#!/usr/bin/env python3
# Copyright (c) 2021 Julien Floret
# Copyright (c) 2021 Robin Jarry
# SPDX-License-Identifier: BSD-3-Clause
# pylint: disable=consider-using-f-string,invalid-name

"""
Interact with a dlrepo server. If the server requires a password, it will be prompted
unless the DLREPO_PASSWORD environment variable is specified. By default the user is
your own UNIX user or DLREPO_USER if specified.
"""

import argparse
import base64
import configparser
import getpass
import hashlib
import json
import os
import re
import sys
from urllib.error import HTTPError
from urllib.parse import quote_plus, urlencode, urljoin
from urllib.request import BaseHandler, Request, build_opener


CONF = configparser.ConfigParser()
CONF.read_dict(
    {
        "dlrepo-cli": {
            "user": os.getenv("DLREPO_USER", getpass.getuser()),
            "password": "",
            "url": "http://127.0.0.1:1337",
            "prefix": "/~%(user)s/",
        }
    }
)
CONF.read(os.path.expanduser("~/.config/dlrepo-cli"))


# --------------------------------------------------------------------------------------
def main():
    parser = argparse.ArgumentParser(description=__doc__)
    default_url = os.getenv("DLREPO_URL", CONF.get("dlrepo-cli", "url"))
    parser.add_argument(
        "-U",
        "--url",
        default=default_url,
        help="""
        The root URL of the dlrepo server. (default: %s)
        """
        % (default_url,),
    )
    default_prefix = os.getenv("DLREPO_PREFIX", CONF.get("dlrepo-cli", "prefix"))
    parser.add_argument(
        "-p",
        "--prefix",
        default=default_prefix,
        help="""
        URL prefix. Use --prefix=/ to work at the repository root and not in your own
        user prefix. Attention, depending on the ACLs configuration, you may not have
        permissions to modify anything outside of your own user space. (default: %s)
        """
        % (default_prefix,),
    )
    parser.add_argument(
        "-j",
        "--raw-json",
        action="store_true",
        help="""
        Do not convert the returned JSON data to human readable text.
        Only applies to non-binary data.
        """,
    )
    sub = parser.add_subparsers(title="sub-command help", metavar="SUB_COMMAND")
    sub.required = True

    for cmd in sub_command.commands:
        subparser = sub.add_parser(
            cmd.cmd_name, description=cmd.__doc__, help=cmd.__doc__
        )
        for arg in cmd.sub_args:
            subparser.add_argument(*arg.args, **arg.kwargs)
        subparser.set_defaults(callback=cmd)

    args = parser.parse_args()
    args.url = urljoin(args.url, args.prefix)

    try:
        args.callback(args)
    except KeyboardInterrupt:
        print()
        return 1
    except HTTPError as e:
        with e:
            msg = e.read().decode("utf-8").strip()
        if not msg:
            msg = str(e)
        print("error: remote: %s" % (msg,), file=sys.stderr)
        return 1
    except Exception as e:
        print("error: %s" % (e,), file=sys.stderr)
        return 1

    return 0


# --------------------------------------------------------------------------------------
class Arg:
    def __init__(self, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs


def sub_command(*args):
    def decorator(func):
        func.cmd_name = func.__name__.replace("_", "-")
        func.sub_args = args
        if not hasattr(sub_command, "commands"):
            sub_command.commands = []
        sub_command.commands.append(func)
        return func

    return decorator


# --------------------------------------------------------------------------------------
def local_path(value):
    if not (os.path.isdir(value) or os.path.isfile(value)):
        raise argparse.ArgumentTypeError("%r: No such file or directory" % (value,))
    return value


# --------------------------------------------------------------------------------------
def job_param(value):
    match = re.match(r"^(\w+)=(.*)$", value)
    if not match:
        raise argparse.ArgumentTypeError("%r: Invalid job parameter" % (value,))
    return (match.group(1), match.group(2))


# --------------------------------------------------------------------------------------
@sub_command(
    Arg("branch", metavar="BRANCH", help="the branch name"),
    Arg("tag", metavar="TAG", help="the tag name"),
    Arg("job", metavar="JOB", help="the job name"),
    Arg("format", metavar="FORMAT", help="the artifact format"),
    Arg("target", metavar="DIR", help="the target directory"),
)
def get(args):
    """
    Download files from the specified branch into a local directory.
    """
    client = HttpClient(args.url)
    os.makedirs(args.target, exist_ok=True)
    url = os.path.join("branches", args.branch, args.tag, args.job, args.format) + "/"
    data = client.get(url)
    for f in data["artifact_format"]["files"]:
        client.get_file(os.path.join(url, f), os.path.join(args.target, f))


# --------------------------------------------------------------------------------------
@sub_command(
    Arg("branch", metavar="BRANCH", help="the branch name"),
    Arg("tag", metavar="TAG", help="the tag name"),
    Arg("job", metavar="JOB", help="the job name"),
    Arg("format", metavar="FORMAT", help="the artifact format"),
    Arg(
        "paths",
        metavar="FILE_OR_DIR",
        nargs="+",
        type=local_path,
        help="a file/folder to upload",
    ),
)
def upload(args):
    """
    Upload files and/or folders to the specified branch, tag and job under the
    specified artifact format.
    """
    client = HttpClient(args.url)
    url = os.path.join("branches", args.branch, args.tag, args.job, args.format)

    for path in args.paths:
        if os.path.isdir(path):
            for root, _, files in os.walk(path):
                for fname in files:
                    p = os.path.join(root, fname)
                    target_url = os.path.join(url, os.path.relpath(p, path))
                    client.put_file(p, target_url)
        elif os.path.isfile(path):
            target_url = os.path.join(url, os.path.basename(path))
            client.put_file(path, target_url)

    client.patch(url + "/", None)  # clear the dirty flag

    print("Browse/Download: %s/" % urljoin(client.baseurl, url).lower())


# --------------------------------------------------------------------------------------
@sub_command(
    Arg("branch", metavar="BRANCH", help="the branch name"),
    Arg("tag", metavar="TAG", help="the tag name"),
    Arg("job", metavar="JOB", help="the job name"),
)
def get_info(args):
    """
    Download metadata for the specified job.
    """
    client = HttpClient(args.url)
    data = client.get(os.path.join("branches", args.branch, args.tag, args.job) + "/")
    if args.raw_json:
        print(json.dumps(data, indent=2))
    else:
        for k, v in data["job"].items():
            if isinstance(v, list):
                v = " ".join(v)
            print("%s=%s" % (k, v))


# --------------------------------------------------------------------------------------
@sub_command(
    Arg("branch", metavar="BRANCH", help="the branch name"),
    Arg("tag", metavar="TAG", help="the tag name"),
    Arg("job", metavar="JOB", help="the job name"),
    Arg(
        "info",
        nargs="+",
        metavar="PARAM=VALUE",
        type=job_param,
        help="set a metadata parameter for the specified job (PARAM= to unset)",
    ),
)
def set_info(args):
    """
    Update metadata for the specified job.
    """
    client = HttpClient(args.url)
    url = os.path.join("branches", args.branch, args.tag, args.job) + "/"
    client.patch(url, {"job": dict(args.info)})
    data = client.get(url)["job"]
    if {"product", "version", "product_branch", "product_variant"} <= set(data):
        url = os.path.join(
            "products",
            str(data["product"]),
            str(data["product_variant"]),
            str(data["product_branch"]),
            str(data["version"]),
        )
        print("Browse/Download: %s/" % urljoin(client.baseurl, url).lower())


# --------------------------------------------------------------------------------------
@sub_command(
    Arg("branch", metavar="BRANCH", help="the branch name"),
    Arg("tag", metavar="TAG", help="the tag name"),
    Arg("job", metavar="JOB", nargs="?", help="the job name"),
    Arg(
        "-u",
        "--unlock",
        action="store_true",
        help="unlock instead of locking",
    ),
)
def lock(args):
    """
    Lock a job to prevent further modifications. Or lock a tag so that it is never
    deleted (even by automatic cleanup operations).
    """
    client = HttpClient(args.url)
    data = {"locked": not args.unlock}
    if args.job:
        url = os.path.join("branches", args.branch, args.tag, args.job) + "/"
        data = {"job": data}
    else:
        url = os.path.join("branches", args.branch, args.tag) + "/"
        data = {"tag": data}
    client.put(url, data)


# --------------------------------------------------------------------------------------
@sub_command()
def branches(args):
    """
    List the available branches on the repo.
    """
    client = HttpClient(args.url)
    data = client.get("/branches/")
    if args.raw_json:
        print(json.dumps(data, indent=2))
    else:
        for branch in data["branches"]:
            print(branch["name"])


# --------------------------------------------------------------------------------------
@sub_command(
    Arg("branch", metavar="BRANCH", help="the branch name"),
    Arg("-r", "--released-only", action="store_true", help="only show released tags"),
)
def tags(args):
    """
    List the available tags in the specified branch.
    """
    client = HttpClient(args.url)
    if args.released_only:
        params = {"released": 1}
    else:
        params = {}
    data = client.get(os.path.join("branches", args.branch) + "/", params)
    if args.raw_json:
        print(json.dumps(data, indent=2))
    else:
        for tag in data["branch"]["tags"]:
            print(tag["name"])


# --------------------------------------------------------------------------------------
@sub_command(
    Arg("branch", metavar="BRANCH", help="the branch name"),
    Arg("tag", metavar="TAG", help="the tag name"),
    Arg(
        "filters",
        nargs="*",
        metavar="PARAM=VALUE",
        type=job_param,
        help="filter jobs that have this metadata parameter value",
    ),
)
def jobs(args):
    """
    List the available jobs in the specified branch and tag.
    """
    client = HttpClient(args.url)
    params = {}
    for key, value in args.filters:
        params[key] = value
    url = os.path.join("branches", args.branch, args.tag) + "/"
    data = client.get(url, params)
    if args.raw_json:
        print(json.dumps(data, indent=2))
    else:
        for job in data["tag"]["jobs"]:
            print(job["name"])


# --------------------------------------------------------------------------------------
@sub_command(
    Arg("branch", metavar="BRANCH", help="the branch name"),
    Arg("tag", metavar="TAG", help="the tag name"),
)
def status(args):
    """
    Display the current status of a tag.
    """
    client = HttpClient(args.url)
    data = client.get(os.path.join("branches", args.branch, args.tag) + "/")
    if args.raw_json:
        print(json.dumps(data, indent=2))
    else:
        print("released=%s" % data["tag"]["released"])
        print("locked=%s" % data["tag"]["locked"])
        print("publish_status=%s" % data["tag"].get("publish_status", ""))


# --------------------------------------------------------------------------------------
@sub_command(
    Arg("branch", metavar="BRANCH", help="the branch name"),
    Arg("tag", metavar="TAG", help="the tag name"),
    Arg(
        "-u",
        "--unset",
        action="store_true",
        help="unset the 'released' status from the tag instead of setting it",
    ),
)
def release(args):
    """
    Set or unset the 'released' status on a tag. If a remote publish url has been
    configured on the server, setting (or unsetting) the 'released' status will trigger
    an asynchronous modification of the remote server.
    The current publish status of a tag can be displayed with the 'status' command.
    """
    client = HttpClient(args.url)
    url = os.path.join("branches", args.branch, args.tag) + "/"
    client.put(url, {"tag": {"released": not args.unset}})


# --------------------------------------------------------------------------------------
@sub_command(
    Arg("branch", metavar="BRANCH", help="the branch name"),
    Arg("tag", metavar="TAG", help="the tag name"),
    Arg("job", metavar="JOB", help="the job name"),
    Arg(
        "-u",
        "--unset",
        action="store_true",
        help="unset the 'internal' status from the job instead of setting it",
    ),
)
def internal(args):
    """
    Set or unset the 'internal' status on a job. An internal job is never
    published to a remote server.
    """
    client = HttpClient(args.url)
    url = os.path.join("branches", args.branch, args.tag, args.job) + "/"
    client.put(url, {"job": {"internal": not args.unset}})


# --------------------------------------------------------------------------------------
@sub_command(
    Arg("branch", metavar="BRANCH", help="the branch name"),
    Arg("tag", metavar="TAG", nargs="?", help="the tag name"),
    Arg("job", metavar="JOB", nargs="?", help="the job name"),
    Arg(
        "-f",
        "--force",
        action="store_true",
        help="force the deletion of 'released' tags",
    ),
)
def delete(args):
    """
    Delete a job, a tag or a branch and all its tags recursively.
    """
    client = HttpClient(args.url)
    params = {}
    if args.tag:
        if args.job:
            url = os.path.join("branches", args.branch, args.tag, args.job) + "/"
        else:
            if args.force:
                params["force"] = "true"
            url = os.path.join("branches", args.branch, args.tag) + "/"
    else:
        url = os.path.join("branches", args.branch) + "/"
    client.delete(url, params)


# --------------------------------------------------------------------------------------
@sub_command()
def flush(args):
    """
    Delete all user data.
    """
    if "~" not in args.prefix:
        raise ValueError("%s is not a user space, refusing to flush" % args.url)
    client = HttpClient(args.url)
    data = client.get("/branches/")
    print("deleting all branches in %s" % args.url)
    for branch in data["branches"]:
        url = os.path.join("branches", branch["name"]) + "/"
        client.delete(url, {})


# --------------------------------------------------------------------------------------
@sub_command(
    Arg("branch", metavar="BRANCH", help="the branch name"),
    Arg(
        "-d",
        "--max-daily",
        type=int,
        help="maximum non-released tags to keep",
    ),
    Arg(
        "-r",
        "--max-released",
        type=int,
        help="maximum released tags to keep",
    ),
)
def cleanup_policy(args):
    """
    Get or set a branch cleanup policy.
    """
    client = HttpClient(args.url)
    policy = {}
    if args.max_daily is not None:
        policy["max_daily_tags"] = args.max_daily
    if args.max_released is not None:
        policy["max_released_tags"] = args.max_released
    if policy:
        client.put(os.path.join("branches", args.branch) + "/", {"branch": policy})
    else:
        data = client.get(os.path.join("branches", args.branch) + "/")
        if args.raw_json:
            print(json.dumps(data, indent=2))
        else:
            for i in data["branch"]["cleanup_policy"].items():
                print("%s=%s" % i)


# --------------------------------------------------------------------------------------
class AuthHandler(BaseHandler):
    handler_order = 500

    def __init__(self):
        self.prompted_auth = False
        self.user = CONF.get("dlrepo-cli", "user")
        self.password = os.getenv("DLREPO_PASSWORD", CONF.get("dlrepo-cli", "password"))

    def http_error_401(self, request, response, code, msg, headers):
        if self.prompted_auth or headers.get("WWW-Authenticate") is None:
            return None
        if hasattr(request.data, "seek"):
            request.data.seek(0)
        retry = Request(
            url=request.get_full_url(),
            data=request.data,
            headers=request.headers,
            origin_req_host=request.origin_req_host,
            method=request.method,
        )
        self.prompted_auth = True
        if sys.stdin.isatty():
            sys.stderr.write("\n")
            sys.stderr.flush()
            self.password = getpass.getpass(
                "%s@%s's password: " % (self.user, request.origin_req_host)
            )
        return self.parent.open(retry)

    https_error_401 = http_error_401

    def http_request(self, request):
        if self.user:
            auth = "%s:%s" % (self.user, self.password or "")
            auth = base64.standard_b64encode(auth.encode("utf-8")).decode("ascii")
            request.add_header("Authorization", "Basic " + auth)
        return request

    https_request = http_request


# --------------------------------------------------------------------------------------
class HttpClient:
    def __init__(self, baseurl):
        self.baseurl = baseurl
        self.opener = build_opener(AuthHandler)

    def make_url(self, url, params=None):
        url = quote_plus(url.lstrip("/"), safe="/", encoding="utf-8")
        url = urljoin(self.baseurl, url)
        if params is not None:
            url += "?" + urlencode(params)
        return url

    def get(self, url, params=None, headers=None):
        url = self.make_url(url, params)
        request = Request(url, method="GET")
        return self._send(request, headers)

    @classmethod
    def human_readable(cls, value):
        if value == 0:
            return "0"
        units = ("K", "M", "G", "T")
        i = 0
        unit = ""
        while value >= 1000 and i < len(units):
            unit = units[i]
            value /= 1000
            i += 1
        if value < 100:
            return "{:.1f}{}".format(value, unit)
        return "{:.0f}{}".format(value, unit)

    @classmethod
    def log(cls, msg, *args, end="\n"):
        sys.stderr.write(msg % args)
        sys.stderr.write(end)
        sys.stderr.flush()

    @classmethod
    def progress(cls, nbytes, total_size, filename):
        n = cls.human_readable(nbytes)
        if total_size:
            total = cls.human_readable(total_size)
            percent = "{:.0%}".format(nbytes / total_size)
        else:
            percent = total = ""
        end = "\r" if sys.stderr.isatty() else "\n"
        cls.log(
            "%5s / %5s %6s  %s", n, total, percent, os.path.basename(filename), end=end
        )

    def get_file(self, url, filepath):
        url = self.make_url(url)
        request = Request(url, method="GET")
        buf = bytearray(256 * 1024)
        view = memoryview(buf)

        with self.opener.open(request) as response:
            size = response.getheader("Content-Length")
            if size is not None:
                size = int(size)
            digest = response.getheader("Digest")
            if digest is not None:
                algo, digest = digest.split(":")
                hasher = hashlib.new(algo)
            else:
                algo = hasher = None
            try:
                self.progress(0, size, filepath)
                filedir = os.path.dirname(filepath)
                if filedir:
                    os.makedirs(os.path.dirname(filepath), exist_ok=True)
                i = total = 0
                with open(filepath, "wb") as f:
                    n = response.readinto(buf)
                    while n:
                        f.write(view[:n])
                        if hasher is not None:
                            hasher.update(view[:n])
                        total += n
                        if i % 64 == 0 and sys.stderr.isatty():
                            self.progress(total, size, filepath)
                        i += 1
                        n = response.readinto(buf)
                if hasher is not None and hasher.hexdigest() != digest:
                    os.unlink(filepath)
                    try:
                        os.removedirs(os.path.dirname(filepath))
                    except OSError:
                        pass
                    raise OSError(
                        "downloaded data does not match digest: %s:%s" % (algo, digest)
                    )
                self.progress(total, total, filepath)
            finally:
                if sys.stderr.isatty():
                    self.log("")

    def delete(self, url, params=None, headers=None):
        url = self.make_url(url, params)
        request = Request(url, method="DELETE")
        return self._send(request, headers)

    def _encode_body(self, body, headers=None):
        headers = headers or {}
        if isinstance(body, bytes):
            headers["Content-Type"] = "application/octet-stream"
        elif isinstance(body, (list, tuple, dict)):
            body = json.dumps(body).encode("utf-8")
            headers["Content-Type"] = "application/json"
        elif isinstance(body, str):
            body = body.encode("utf-8")
            headers["Content-Type"] = "text/plain"
        return body, headers

    def put(self, url, body, headers=None):
        url = self.make_url(url)
        body, headers = self._encode_body(body, headers)
        request = Request(url, body, method="PUT")
        return self._send(request, headers)

    class FileReader:
        def __init__(self, filepath, size):
            self.filepath = filepath
            self.size = size
            self.i = 0
            self.uploaded = 0
            self.fileobj = None

        def __enter__(self):
            self.fileobj = open(self.filepath, "rb")
            return self

        def seek(self, offset, whence=os.SEEK_SET):
            self.fileobj.seek(offset, whence)
            self.uploaded = offset

        def read(self, n):
            buf = self.fileobj.read(n)
            self.i += 1
            self.uploaded += len(buf)
            if self.i % 64 == 0 and sys.stderr.isatty():
                HttpClient.progress(self.uploaded, self.size, self.filepath)
            return buf

        def __exit__(self, *args):
            if self.fileobj is not None:
                self.fileobj.close()

    def head(self, url, headers=None):
        url = self.make_url(url)
        request = Request(url, method="HEAD")
        return self._send(request, headers)

    def put_file(self, filepath, url):
        size = os.path.getsize(filepath)
        hasher = hashlib.sha256()
        buf = bytearray(256 * 1024)
        view = memoryview(buf)
        with open(filepath, "rb") as f:
            n = f.readinto(buf)
            while n:
                hasher.update(view[:n])
                n = f.readinto(buf)
        digest = "sha256:" + hasher.hexdigest()
        try:
            try:
                self.head(url, headers={"Digest": digest})
                # file digest already present on the server, do not upload the data again
                self.progress(0, size, filepath + " (deduplicated)")
                self.put(
                    url, body=None, headers={"Digest": digest, "X-Dlrepo-Link": digest}
                )
                self.progress(size, size, filepath + " (deduplicated)")
                return
            except HTTPError as e:
                if e.code != 404:
                    raise
                # file digest not on server, proceed with upload

            url = self.make_url(url)
            self.progress(0, size, filepath)

            with self.FileReader(filepath, size) as f:
                request = Request(url, f, method="PUT")
                request.add_header("User-Agent", "dlrepo-cli")
                request.add_header("Content-Size", str(size))
                if sys.version_info < (3, 6):
                    request.add_header("Content-Length", str(size))
                request.add_header("Content-Type", "application/octet-stream")
                request.add_header("Digest", digest)
                with self.opener.open(request):
                    pass
            self.progress(size, size, filepath)
        finally:
            if sys.stderr.isatty():
                self.log("")

    def patch(self, url, body, headers=None):
        url = self.make_url(url)
        body, headers = self._encode_body(body, headers)
        request = Request(url, body, method="PATCH")
        return self._send(request, headers)

    def _send(self, request, headers=None):
        request.add_header("Accept", "application/json,application/octet-stream")
        request.add_header("User-Agent", "dlrepo-cli")
        request.add_header("Referer", request.get_full_url())

        if headers is not None:
            for key, val in headers.items():
                request.add_header(key, val)

        with self.opener.open(request) as response:
            content_type = response.headers.get("Content-Type", "")
            if request.method == "HEAD" or "json" not in content_type:
                return response.read()
            return json.loads(response.read().decode("utf-8"))


# --------------------------------------------------------------------------------------
if __name__ == "__main__":
    sys.exit(main())
