#!/usr/bin/python3

# Copyright (C) 2014-2020, IONOS SE
# Author: Benjamin Drung <benjamin.drung@cloud.ionos.com>
#
# Permission to use, copy, modify, and/or distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

# pylint: disable=invalid-name
# pylint: enable=invalid-name

import argparse
import configparser
import datetime
import errno
import glob
import hashlib
import logging
import os
import re
import shutil
import stat
import subprocess
import sys
import tempfile

import httplib2
import parted

DEFAULT_IMAGE_SIZE = "2G"
DEFAULT_LOGGING_FORMAT = "%(asctime)s %(name)s [%(process)d] %(levelname)s: %(message)s"
GUEST_HTTP_SERVER = "10.0.2.4"
_UNSET = object()
LOSETUP = "/sbin/losetup"
ZEROFREE = "/usr/sbin/zerofree"
__logger_name__ = os.path.basename(sys.argv[0]) if __name__ == "__main__" else __name__


def get_config():
    """Return a config parser object.

    The configuration is tried to be read in this order:
    1) User configuration file: ~/.config/image-factory.conf
    2) System configuration file: /etc/image-factory.conf

    You can override the usage of the configuration files from
    point 1 and 2 by specifying a configuration file in the
    IMAGE_FACTORY_CONFIG environment variable.
    """
    config = configparser.ConfigParser(interpolation=configparser.ExtendedInterpolation())
    env_name = "IMAGE_FACTORY_CONFIG"
    if env_name in os.environ:
        config_files = [os.environ[env_name]]
        if not os.path.isfile(config_files[0]):
            raise Exception(
                "'" + config_files[0] + "' (set in " + env_name + ") is not a valid file."
            )
    else:
        config_files = [
            "/etc/image-factory.conf",
            os.path.expanduser("~/.config/image-factory.conf"),
        ]
    config.read(config_files)
    return config


def call_command(command, as_root=False):
    """Run a given command and check if the command returns 0. Otherwise fail."""
    if as_root and os.getuid() != 0:
        command = ["sudo", "image-factory-sudo-helper"] + command
    escaped_command = []
    for argument in command:
        if " " in argument:
            escaped_command.append('"' + argument.replace('"', r"\"") + '"')
        else:
            escaped_command.append(argument)
    logger = logging.getLogger(__logger_name__)
    logger.info("Calling %s", " ".join(escaped_command))
    return_code = subprocess.call(command)
    if return_code != 0:
        logger.error("'%s' failed with exit code %i.", " ".join(command), return_code)
        sys.exit(return_code)


def create_raw_image(filename, size):
    """Create a virtual raw image."""
    cmd = ["qemu-img", "create", "-f", "raw", filename, size]
    call_command(cmd)


def create_url(image):
    return "http://" + GUEST_HTTP_SERVER + "/" + image


def get_default_cache_dir():
    """Return the default cache directory."""
    if os.getuid() == 0:
        cache_dir = "/var/cache/image-factory"
    else:
        cache_dir = "~/.cache/image-factory"
    return cache_dir


def cache_file(cache_dir, source):
    """Cache a file locally and return the relative location of the cached file."""
    logger = logging.getLogger(__logger_name__)
    if not os.path.exists(cache_dir):
        logger.info("Creating directory %s", cache_dir)
        os.makedirs(cache_dir)
    relative_destination = os.path.basename(source)
    destination = os.path.join(cache_dir, relative_destination)
    if source.startswith("file:") or source.startswith("/"):
        logger.info("Copying '%s' to cache '%s'...", source, destination)
        shutil.copy(source, destination)
    elif source.startswith("http:") or source.startswith("https:"):
        logger.info("Downloading %s...", source)
        http_client = httplib2.Http(cache_dir)
        http_client.ignore_etag = True
        (response, content) = http_client.request(source)
        if response.fromcache:
            logger.info("Copy cached download to %s...", destination)
        else:
            logger.info("Save download to %s...", destination)
        cached_file = open(destination, "wb")
        cached_file.write(content)
        cached_file.close()
    elif source.startswith("rsync:"):
        call_command(["rsync", "--no-motd", source, destination])
    else:
        raise Exception("No download handler for file '{source}' found.".format(source=source))
    return relative_destination


def download_and_publish(config, image, source, filename):
    destination = os.path.join(config.get("http", "path"), image, filename)
    target_dir = os.path.dirname(destination)
    logger = logging.getLogger(__logger_name__)
    if not os.path.exists(target_dir):
        logger.info("Creating directory %s", target_dir)
        os.makedirs(target_dir)
    if source.startswith("file:") or source.startswith("/"):
        logger.info("Copying %s to %s", source, destination)
        shutil.copyfile(source, destination)
    elif source.startswith("http:") or source.startswith("https:"):
        logger.info("Downloading %s...", source)
        http_client = httplib2.Http(os.path.join(config[image]["cache_dir"], image))
        http_client.ignore_etag = True
        (response, content) = http_client.request(source)
        if response.fromcache:
            logger.info("Copy cached download to %s...", destination)
        else:
            logger.info("Save download to %s...", destination)
        cached_file = open(destination, "wb")
        cached_file.write(content)
        cached_file.close()
    elif source.startswith("rsync:"):
        call_command(["rsync", "--no-motd", source, destination])
    else:
        raise Exception("No download handler for file '{source}' found.".format(source=source))
    return destination


def check_one_partition(partitions, image):
    """Checks that the image has only one partition.

    The image-handler (that adds the root password and SSH keys to instantiate
    the template) requires that there is only one partition in the image.
    """
    if len(partitions) != 1:
        if partitions:
            msg = "{n} partitions ({partitions})".format(
                n=len(partitions), partitions=", ".join(partitions)
            )
        else:
            msg = "no partitions"
        logger = logging.getLogger(__logger_name__)
        logger.error("Expected exactly one partition in %s, but found %s.", image, msg)
        sys.exit(1)


def parse_bytes(data):
    """Parse bytes from given string.

    The SI prefixes (kB, MB, etc.) and binary prefixes (KiB, MiB, etc.) are supported.
    For backward compatibility, the units K, M, and G are mapped to KiB, MiB, and GiB.
    """
    match = re.match(r"^([0-9]+)\s*([kMGTPE]?B|[KMGTPE]iB|[KMGTPE])$", data.strip())
    if not match:
        raise ValueError(
            "Failed to parse bytes from '%s'. Please use SI or binary prefixes "
            "for bytes (e.g. '2 GB' or '512 MiB')." % (data)
        )

    value = int(match.group(1))
    unit = match.group(2)
    if unit.endswith("iB"):
        value *= 1 << (10 * {"KiB": 1, "MiB": 2, "GiB": 3, "TiB": 4, "PiB": 5, "EiB": 6}[unit])
    elif unit.endswith("B"):
        value *= 10 ** (3 * {"": 0, "kB": 1, "MB": 2, "GB": 3, "TB": 4, "PB": 5, "EB": 6}[unit])
    else:
        value *= 1 << (10 * {"K": 1, "M": 2, "G": 3, "T": 4, "P": 5, "E": 6}[unit])
    return value


def get_session():
    """Return the session to use with virt-install."""
    if os.getuid() == 0:
        session = "qemu:///system"
    else:
        session = "qemu:///session"
    return session


def run_installation(config, image, image_name):  # pylint: disable=too-many-locals
    cores = config.get(image, "cores", fallback="1")
    ram_in_mib = parse_bytes(config.get(image, "ram")) >> 20

    cache_dir = os.path.join(config[image]["cache_dir"], image)
    if config.has_option(image, "installer_image"):
        installer_image = cache_file(cache_dir, config.get(image, "installer_image"))
        location = os.path.join(cache_dir, installer_image)
    else:
        initrd = cache_file(cache_dir, config.get(image, "initrd"))
        kernel = cache_file(cache_dir, config.get(image, "linux"))
        location = "{},kernel={},initrd={}".format(cache_dir, kernel, initrd)

    append = config.get(image, "append", fallback="")
    if config.has_option(image, "preseed"):
        initrd_inject = config.get(image, "preseed")
        append = "preseed/url=file:///{} {}".format(os.path.basename(initrd_inject), append)
    if config.has_option(image, "kickstart"):
        initrd_inject = config.get(image, "kickstart")
        append = "ks=file:///{} {}".format(os.path.basename(initrd_inject), append)
        append = "inst.ks=file:///{} {}".format(os.path.basename(initrd_inject), append)
    if config.has_option(image, "yast"):
        initrd_inject = config.get(image, "yast")
        append = "autoyast=file:///{} {}".format(os.path.basename(initrd_inject), append)

    network = "user,model=virtio"
    if config.has_option(image, "mac"):
        network += ",mac=" + config.get(image, "mac")

    graphics = "none"
    if config.has_option(image, "vnc"):
        listen, port = config.get(image, "vnc").split(":")
        graphics = "vnc,port={}".format(5900 + int(port))
        if listen:
            graphics += ",listen={}".format(listen)

    session = get_session()
    cmd = [
        "virt-install",
        "--connect",
        session,
        "--noreboot",
        "--wait",
        "-1",
        "--name",
        image_name,
        "--vcpus",
        cores,
        "--memory",
        str(ram_in_mib),
        "--disk",
        "path=" + image_name + ",bus=virtio,format=raw",
        "--network",
        network,
        "--graphics",
        graphics,
        "--console",
        "pty,target_type=serial",
        "--noautoconsole",
        "--location",
        location,
        "--extra-args",
        append,
        "--initrd-inject",
        initrd_inject,
    ]

    try:
        call_command(cmd)
    except KeyboardInterrupt:
        call_command(["virsh", "-c", session, "destroy", image_name])
        raise
    finally:
        call_command(["virsh", "-c", session, "undefine", image_name])


def open_as_user(filename):
    try:
        fileobject = open(filename)
    except IOError as error:
        if error.errno == errno.EACCES:
            call_command(["chmod", "o+r", filename], as_root=True)
            fileobject = open(filename)
        else:
            raise
    return fileobject


def remove(path, recursive=False):
    logger = logging.getLogger(__logger_name__)
    parent_dir = os.path.dirname(path)
    missing_permission = not os.access(parent_dir, os.W_OK)
    if missing_permission:
        # Assert that others cannot write (so we correct remove the other write bits later again)
        assert os.stat(parent_dir).st_mode & stat.S_IWOTH == 0
        call_command(["chmod", "o+w", parent_dir], as_root=True)
    try:
        if os.path.isdir(path):
            logger.info("Removing directory %s", path)
            if recursive:
                call_command(["chmod", "-R", "o+rwx", path], as_root=True)
                shutil.rmtree(path)
            else:
                os.rmdir(path)
        else:
            logger.info("Removing file %s", path)
            os.remove(path)
    finally:
        if missing_permission:
            call_command(["chmod", "o-w", parent_dir], as_root=True)


def remove_logs(tmpdir, print_installer_logs):
    root_dir = os.path.join(tmpdir, "root")
    missing_permission = not os.access(root_dir, os.W_OK)
    if missing_permission:
        # Assert that others cannot read/write/execute.
        # Then we can correctly remove the permissions again later (without altering the state).
        assert os.stat(root_dir).st_mode & (stat.S_IROTH | stat.S_IWOTH | stat.S_IXOTH) == 0
        call_command(["chmod", "o+rwx", root_dir], as_root=True)

    try:
        # Fix read permission for log directories (that we will remove later)
        check_directories = glob.glob(os.path.join(tmpdir, "var/adm/autoinstall/*")) + glob.glob(
            os.path.join(tmpdir, "var/log/YaST2")
        )
        for directory in check_directories:
            if os.path.isdir(directory) and not os.access(directory, os.R_OK):
                call_command(["chmod", "-R", "o+rx", directory], as_root=True)

        # Print installer logs
        installer_logs = (
            glob.glob(os.path.join(tmpdir, "root/anaconda-ks.cfg"))
            + glob.glob(os.path.join(tmpdir, "var/adm/autoinstall/logs/*"))
            + glob.glob(os.path.join(tmpdir, "var/log/anaconda.*"))
            + glob.glob(os.path.join(tmpdir, "var/log/anaconda/*"))
            + glob.glob(os.path.join(tmpdir, "var/log/installer/syslog"))
            + glob.glob(os.path.join(tmpdir, "var/log/YaST2/y2log"))
        )
        if print_installer_logs:
            logger = logging.getLogger(__logger_name__)
            for installer_log in installer_logs:
                content = open_as_user(installer_log).read()
                logger.info("Content of /%s:\n%s", os.path.relpath(installer_log, tmpdir), content)

        # Remove installer logs
        remove_log_globs = [
            "root/anaconda-ks.cfg",
            "root/install.log*",
            "var/adm/autoinstall",
            "var/lib/YaST2",
            "var/log/anaconda",
            "var/log/installer",
            "var/log/YaST2",
        ]
        for log_glob in remove_log_globs:
            for log in glob.glob(os.path.join(tmpdir, log_glob)):
                remove(log, recursive=True)
    finally:
        if missing_permission:
            call_command(["chmod", "o-rwx", root_dir], as_root=True)


def post_installation(script_name, image, print_installer_logs):
    logger = logging.getLogger(__logger_name__)
    device = parted.getDevice(image)
    try:
        disk = parted.Disk(device)
    except parted.DiskLabelException as error:
        logger.error("Failed to read the disk %s. Maybe the disk is still empty?", error)
        sys.exit(1)
    check_one_partition(disk.partitions, image)
    partition = disk.partitions[0]
    offset = partition.geometry.start * device.sectorSize
    loopdev = None
    try:
        cmd = [LOSETUP, "-o", str(offset), "--show", "-f", image]
        logger.info("Calling %s", " ".join(cmd))
        process = subprocess.Popen(cmd, stdout=subprocess.PIPE)
        loopdev = process.communicate()[0].decode().strip()

        tmpdir = tempfile.mkdtemp(prefix=script_name + ".")
        call_command(["mount", loopdev, tmpdir], as_root=True)

        try:
            remove_logs(tmpdir, print_installer_logs)
        finally:
            call_command(["umount", tmpdir], as_root=True)
            shutil.rmtree(tmpdir)

        call_command([ZEROFREE, loopdev])

    finally:
        if loopdev:
            cmd = [LOSETUP, "-d", loopdev]
            logger.info("Calling %s", " ".join(cmd))
            return_code = subprocess.call(cmd)
            if return_code != 0:
                logger.warning("losetup failed with exit code %i.", return_code)


def create_hashsum(image):
    logger = logging.getLogger(__logger_name__)
    logger.info("Calculating SHA 256 sum of %s...", image)
    sha256_sum = hashlib.sha256(open(image, "rb").read()).hexdigest()
    logger.info("SHA 256 sum of %s: %s", image, sha256_sum)
    checksum_file = open(image + ".sha256sum", "w")
    checksum_file.write(sha256_sum + "  " + image + "\n")
    checksum_file.close()


def create_qcow2(image, keep_raw_image):
    qcow2_name = os.path.splitext(image)[0] + ".qcow2"
    call_command(["qemu-img", "convert", "-O", "qcow2", image, qcow2_name])
    if not keep_raw_image:
        try:
            logger = logging.getLogger(__logger_name__)
            logger.info("Removing %s...", image)
            os.remove(image)
        except FileNotFoundError:
            pass
    create_hashsum(qcow2_name)
    return qcow2_name


def upload_image(config, image, image_file, checksum_file):
    """Try to upload image.

    Upload the image to all destinatons that are listed in 'upload_destinations'.
    The 'upload_destinations' variable is a comma-separated list of sections.
    To disable the upload, let 'upload_destinations' undefined or set to an empty string.
    Every section has to set 'upload_type' and 'upload_target'. You could specify
    'upload_args', 'post-upload-command', and 'post-upload-command' with a number suffix
    (counting up from 1). ${image} can be used as parameter in post-upload-command.
    """
    logger = logging.getLogger(__logger_name__)
    destinations = config.get(image, "upload_destinations", fallback="").split(",")
    for destination in [d.strip() for d in destinations if d.strip() != ""]:
        try:
            upload_type = config.get(destination, "upload_type")
        except (configparser.NoOptionError, configparser.NoSectionError):
            logger.error(
                "No 'upload_type' defined in the upload destination section '%s'.", destination
            )
            sys.exit(1)
        if upload_type.lower() == "rsync":
            try:
                upload_target = config.get(destination, "upload_target")
            except (configparser.NoOptionError, configparser.NoSectionError):
                logger.error(
                    "No 'upload_target' defined in the upload destination section '%s'.",
                    destination,
                )
                sys.exit(1)
            command = ["rsync"]
            if config.has_option(destination, "upload_args"):
                command += [
                    a for a in config.get(destination, "upload_args").split(" ") if a.strip() != ""
                ]
            command += [image_file, checksum_file, upload_target]
            call_command(command)
            logger.info("Successfully uploaded " + image_file + " to " + upload_target)

            if config.has_option(destination, "post-upload-command"):
                config.set(destination, "image", image_file)
                command = [
                    a
                    for a in config.get(destination, "post-upload-command").split(" ")
                    if a.strip() != ""
                ]
                call_command(command)

            i = 1
            while config.has_option(destination, "post-upload-command" + str(i)):
                config.set(destination, "image", image_file)
                command = config.get(destination, "post-upload-command" + str(i))
                command = [a for a in command.split(" ") if a.strip() != ""]
                call_command(command)
                i += 1
        else:
            logger.error("Unknown upload type '%s' specified. Supported types: rsync", upload_type)
            sys.exit(1)


def parse_args(argv):
    """Parse the command line arguments."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "image",
        nargs="?",
        default=os.environ.get("IMAGE"),
        help="Image to build (date and suffix will be added).",
    )
    parser.add_argument(
        "-c", "--cache-dir", help="Cache directory (default: %s)" % (get_default_cache_dir())
    )
    parser.add_argument(
        "-f", "--format", choices=["qcow2", "raw"], help="Image format to use (default: raw)"
    )
    parser.add_argument("--image-size", help="Size of the raw image (default: 2G)")
    parser.add_argument("--mac", help="MAC address used in the installation machine")
    parser.add_argument(
        "--installer-logs",
        dest="installer_logs",
        action="store_true",
        default=None,
        help="Print installer logs into logging output",
    )
    parser.add_argument(
        "--no-installer-logs",
        dest="installer_logs",
        action="store_false",
        default=None,
        help="Do not print installer logs into logging output",
    )
    parser.add_argument(
        "--log-file",
        dest="log_file",
        action="store_true",
        default=None,
        help="Store logs into a file (in addition to stdout/stderr)",
    )
    parser.add_argument(
        "--no-log-file",
        dest="log_file",
        action="store_false",
        default=None,
        help="Do not store logs into a file (in addition to stdout/stderr)",
    )
    parser.add_argument("--log-filename", help="log into specified file")
    args = parser.parse_args(argv)

    if not args.image:
        parser.error("No image specified.")

    return args


def override_configs_by_args(config, args):
    """Override the configs from the configuration by arguments from the command line.

    The command line argument take precedence over the configs from the config file.
    """
    if args.image not in config:
        config[args.image] = {}
    image_conf = config[args.image]

    if args.cache_dir is not None:
        image_conf["cache_dir"] = args.cache_dir
    if args.mac is not None:
        image_conf["mac"] = args.mac
    if args.format is not None:
        image_conf["format"] = args.format
    if args.image_size is not None:
        image_conf["image-size"] = args.image_size
    if args.installer_logs is not None:
        image_conf["installer-logs"] = str(args.installer_logs)
    if args.log_file is not None:
        image_conf["log-file"] = str(args.log_file)
    if args.log_filename is not None:
        image_conf["log-filename"] = args.log_filename

    if "cache_dir" not in image_conf:
        image_conf["cache_dir"] = get_default_cache_dir()

    return config


def main():
    args = parse_args(sys.argv[1:])
    config = get_config()
    missing_image_section = args.image not in config
    override_configs_by_args(config, args)
    image_conf = config[args.image]
    image_conf["cache_dir"] = os.path.expanduser(image_conf["cache_dir"])

    logger = logging.getLogger(__logger_name__)
    logging.basicConfig(format=DEFAULT_LOGGING_FORMAT, level=logging.INFO)
    if image_conf.getboolean("log-file", fallback=False):
        if "log-filename" in image_conf:
            log_filename = image_conf["log-filename"]
        else:
            log_filename = args.image + "-" + datetime.date.today().isoformat() + ".log"
        file_handler = logging.FileHandler(log_filename, mode="w")
        file_handler.setFormatter(logging.Formatter(DEFAULT_LOGGING_FORMAT))
        logger.addHandler(file_handler)

    # Check that configuration for installer_image or (initrd and linux) exists
    if "installer_image" not in image_conf:
        required_options = ["initrd", "linux"]
        missing_options = [option for option in required_options if option not in image_conf]
        if missing_options:
            if missing_image_section:
                logger.error("No section '%s' defined in image-factory.conf.", args.image)
            else:
                for option in missing_options:
                    logger.error(
                        "No option '%s' or 'installer_image' defined in section '%s' "
                        "in image-factory.conf.",
                        option,
                        args.image,
                    )
            sys.exit(1)

    image = args.image + "-" + datetime.date.today().isoformat() + ".raw"
    create_raw_image(image, image_conf.get("image-size", DEFAULT_IMAGE_SIZE))
    run_installation(config, args.image, image)
    post_installation(
        os.path.basename(sys.argv[0]),
        image,
        image_conf.getboolean("installer-logs", fallback=False),
    )
    if image_conf["format"] == "qcow2":
        image = create_qcow2(image, image_conf.getboolean("keep-raw", fallback=False))
    else:
        create_hashsum(image)
    logger.info("Successfully created %s", image)
    upload_image(config, args.image, image, image + ".sha256sum")
    if config.has_option(args.image, "post-build-command"):
        cmd = [config.get(args.image, "post-build-command"), image]
        call_command(cmd)


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("User abort.")
