#!/usr/bin/env python3
# Copyright 2020 Petuum, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
import json
import kubernetes
import platform
import os
import subprocess
import sys
import time
import traceback
import uuid
from datetime import datetime
from pathlib import Path

from adaptdl_cli.proxy import service_proxy
from adaptdl_cli.pvc import get_storageclass, create_pvc, create_copy_pod
from adaptdl_cli.tensorboard import TENSORBOARD_PREFIX, \
    add_tensorboard_commands


def _find_insecure_registry():
    services = kubernetes.client.CoreV1Api().list_service_for_all_namespaces(
        label_selector="app=docker-registry",
        field_selector="metadata.name=adaptdl-registry")
    for service in services.items:
        for port in service.spec.ports:
            if port.node_port and port.name == "registry":
                return (service.metadata.namespace,
                        service.metadata.name + ":" + port.name,
                        port.node_port)
    return None, None, None


def _get_insecure_registry_proxy(proxy_port):
    if platform.system() == "Darwin":
        # Docker for Mac runs in a VM and can't access the proxy which is bound
        # to a port on the host. Use a special DNS name instead.
        # https://github.com/docker/for-mac/issues/3611
        # https://docs.docker.com/docker-for-mac/networking/#use-cases-and-workarounds
        registry = f"host.docker.internal:{proxy_port}"
        # Docker doesn't automatically trust host.docker.internal as an
        # insecure registry. Need to add it manually.
        docker_dir = Path.home() / ".docker"
        daemon_path = docker_dir / "daemon.json"
        try:
            with open(daemon_path, "r") as f:
                daemon = json.load(f)
        except FileNotFoundError:
            daemon = {}
        if registry not in daemon.setdefault("insecure-registries", []):
            daemon["insecure-registries"].append(registry)
            docker_dir.mkdir(exist_ok=True)
            with open(daemon_path, "w") as f:
                json.dump(daemon, f, indent=2)
            raise SystemExit(f"Insecure registry {registry} written to "
                             f"{daemon_path}. Please restart Docker and then"
                             " submit your job again!")
        return registry
    return f"localhost:{proxy_port}"


def _build_push(project, dockerfile, proxy_port):
    image = os.getenv("ADAPTDL_SUBMIT_REPO")
    external = image is not None
    if external:
        print(f"Using external Docker repository at {image}.")
    else:
        print("Using AdaptDL insecure registry.")
        registry = _get_insecure_registry_proxy(proxy_port)
        image = f"{registry}/adaptdl-submit"
    try:
        subprocess.check_call(["docker", "build", "-t", image, project] +
                              (["-f", dockerfile] if dockerfile else []))
    except subprocess.CalledProcessError:
        raise SystemExit(f"Error: could not build image {image}")
    if external:
        try:
            subprocess.check_call(["docker", "push", image])
        except subprocess.CalledProcessError:
            raise SystemExit(f"Error: could not push image {image}")
        remote_image = image
    else:
        namespace, service, nodeport = _find_insecure_registry()
        if None in (namespace, service, nodeport):
            raise SystemExit("Error: No Docker registry could be found!")
        with service_proxy(namespace, service, listen_port=proxy_port):
            try:
                subprocess.check_call(["docker", "push", image])
            except subprocess.CalledProcessError:
                raise SystemExit(f"Error: could not push image {image}")
        remote_image = f"localhost:{nodeport}/adaptdl-submit"
    repodigests = json.loads(subprocess.check_output(
        ["docker", "image", "inspect", image,
         "--format={{json .RepoDigests}}"]))
    # Return a remote-accessible repodigest.
    for digest in repodigests:
        if digest.startswith(image):
            return remote_image + "@" + digest.split("@")[-1]


def submit(args, remaining):
    repodigest = _build_push(args.project, args.dockerfile, args.proxy_port)
    jobfile = args.jobfile or os.path.join(args.project, "adaptdljob.yaml")
    resource = subprocess.check_output(
        ["kubectl", "create", "--dry-run", "-f", jobfile, "-o", "json"])
    resource = json.loads(resource.decode())
    pod_spec = resource["spec"]["template"]["spec"]
    pod_spec["containers"][0]["image"] = repodigest
    pod_spec["containers"][0]["args"] = remaining
    secret = os.getenv("ADAPTDL_SUBMIT_REPO_CREDS")
    if secret:
        pod_spec["imagePullSecrets"] = [{"name": secret}]

    if args.name is not None:
        resource["metadata"]["generateName"] = args.name + "-"

    context = kubernetes.config.list_kube_config_contexts()[1]
    namespace = context["context"].get("namespace", "default")

    if args.tensorboard is not None:
        # Check PersistentVolumeClaim exists.
        claim_name = "{}{}".format(TENSORBOARD_PREFIX, args.tensorboard)
        core_api = kubernetes.client.CoreV1Api()
        core_api.read_namespaced_persistent_volume_claim(claim_name, namespace)
        volumes = pod_spec.setdefault("volumes", [])
        volumes.append({
            "name": "adaptdl-tensorboard",
            "persistentVolumeClaim": {
                "claimName": claim_name,
            },
        })
        for container in pod_spec["containers"]:
            volume_mounts = container.setdefault("volumeMounts", [])
            volume_mounts.append({
                "name": "adaptdl-tensorboard",
                "mountPath": "/adaptdl/tensorboard",
            })
            env = container.setdefault("env", [])
            env.append({
                "name": "ADAPTDL_TENSORBOARD_LOGDIR",
                "value": volume_mounts[-1]["mountPath"],
            })

    volume_name = "adaptdl-pvc"
    pvc_name = "adaptdl-pvc-{}".format(str(uuid.uuid4()))
    volume = {
        "name": volume_name,
        "persistentVolumeClaim": {
            "claimName": pvc_name,
        }
    }
    for container in pod_spec["containers"]:
        volume_mounts = container.setdefault("volumeMounts", [])
        volume_mounts.append({
            "name": volume_name,
            "mountPath": "/adaptdl/checkpoint",
            "subPath": "adaptdl/checkpoint"
        })
        volume_mounts.append({
            "name": volume_name,
            "mountPath": "/adaptdl/share",
            "subPath": "adaptdl/share"

        })
        env = container.setdefault("env", [])
        env.extend([{"name": "ADAPTDL_CHECKPOINT_PATH",
                     "value": "/adaptdl/checkpoint"},
                    {"name": "ADAPTDL_SHARE_PATH",
                     "value": "/adaptdl/share"}])

    volumes = pod_spec.setdefault("volumes", [])
    volumes.append(volume)

    storage_class = get_storageclass(name=args.checkpoint_storage_class)
    custom_object_api = kubernetes.client.CustomObjectsApi()
    adaptdljob = custom_object_api.create_namespaced_custom_object(
        "adaptdl.petuum.com", "v1", namespace, "adaptdljobs", resource)
    create_pvc(name=pvc_name, storage_class=storage_class,
               size=args.checkpoint_storage_size,
               owner_metadata=adaptdljob["metadata"])


def cp(args, remaining):
    source = args.source
    parsed_source = source.split(":")
    assert(len(parsed_source) == 2), \
        f"{source} is not a valid source specification"
    job_name, remote_path = parsed_source
    assert(remote_path[0] == '/'), \
        "absolute path is required for the source path"
    local_path = args.destination
    cp_job_uid = str(uuid.uuid4())[:8]
    job_json = subprocess.check_output(
        ["kubectl", "get", "adaptdljobs", job_name, "-o", "json"]).decode()
    job_object = json.loads(job_json)
    for volume in job_object["spec"]["template"]["spec"]["volumes"]:
        if volume["name"] == "adaptdl-pvc":
            pvc_name = volume["persistentVolumeClaim"]["claimName"]
            break
    else:
        raise SystemExit("Error: no AdaptDL persistent volume claim for "
                         f"AdaptDL job {job_name}."
                         "Cannot perform `adaptdl cp`")
    try:
        core_api = kubernetes.client.CoreV1Api()
        pod = create_copy_pod(pvc_name, cp_job_uid)
        pod_name = pod.metadata.name
        namespace = pod.metadata.namespace
        while pod.status.phase == "Pending":
            time.sleep(1)
            pod = core_api.read_namespaced_pod_status(namespace=namespace,
                                                      name=pod_name)
        if pod.status.phase != "Running":
            raise RuntimeError(f"Pod {pod_name} created to copy files"
                               f"in unexpected phase {pod.status.phase}")

        print("Copying files from cluster to client machine.")
        subprocess.check_output(
            ["kubectl", "cp",
             f"{namespace}/{pod_name}:/adaptdl_pvc{remote_path}", local_path])

    except Exception:
        tb = traceback.format_exc()
        print(tb, file=sys.stderr)
        raise SystemExit("Attempt to copy files from cluster failed. "
                         "Cleaning up.")

    finally:
        try:
            api = kubernetes.client.CoreV1Api()
            api.delete_namespaced_pod(namespace=namespace,
                                      name=pod_name)
        except Exception:
            tb = traceback.format_exc()
            print(tb, file=sys.stderr)
            raise SystemExit("Attempt to clean up failed. "
                             f"Please ensure pod {pod_name} is deleted")


def logs(args, remaining):
    while True:
        try:
            subprocess.check_call(["kubectl",
                                   "logs", "-l",
                                   f"adaptdl/job={args.jobname}"] + remaining)
        except KeyboardInterrupt:
            return
        except Exception:
            traceback.print_exc()
            print("PRESS CTRL-C TO EXIT....")
            time.sleep(2)


def ls(args, remaining):
    def default_ls():
        all_jobs = {}
        get_out = subprocess.check_output(
            ["kubectl", "get", "adaptdljobs", "-o", "json"]).decode()
        get_out = json.loads(get_out)

        if (len(get_out["items"]) == 0):
            print("No adaptdljobs")

        for v in get_out["items"]:
            if (v["status"] is None):
                continue

            # fromisoformat does not handle the trailing 'Z'
            start_time = datetime.strptime(v["metadata"]["creationTimestamp"],
                                           "%Y-%m-%dT%H:%M:%SZ")
            current_timestamp = datetime.strptime(
                v["status"].get("completionTimestamp",
                                datetime.utcnow().isoformat()),
                "%Y-%m-%dT%H:%M:%S.%f%z")
            run_time = current_timestamp.replace(tzinfo=None) \
                - start_time.replace(tzinfo=None)

            ls_data = {}
            ls_data["start_time"] = start_time
            ls_data["run_time"] = run_time
            ls_data["replicas"] = v["status"].get("replicas", "N/A")
            ls_data["phase"] = v["status"].get("phase")
            ls_data["restart"] = v["status"].get("group", 0)

            all_jobs[v["metadata"]["name"]] = ls_data
        return all_jobs
    # Default listing
    fmt_str = "{:<65}{:<11}{:<14}{:<9}{:<6}{:<6}"

    def print_ls_header():
        print(fmt_str.format("Name", "Status", "Start(UTC)", "Runtime",
                             "Rplc", "Rtrt"))

    def print_ls_row(k, v):
        print(fmt_str.format(
              k,
              v["phase"],
              str(v["start_time"].strftime("%b-%d %H:%S")),
              str(v["run_time"].seconds//60) + " min",
              v["replicas"],
              v["restart"]))

    ls_values = default_ls()
    print_ls_header()
    for k, v in sorted(ls_values.items(),
                       key=lambda item: item[1]["start_time"]):
        print_ls_row(k, v)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.set_defaults(handler=lambda args, remaining: parser.print_help())

    subparsers = parser.add_subparsers(help="sub-command help")

    parser_submit = subparsers.add_parser(
        "submit", help="submit a job to an AdaptDL cluster")
    parser_submit.add_argument(
        "project", type=str, help="path to the project directory")
    parser_submit.add_argument(
        "-f", "--jobfile", type=str,
        help="path to the AdaptDLJob resource file")
    parser_submit.add_argument(
        "-d", "--dockerfile", type=str, help="path to the Dockerfile")
    parser_submit.add_argument(
        "-n", "--name", type=str, help="name to identify the AdaptDLjob")
    parser_submit.add_argument(
        "--tensorboard", type=str, help="name of tensorboard instance")
    parser_submit.add_argument(
        "--checkpoint-storage-size", type=str,
        help="size of new temporary pvc to create", default="1Gi")
    parser_submit.add_argument(
        "--checkpoint-storage-class", type=str,
        help="storage class of new temporary pvc to create", default=None)
    parser_submit.add_argument(
        "--proxy-port", type=int, default=59283,  # Some uncommon port.
        help="temporary proxy port used for pushing to the insecure registry")
    parser_submit.set_defaults(handler=submit)

    parser_logs = subparsers.add_parser(
        "logs", help="view the logs of an AdaptDLJob")
    parser_logs.add_argument(
        "jobname", type=str, help="name of the AdaptDLJob")
    parser_logs.set_defaults(handler=logs)

    parser_ls = subparsers.add_parser(
        "ls",
        help="View the status of AdaptDLJobs in the format"
             " (Name, Status, Start Time, Runtime, Replicas (Rplc),"
             "  Number of nodes (Nodes), Speedup (Spdp), Restarts (Rtrt)")
    parser_ls.set_defaults(handler=ls)

    parser_cp = subparsers.add_parser(
        "cp", help="copy a file from the AdaptDL cluster to the local client")
    parser_cp.add_argument(
        "source", type=str,
        help="source for the copy. use kubectl cp conventions"
             ", e.g. <job-name>:<remote-path>")
    parser_cp.add_argument(
        "destination", type=str,
        help="destination path for the copy on the client")
    parser_cp.set_defaults(handler=cp)

    parser_tensorboard = subparsers.add_parser(
        "tensorboard", help="manage tensorboard instances")
    add_tensorboard_commands(parser_tensorboard)

    if "--" in sys.argv:
        remaining = []
        idx = sys.argv.index("--")
        sys.argv.pop(idx)
        while len(sys.argv) > idx:
            remaining.append(sys.argv.pop(idx))
        args = parser.parse_args()
    else:
        args, remaining = parser.parse_known_args()

    kubernetes.config.load_kube_config()
    try:
        args.handler(args, remaining)
    except KeyboardInterrupt:
        pass
    except kubernetes.client.rest.ApiException as exc:
        result = json.loads(exc.body)
        print("{} ({}): {}".format(result.get("status"), result.get("reason"),
                                   result.get("message")), file=sys.stderr)
        exit(1)
