#!/usr/bin/python3
# Copyright (C) 2018 Jelmer Vernooij <jelmer@jelmer.uk>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA

import asyncio
import contextlib
import json
import logging
import os
import subprocess
import tempfile
from typing import List

from urllib.request import urlopen
from urllib.error import HTTPError

import breezy.bzr
import breezy.git
from breezy.trace import note

try:
    from breezy.transport import NoSuchFile
except ImportError:
    from breezy.errors import NoSuchFile
from breezy.workingtree import WorkingTree

from debian.changelog import Changelog, Version

from breezy.plugins.debian.info import versions_dict
from breezy.plugins.debian.upstream import PackageVersionNotPresent
from breezy.plugins.debian.import_dsc import (
    DistributionBranch,
    DistributionBranchSet,
)
from breezy.plugins.debian.util import MissingChangelogError
from breezy.plugins.debian.apt_repo import (
    LocalApt,
    RemoteApt,
    NoAptSources,
    AptSourceError,
)

BRANCH_NAME = "missing-commits"


def connect_udd_mirror():
    import psycopg2

    return psycopg2.connect(
        database="udd",
        user="udd-mirror",
        password="udd-mirror",
        host="udd-mirror.debian.net",
    )


def select_vcswatch_packages():
    conn = connect_udd_mirror()
    cursor = conn.cursor()
    args = []
    query = """\
    SELECT sources.source, vcswatch.url
    FROM vcswatch JOIN sources ON sources.source = vcswatch.source
    WHERE
     vcswatch.status IN ('OLD', 'UNREL') AND
     sources.release = 'sid'
"""
    cursor.execute(query, tuple(args))
    packages = []
    for package, vcs_url in cursor.fetchall():
        packages.append(package)
    return packages


class SnapshotDownloadError(Exception):

    def __init__(self, url, e):
        self.url = url
        self.inner = inner


def download_snapshot(package, version, output_dir, no_preparation=False):
    note("Downloading %s %s", package, version)
    srcfiles_url = (
        "https://snapshot.debian.org/mr/package/%s/%s/"
        "srcfiles?fileinfo=1" % (package, version)
    )
    files = {}
    for hsh, entries in json.load(urlopen(srcfiles_url))["fileinfo"].items():
        for entry in entries:
            files[entry["name"]] = hsh
    for filename, hsh in files.items():
        local_path = os.path.join(output_dir, os.path.basename(filename))
        with open(local_path, "wb") as f:
            url = "https://snapshot.debian.org/file/%s" % hsh
            note('.. Downloading %s', url)
            try:
                with urlopen(url) as g:
                    f.write(g.read())
            except HTTPError as e:
                if e.status == 504:
                    raise SnapshotDownloadError(e)
                raise e
    args = []
    if no_preparation:
        args.append("--no-preparation")
    subprocess.check_call(
        ["dpkg-source"] + args + ["-x", "%s_%s.dsc" % (package, version)],
        cwd=output_dir,
    )


class NoMissingVersions(Exception):
    def __init__(self, vcs_version, archive_version):
        self.vcs_version = vcs_version
        self.archive_version = archive_version
        super(NoMissingVersions, self).__init__(
            "No missing versions after all. Archive has %s, VCS has %s"
            % (archive_version, vcs_version)
        )


class TreeVersionNotInArchiveChangelog(Exception):
    def __init__(self, tree_version):
        self.tree_version = tree_version
        super(TreeVersionNotInArchiveChangelog, self).__init__(
            "tree version %s does not appear in archive changelog" %
            tree_version
        )


class TreeUpstreamVersionMissing(Exception):
    def __init__(self, upstream_version):
        self.upstream_version = upstream_version
        super(TreeUpstreamVersionMissing, self).__init__(
            "unable to find upstream version %r" % upstream_version
        )


def import_uncommitted(tree, subpath, apt):
    cl_path = os.path.join(subpath, "debian/changelog")
    try:
        with tree.get_file(cl_path) as f:
            tree_cl = Changelog(f)
            package_name = tree_cl.package
    except NoSuchFile:
        raise MissingChangelogError([cl_path])

    with contextlib.ExitStack() as es:
        es.enter_context(apt)
        archive_source = es.enter_context(tempfile.TemporaryDirectory())
        apt.retrieve_source(package_name, archive_source)
        [dsc] = [e.name for e in os.scandir(archive_source)
                 if e.name.endswith('.dsc')]
        note("Unpacking source")
        subprocess.check_output(['dpkg-source', '-x', dsc], cwd=archive_source)
        [subdir] = [e.path for e in os.scandir(archive_source) if e.is_dir()]
        with open(os.path.join(subdir, "debian", "changelog"), "r") as f:
            archive_cl = Changelog(f)
        missing_versions: List[Version] = []
        for block in archive_cl:
            if block.version == tree_cl.version:
                break
            missing_versions.append(block.version)
        else:
            raise TreeVersionNotInArchiveChangelog(tree_cl.version)
        if len(missing_versions) == 0:
            raise NoMissingVersions(tree_cl.version, archive_cl.version)
        note("Missing versions: %s", ", ".join(map(str, missing_versions)))
        ret = []
        dbs = DistributionBranchSet()
        db = DistributionBranch(tree.branch, tree.branch, tree=tree)
        dbs.add_branch(db)
        if tree_cl.version.debian_revision:
            note(
                "Extracting upstream version %s.",
                tree_cl.version.upstream_version
            )
            upstream_dir = es.enter_context(tempfile.TemporaryDirectory())
            try:
                upstream_tips = db.pristine_upstream_source.version_as_revisions(
                    tree_cl.package, tree_cl.version.upstream_version
                )
            except PackageVersionNotPresent:
                # TODO(jelmer): Should we import it instead?
                raise TreeUpstreamVersionMissing(tree_cl.version.upstream_version)
            db.extract_upstream_tree(upstream_tips, upstream_dir)
        no_preparation = not tree.has_filename(".pc/applied-patches")
        version_path = {}
        for version in missing_versions:
            output_dir = es.enter_context(tempfile.TemporaryDirectory())
            download_snapshot(
                package_name, version, output_dir,
                no_preparation=no_preparation
            )
            version_path[version] = output_dir
        for version in reversed(missing_versions):
            note("Importing %s", version)
            dsc_path = os.path.join(
                version_path[version], "%s_%s.dsc" % (package_name, version)
            )
            tag_name = db.import_package(dsc_path)
            revision = db.branch.tags.lookup_tag(tag_name)
            ret.append((tag_name, version, revision))
    return ret


def report_fatal(code, description):
    if os.environ.get('SVP_API') == '1':
        with open(os.environ['SVP_RESULT'], 'w') as f:
            json.dump({
                'versions': versions_dict(),
                'result_code': code,
                'description': description}, f)
    logging.fatal('%s', description)


async def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--apt-repository', type=str,
        help='APT repository to use. Defaults to locally configured.',
        default=os.environ.get('APT_REPOSITORY') or os.environ.get('REPOSITORIES'))
    parser.add_argument(
        '--apt-repository-key', type=str,
        help=('APT repository key to use for validation, '
              'if --apt-repository is set.'),
        default=os.environ.get('APT_REPOSITORY_KEY'))
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO, format='%(message)s')

    if args.apt_repository:
        apt = RemoteApt.from_string(
            args.apt_repository, args.apt_repository_key)
    else:
        apt = LocalApt()
    local_tree, subpath = WorkingTree.open_containing('.')
    try:
        ret = import_uncommitted(local_tree, subpath, apt)
    except AptSourceError as e:
        if isinstance(e.reason, list):
            reason = e.reason[-1]
        else:
            reason = e.reason
        report_fatal("apt-source-error", reason)
        return 1
    except MissingChangelogError as e:
        report_fatal(
            "missing-changelog",
            "Missing changelog: %s" % e.location[0])
        return 1
    except NoAptSources:
        report_fatal(
            "no-apt-sources",
            "No sources configured in /etc/apt/sources.list")
        return 1
    except TreeUpstreamVersionMissing as e:
        report_fatal("tree-upstream-version-missing", str(e))
        return 1
    except TreeVersionNotInArchiveChangelog as e:
        report_fatal("tree-version-not-in-archive-changelog", str(e))
        return 1
    except NoMissingVersions as e:
        report_fatal("nothing-to-do", str(e))
        return 0
    except SnapshotDownloadError as e:
        report_fatal(
            'snapshot-download-failed', 'Downloading %s failed: %s' % (e.url, e.inner))
        return 1

    if os.environ.get('SVP_API') == '1':
        with open(os.environ['SVP_RESULT'], 'w') as f:
            json.dump({
                'description': 'Import archive changes missing from the VCS.',
                'versions': versions_dict(),
                'commit-message': "Import missing uploads: %s." % (
                    ", ".join([str(v) for t, v, rs in ret])),
                'context': {
                    'tags':
                        [(tag_name, str(version))
                         for (tag_name, version, rs) in ret],
                },
            }, f)

    note('Imported uploads: %s.', [v[1] for v in ret])


if __name__ == "__main__":
    import sys

    sys.exit(asyncio.run(main()))
