#! /usr/bin/python
# -*- coding: UTF-8 -*-
#
# createdebrepo: Make a Debian repository
#
# Copyright © 2014 Multapplied Networks, Inc.
#
import argparse
import bz2
import io
import datetime
import errno
import fnmatch
import gzip
import hashlib
import os
import shutil
import subprocess
import sys


if not hasattr(subprocess, "check_output"):
    def check_output(*popenargs, **kwargs):
        if 'stdout' in kwargs:
            raise ValueError('stdout argument not allowed, it will be overridden.')
        process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
        output, unused_err = process.communicate()
        retcode = process.poll()
        if retcode:
            cmd = kwargs.get("args")
            if cmd is None:
                cmd = popenargs[0]
            raise subprocess.CalledProcessError(retcode, cmd)
        return output
    subprocess.check_output = check_output


def parse_attrs(data):
    """
    Parse deb attrs from raw data
    """
    attrs = []
    # Set local attrs from simple fields
    for line in data.split('\n'):
        if not line:
            continue
        if not line.startswith(' ') and ': ' in line:
            # New field
            key, val = line.split(': ', 1)
            attrs.append([key, val])
        else:
            # Append a line to the last attr
            attrs[-1][1] += '\n' + line

    return attrs


def get_attr(attrs, key):
    """
    Get value by key from attrs. Returns None if key does not exist
    """
    result = [v for k, v in attrs if k == key]
    if result:
        return result[0]

    return None


class RepoException(Exception):
    pass


class DistException(Exception):
    pass


class Package(object):
    """
    Represents a package
    """
    def __init__(self, path, attrs):
        self.path = path
        self.filename = os.path.basename(path)
        self.attrs = attrs
        self.arch = get_attr(attrs, 'Architecture')
        self.name = get_attr(attrs, 'Package')

    @classmethod
    def from_path(cls, path):
        """
        Get a Package object from a file path
        """
        path = path

        # Load data
        data = subprocess.check_output(['/usr/bin/dpkg-deb', '-f', path])

        return cls(path, parse_attrs(data.decode('UTF-8')))

    def get_dest_path(self, dist, component):
        """
        Get destination path for dist and component
        """
        return os.sep.join((dist.pool_dir, component, self.filename))

    def get_raw_package_data(self):
        """
        Get raw package data
        """
        return open(self.path, 'r+b').read()

    def get_package_data(self, dist, component):
        """
        Get package data for the Packages file for dist and component
        """
        data = ""
        description = ""
        written_attrs = []

        # Get all attrs except for description, add Filename, Size, and
        # Checksums, then add the description
        for key, val in self.attrs:
            if key == 'Description':
                # Put this at the end
                description = val
                continue
            elif key == 'Filename':
                # We have our own path
                continue
            written_attrs.append(key)
            data += '%s: %s\n' % (key, val)

        # Add filename
        data += 'Filename: pool/%s/%s/%s\n' % (dist, component, self.filename)

        # Get raw package data
        raw_data = self.get_raw_package_data()

        # Add Size
        if 'Size' not in written_attrs:
            data += 'Size: %s\n' % len(raw_data)

        # Add checksums
        if 'MD5sum' not in written_attrs:
            data += 'MD5sum: %s\n' % hashlib.md5(raw_data).hexdigest()
        if 'SHA1' not in written_attrs:
            data += 'SHA1: %s\n' % hashlib.sha1(raw_data).hexdigest()
        if 'SHA256' not in written_attrs:
            data += 'SHA256: %s\n' % hashlib.sha256(raw_data).hexdigest()

        data += 'Description: %s\n' % description

        return data

    def __hash__(self):
        return hash(self.filename)

    def __eq__(self, other):
        return self.filename == other.filename


class Dist(object):
    """
    Represents a distribution
    """
    def __init__(self, name, dest, description, signkey=None, update=False):
        self.name = name
        self.dest = dest
        self.description = description
        self.signkey = signkey
        self.update = update

        self.dist_dir = os.sep.join((self.dest, 'dists', self.name))
        self.pool_dir = os.sep.join((self.dest, 'pool', self.name))

        self.architectures = []
        # Each component contains a list of Package objects
        self.components = {}

        # Each entry is a tuple: (filename_relative_to_dist, size, md5sum, sha1sum, sha256sum)
        self.dist_files = []

    def add_component(self, name):
        """
        Add component to dist
        """
        if not name in self.components:
            self.components[name] = set()

    def add_package(self, package, component):
        """
        Add package file to dist in component
        """
        arch = package.arch
        if arch != 'all' and arch not in self.architectures:
            self.architectures.append(arch)
        self.components[component].add(package)

    def make_dir(self, path):
        """
        Make a directory at `path`, if it does not already exist.
        """
        if not os.path.isdir(path):
            os.makedirs(path)

    def build_dirs(self):
        """
        Create all needed directories
        """
        for component in self.components:
            for arch in self.architectures:
                self.make_dir(os.sep.join((self.dist_dir, component, 'binary-%s' % arch)))
            self.make_dir(os.sep.join((self.pool_dir, component)))

    def copy_packages(self, link=False):
        """
        Copy all packages to dist
        """
        action = "Linking" if link else "Copying"
        for component in self.components:
            packages = self.components[component]
            for package in packages:
                print("%s:%s - %s %s" % (self.name, component, action, package.filename))
                src = package.path
                dst = package.get_dest_path(self, component)
                if link:
                    if os.path.exists(dst):
                        os.unlink(dst)
                    os.link(src, dst)
                else:
                    if os.path.exists(dst) and os.stat(src).st_ino == os.stat(dst).st_ino:
                        # The source and the destination are the same file
                        continue
                    shutil.copyfile(src, dst)
                    shutil.copystat(src, dst)

    def get_packages_data(self, component, architecture):
        """
        Get the content of the Packages file for component and architecture
        """
        return u'\n'.join([pkg.get_package_data(self.name, component) for pkg in self.components[component] if pkg.arch in ('all', architecture)]).encode('UTF-8')

    def create_dist_file(self, path, data):
        """
        Create dist file in path with data and store metadata
        """
        with open(os.sep.join((self.dist_dir, path)), 'wb') as f:
            f.write(data)
        md5 = hashlib.md5(data).hexdigest()
        sha1 = hashlib.sha1(data).hexdigest()
        sha256 = hashlib.sha256(data).hexdigest()

        self.dist_files.append((path, len(data), md5, sha1, sha256))

    def generate_metadata(self):
        """
        Generate metadata for repo
        """
        for component in self.components:
            for arch in self.architectures:
                print("%s:%s:%s - Generating metadata" % (self.name, component, arch))
                arch_path = os.sep.join((component, 'binary-%s' % arch))

                # Write Packages file
                packages_path = os.sep.join((arch_path, 'Packages'))
                packages_data = self.get_packages_data(component, arch)
                self.create_dist_file(packages_path, packages_data)

                # Write Packages.gz file
                packages_gz_path = packages_path + '.gz'
                gzdata = io.BytesIO()
                gzfile = gzip.GzipFile(fileobj=gzdata, mode='wb')
                gzfile.write(packages_data)
                gzfile.close()
                packages_gz_data = gzdata.getvalue()
                self.create_dist_file(packages_gz_path, packages_gz_data)

                # Write component-arch Release file
                release_path = os.sep.join((arch_path, 'Release'))
                release_data = "Component: %s\nArchitecture: %s\nDescription: %s\n" % (component, arch, self.description)
                self.create_dist_file(release_path, release_data.encode('UTF-8'))

        # Create repo release data
        date = datetime.datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S UTC')
        release_fields = [
            'Codename: %s' % self.name,
            'Date: %s' % date,
            'Architectures: %s' % ' '.join(self.architectures),
            'Components: %s' % ' '.join(self.components),
        ]
        release_data = '\n'.join(release_fields) + '\n'

        # Add MD5 data
        release_data += 'MD5Sum:\n'
        for data in self.dist_files:
            release_data += ' %s %s %s\n' % (data[2], data[1], data[0])

        # Add SHA1 data
        release_data += 'SHA1:\n'
        for data in self.dist_files:
            release_data += ' %s %s %s\n' % (data[3], data[1], data[0])

        # Add SHA256 data
        release_data += 'SHA256:\n'
        for data in self.dist_files:
            release_data += ' %s %s %s\n' % (data[4], data[1], data[0])

        # Write Release
        releasepath = os.sep.join((self.dist_dir, 'Release'))
        with open(releasepath, 'w') as f:
            f.write(release_data)

        if self.signkey:
            subprocess.check_call([
                '/usr/bin/gpg',
                '--batch', '--yes',  # Overwrite existing file without asking
                '--local-user', self.signkey,
                '--output', os.sep.join((self.dist_dir, 'Release.gpg')),
                '--armor',
                '--detach-sign', releasepath,
            ])

            # Create an inplace signed Release file
            subprocess.check_call([
                '/usr/bin/gpg',
                '--batch', '--yes',  # Overwrite existing file without asking
                '--local-user', self.signkey,
                '--output', os.sep.join((self.dist_dir, 'InRelease')),
                '--clearsign', releasepath,
            ])

    def build(self, link=False):
        """
        Build dist repo in self.dest
        """
        if not self.architectures:
            raise RepoException('No packages found for repo')
        self.build_dirs()
        self.copy_packages(link=link)
        self.generate_metadata()


class DirectoryPackageSource(object):
    """
    A simple directory package source
    """
    def __init__(self, directory):
        self.directory = directory

        if not os.path.isdir(directory):
            raise RepoException("'%s' is not a directory" % self.directory)

    def find_packages(self):
        """
        Find packages in source directory
        """
        for root, dirs, files in os.walk(self.directory):
            for filename in files:
                if filename.rsplit('.', 1)[-1] in ('deb', 'udeb'):
                    yield Package.from_path(os.path.join(root, filename))


class RepositoryPackageSource(object):
    """
    A Debian repo package source
    """
    def __init__(self, directory, codename, component):
        self.directory = directory
        self.componentdir = '%s/dists/%s/%s' % (self.directory, codename, component)
        self.packages_files = []

        if not os.path.isdir(directory):
            raise RepoException("'%s' is not a directory" % self.srcdir)

        if not os.path.isdir(self.componentdir):
            raise RepoException("'%s' does not exist" % component)

        for dirname in os.listdir(self.componentdir):
            if dirname.startswith('binary-') and dirname != 'binary-all' and os.path.isdir(os.path.join(self.componentdir, dirname)):
                for ext in ('', '.gz', '.bz2'):
                    filename = '%s/%s/Packages%s' % (self.componentdir, dirname, ext)
                    if os.path.exists(filename):
                        self.packages_files.append(filename)
                        break

        if not self.packages_files:
            raise RepoException("No Packages files found in '%s'" % self.componentdir)

    def find_packages(self):
        """
        Find packages in source directory
        """
        for packages_file in self.packages_files:
            try:
                if packages_file.endswith('bz2'):
                    f = bz2.open(packages_file, 'rt')
                elif packages_file.endswith('gz'):
                    f = gzip.open(packages_file, 'rt')
                else:
                    f = open(packages_file)
            except EnvironmentError as e:
                raise RepoException("Could not open %s: %s" % (packages_file, e))

            try:
                data = ""
                eof = False
                while True:
                    chunk = f.read(65535)
                    data += chunk
                    if len(chunk) == 0:
                        eof = True
                        if not data.endswith('\n\n'):
                            data += '\n\n'
                    while '\n\n' in data:
                        attrdata, data = data.split('\n\n', 1)
                        attrs = parse_attrs(attrdata + '\n')
                        if attrs:
                            yield Package(os.path.join(self.directory, get_attr(attrs, 'Filename')), attrs)
                    if eof:
                        break
            finally:
                f.close()


class Source(object):
    """
    Represents a package source
    """
    def __init__(self, dist, component, *srcdirs):
        self.dist = dist
        self.component = component
        self.srcdirs = srcdirs

        self.sources = []

        for srcdir in self.srcdirs:
            if '::' in srcdir:
                args = srcdir.split('::')
                path, dist = args[:2]
                if not path or not os.path.isdir(path):
                    raise RepoException("Source path '%s' does not exist" % path)
                if len(args) == 2:
                    src_component = component
                else:
                    src_component = args[2]

                self.sources.append(RepositoryPackageSource(path, dist, src_component))
            else:
                self.sources.append(DirectoryPackageSource(srcdir))

    def find_packages(self):
        """
        Find packages in source
        """
        for source in self.sources:
            for package in source.find_packages():
                yield package


def main(args):
    dest_dir = args.repodir

    sources = []

    def add_source(dist, component, *srcdirs):
        try:
            sources.append(Source(dist, component, *srcdirs))
        except RepoException as e:
            print(e)
            sys.exit(1)

    for source in args.sources:
        try:
            add_source(*source.split(','))
        except ValueError:
            print("Source must be of the form DIST,COMPONENT[,<DIR|REPODIR::CODENAME[::COMPONENT]>]")
            sys.exit(1)

    # If updating, add the target sources
    target_source = os.path.join(dest_dir, 'dists')
    if args.update and os.path.exists(target_source):
        for dist in os.listdir(target_source):
            for component in os.listdir(os.path.join(target_source, dist)):
                if os.path.isdir(os.path.join(target_source, dist, component)):
                    add_source(dist, component, '{dest}::{dist}'.format(dest=dest_dir, dist=dist))

    def parse_include_file(filename):
        matches = []
        with open(filename) as f:
            for line in f:
                line = line.strip()
                if line and not line.startswith('#'):
                    matches.append(line)
        return matches

    matches = []
    if args.include:
        matches = parse_include_file(args.include)

    debug_matches = []
    if args.debug_packages:
        debug_matches = parse_include_file(args.debug_packages)

    def package_matches(pkg_name, matches):
        for match in matches:
            if fnmatch.fnmatch(package.name, match):
                return True

    def match_package(package):
        if not args.debug and package.name.endswith('-dbg'):
            # If debug-packages is included, don't exclude anything in that file
            if debug_matches:
                return package_matches(package.name, debug_matches)
            return False

        if not matches:
            return True

        return package_matches(package.name, matches)

    if not args.update and os.path.exists(dest_dir) and not args.dry_run:
        if not args.force:
            ans = input("%s exists. Delete? (y/n) " % dest_dir)
            if not ans or ans[0].lower() != 'y':
                sys.exit(0)

        shutil.rmtree(dest_dir)

    dists = {}
    for source in sources:
        distname = source.dist
        if distname not in dists:
            dists[distname] = Dist(distname, dest_dir, args.description, signkey=args.signkey, update=args.update)
        dist = dists[distname]

        dist.add_component(source.component)

        try:
            for package in source.find_packages():
                if match_package(package):
                    if args.dry_run:
                        print(source.srcdirs, package.filename)
                    else:
                        dist.add_package(package, source.component)
        except RepoException as e:
            print(e)
            sys.exit(1)

    if args.dry_run:
        return
    try:
        for distname in dists:
            dists[distname].build(link=args.link)
    except DistException as e:
        print(e)
        sys.exit(1)

    if args.signkey:
        # Export the signing key for clients to import
        subprocess.check_call([
            '/usr/bin/gpg',
            '--batch', '--yes',  # Overwrite existing key without asking
            '--local-user', args.signkey,
            '--output', os.sep.join((args.repodir, 'public.gpg.key')),
            '--armor',
            '--export',
        ])

if __name__ == '__main__':
    default_dest = os.path.join(os.getcwd(), 'repo')

    parser = argparse.ArgumentParser("Build a Debian repo")
    parser.add_argument('sources', metavar='DIST,COMPONENT[,SOURCE]', nargs='+', help="deb package source. DIST must be the distribution codename (eg: 'wheezy'). COMPONENT may be one of 'main', 'contrib', or 'non-free'. SOURCE may be a simple directory containing deb files that will be recursively searched or an entry of the form REPODIR::CODENAME[::COMPONENT] to extract packages from an existing repository. If the COMPONENT is omitted, it will assume the target component. If no SOURCE is specified, a blank component will be created. Multiple SOURCEs may be specified separated by commas.")
    parser.add_argument('-d', '--description', default="Createdebrepo Repository", help="Description of repo")
    group = parser.add_mutually_exclusive_group()
    group.add_argument('-f', '--force', action='store_true', help="Force creation without input, even if dist dir already exists")
    group.add_argument('-u', '--update', action='store_true', help="Update the existing repository")
    parser.add_argument('-i', '--include', help="file containing a list of packages to include. Packages that don't match will be excluded. The list supports standard shell globbing")
    parser.add_argument('-r', '--repodir', default=default_dest, help="destination directory. WARNING: directory will be cleared before creating the repo, defaults to ./repo")
    parser.add_argument('-s', '--signkey', help="Specify a GnuPG key to sign the repository with. If not set, the default will be used")
    parser.add_argument('-l', '--link', action='store_true', help="Create hardlinks to source package files instead of copying them")
    parser.add_argument('--debug', action='store_true', help="Include all debug packages in repository")
    parser.add_argument('--dry-run', action='store_true', help="Just print a list of packages that will be included in the repository")
    parser.add_argument('-dbg', '--debug-packages', help="Include specific debug packages in repository, regardless of if --debug is used.")
    args = parser.parse_args()

    try:
        main(args)
    except KeyboardInterrupt:
        sys.exit(1)
