#!/usr/bin/env python3
# Copyright 2020 Petuum, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
import json
import kubernetes
import os
import subprocess
import sys
import time
import traceback
import uuid
from datetime import datetime

from adaptdl_cli.pvc import get_storageclass, create_pvc, create_copy_pod
from adaptdl_cli.tensorboard import TENSORBOARD_PREFIX, \
    add_tensorboard_commands
from adaptdl_cli.registry import fix_local_docker, ADAPTDL_REGISTRY_URL, \
    ADAPTDL_REGISTRY_CREDS


def _get_registry_info():
    if 'ADAPTDL_SUBMIT_REPO' and 'ADAPTDL_SUBMIT_REPO_CREDS' in os.environ:
        print(f"Using external repo at "
              f"{os.environ['ADAPTDL_SUBMIT_REPO']}")
        return os.environ['ADAPTDL_SUBMIT_REPO'], \
            os.environ['ADAPTDL_SUBMIT_REPO_CREDS'], False
    else:
        print("Using AdaptDL internal registry")
        return ADAPTDL_REGISTRY_URL, ADAPTDL_REGISTRY_CREDS, True


def _prefix(image):
    # RepoDigests have sha256 suffixes and not user tags, so we remove it here
    parts = image.split('/')
    return '/'.join(parts[:-1] + [parts[-1].split(':')[0]])


def submit(args, remaining):
    repo, secret, internal = _get_registry_info()
    if internal:
        # We choose the full reponame
        image = f"{repo}/dev/adaptdl-submit:latest"
    else:
        # Expect users to provide full repo name
        image = repo
    subprocess.check_call(["docker", "build", "-t", image, args.project] +
                          ([] if not args.dockerfile else
                              ["-f", args.dockerfile]))
    try:
        subprocess.check_call(["docker", "push", image])
    except subprocess.CalledProcessError:
        if internal:
            # We try to fix the local docker client assuming internal docker
            # registry is up and running. This includes,
            # 1. Adding well-known registry hostname to /etc/hosts
            # 2. Making local docker daemon accept unsecure AdaptDL registry
            # 3. docker login to AdaptDL registry
            print("Fixing local docker client.")
            fix_local_docker()
            subprocess.check_call(["docker", "push", image])
    repodigests = json.loads(subprocess.check_output(
        ["docker", "image", "inspect", image,
         "--format={{json .RepoDigests}}"]))
    # Find a matching digest from possible aliases
    for digest in repodigests:
        if digest.startswith(_prefix(image)):
            repodigest = digest
    assert repodigest
    jobfile = args.jobfile or os.path.join(args.project, "adaptdljob.yaml")
    resource = subprocess.check_output(
        ["kubectl", "create", "--dry-run", "-f", jobfile, "-o", "json"])
    resource = json.loads(resource.decode())
    resource["spec"]["template"]["spec"]["containers"][0]["image"] = \
        repodigest
    resource["spec"]["template"]["spec"]["containers"][0]["args"] = remaining
    resource["spec"]["template"]["spec"]["imagePullSecrets"] = \
            [{"name": secret}]

    if args.name is not None:
        resource["metadata"]["generateName"] = args.name + "-"

    context = kubernetes.config.list_kube_config_contexts()[1]
    namespace = context["context"].get("namespace", "default")

    if args.tensorboard is not None:
        # Check PersistentVolumeClaim exists.
        claim_name = "{}{}".format(TENSORBOARD_PREFIX, args.tensorboard)
        core_api = kubernetes.client.CoreV1Api()
        core_api.read_namespaced_persistent_volume_claim(claim_name, namespace)
        volumes = resource["spec"]["template"]["spec"].setdefault("volumes",
                                                                  [])
        volumes.append({
            "name": "adaptdl-tensorboard",
            "persistentVolumeClaim": {
                "claimName": claim_name,
            },
        })
        for container in resource["spec"]["template"]["spec"]["containers"]:
            volume_mounts = container.setdefault("volumeMounts", [])
            volume_mounts.append({
                "name": "adaptdl-tensorboard",
                "mountPath": "/adaptdl/tensorboard",
            })
            env = container.setdefault("env", [])
            env.append({
                "name": "ADAPTDL_TENSORBOARD_LOGDIR",
                "value": volume_mounts[-1]["mountPath"],
            })

    volume_name = "adaptdl-pvc"
    pvc_name = "adaptdl-pvc-{}".format(str(uuid.uuid4()))
    volume = {
        "name": volume_name,
        "persistentVolumeClaim": {
            "claimName": pvc_name,
        }
    }
    for container in resource["spec"]["template"]["spec"]["containers"]:
        volume_mounts = container.setdefault("volumeMounts", [])
        volume_mounts.append({
            "name": volume_name,
            "mountPath": "/adaptdl/checkpoint",
            "subPath": "adaptdl/checkpoint"
        })
        volume_mounts.append({
            "name": volume_name,
            "mountPath": "/adaptdl/share",
            "subPath": "adaptdl/share"

        })
        env = container.setdefault("env", [])
        env.extend([{"name": "ADAPTDL_CHECKPOINT_PATH",
                     "value": "/adaptdl/checkpoint"},
                    {"name": "ADAPTDL_SHARE_PATH",
                     "value": "/adaptdl/share"}])

    volumes = resource["spec"]["template"]["spec"].setdefault("volumes", [])
    volumes.append(volume)

    storage_class = get_storageclass(name=args.checkpoint_storage_class)
    custom_object_api = kubernetes.client.CustomObjectsApi()
    adaptdljob = custom_object_api.create_namespaced_custom_object(
        "adaptdl.petuum.com", "v1", namespace, "adaptdljobs", resource)
    create_pvc(name=pvc_name, storage_class=storage_class,
               size=args.checkpoint_storage_size,
               owner_metadata=adaptdljob["metadata"])


def cp(args, remaining):
    source = args.source
    parsed_source = source.split(":")
    assert(len(parsed_source) == 2), \
        f"{source} is not a valid source specification"
    job_name, remote_path = parsed_source
    assert(remote_path[0] == '/'), \
        "absolute path is required for the source path"
    local_path = args.destination
    cp_job_uid = str(uuid.uuid4())[:8]
    job_json = subprocess.check_output(
        ["kubectl", "get", "adaptdljobs", job_name, "-o", "json"]).decode()
    job_object = json.loads(job_json)
    for volume in job_object["spec"]["template"]["spec"]["volumes"]:
        if volume["name"] == "adaptdl-pvc":
            pvc_name = volume["persistentVolumeClaim"]["claimName"]
            break
    else:
        raise SystemExit("Error: no AdaptDL persistent volume claim for "
                         f"AdaptDL job {job_name}."
                         "Cannot perform `adaptdl cp`")
    try:
        core_api = kubernetes.client.CoreV1Api()
        pod = create_copy_pod(pvc_name, cp_job_uid)
        pod_name = pod.metadata.name
        namespace = pod.metadata.namespace
        while pod.status.phase == "Pending":
            time.sleep(1)
            pod = core_api.read_namespaced_pod_status(namespace=namespace,
                                                      name=pod_name)
        if pod.status.phase != "Running":
            raise RuntimeError(f"Pod {pod_name} created to copy files"
                               f"in unexpected phase {pod.status.phase}")

        print("Copying files from cluster to client machine.")
        subprocess.check_output(
            ["kubectl", "cp",
             f"{namespace}/{pod_name}:/adaptdl_pvc{remote_path}", local_path])

    except Exception:
        tb = traceback.format_exc()
        print(tb, file=sys.stderr)
        raise SystemExit("Attempt to copy files from cluster failed. "
                         "Cleaning up.")

    finally:
        try:
            api = kubernetes.client.CoreV1Api()
            api.delete_namespaced_pod(namespace=namespace,
                                      name=pod_name)
        except Exception:
            tb = traceback.format_exc()
            print(tb, file=sys.stderr)
            raise SystemExit("Attempt to clean up failed. "
                             f"Please ensure pod {pod_name} is deleted")


def logs(args, remaining):
    while True:
        try:
            subprocess.check_call(["kubectl",
                                   "logs", "-l",
                                   f"adaptdl/job={args.jobname}"] + remaining)
        except KeyboardInterrupt:
            return
        except Exception:
            traceback.print_exc()
            print("PRESS CTRL-C TO EXIT....")
            time.sleep(2)


def ls(args, remaining):
    def default_ls():
        all_jobs = {}
        get_out = subprocess.check_output(
            ["kubectl", "get", "adaptdljobs", "-o", "json"]).decode()
        get_out = json.loads(get_out)

        if (len(get_out["items"]) == 0):
            print("No adaptdljobs")

        for v in get_out["items"]:
            if (v["status"] is None):
                continue

            # fromisoformat does not handle the trailing 'Z'
            start_time = datetime.strptime(v["metadata"]["creationTimestamp"],
                                           "%Y-%m-%dT%H:%M:%SZ")
            current_timestamp = datetime.strptime(
                v["status"].get("completionTimestamp",
                                datetime.utcnow().isoformat()),
                "%Y-%m-%dT%H:%M:%S.%f%z")
            run_time = current_timestamp.replace(tzinfo=None) \
                - start_time.replace(tzinfo=None)

            ls_data = {}
            ls_data["start_time"] = start_time
            ls_data["run_time"] = run_time
            ls_data["replicas"] = v["status"].get("replicas", "N/A")
            ls_data["phase"] = v["status"].get("phase")
            ls_data["restart"] = v["status"].get("group", 0)

            all_jobs[v["metadata"]["name"]] = ls_data
        return all_jobs
    # Default listing
    fmt_str = "{:<65}{:<11}{:<14}{:<9}{:<6}{:<6}"

    def print_ls_header():
        print(fmt_str.format("Name", "Status", "Start(UTC)", "Runtime",
                             "Rplc", "Rtrt"))

    def print_ls_row(k, v):
        print(fmt_str.format(
              k,
              v["phase"],
              str(v["start_time"].strftime("%b-%d %H:%S")),
              str(v["run_time"].seconds//60) + " min",
              v["replicas"],
              v["restart"]))

    ls_values = default_ls()
    print_ls_header()
    for k, v in sorted(ls_values.items(),
                       key=lambda item: item[1]["start_time"]):
        print_ls_row(k, v)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.set_defaults(handler=lambda args, remaining: parser.print_help())

    subparsers = parser.add_subparsers(help="sub-command help")

    parser_submit = subparsers.add_parser(
        "submit", help="submit a job to an AdaptDL cluster")
    parser_submit.add_argument(
        "project", type=str, help="path to the project directory")
    parser_submit.add_argument(
        "-f", "--jobfile", type=str,
        help="path to the AdaptDLJob resource file")
    parser_submit.add_argument(
        "-d", "--dockerfile", type=str, help="path to the Dockerfile")
    parser_submit.add_argument(
        "-n", "--name", type=str, help="name to identify the AdaptDLjob")
    parser_submit.add_argument(
        "--tensorboard", type=str, help="name of tensorboard instance")
    parser_submit.add_argument(
        "--checkpoint-storage-size", type=str,
        help="size of new temporary pvc to create", default="1Gi")
    parser_submit.add_argument(
        "--checkpoint-storage-class", type=str,
        help="storage class of new temporary pvc to create", default=None)
    parser_submit.set_defaults(handler=submit)

    parser_logs = subparsers.add_parser(
        "logs", help="view the logs of an AdaptDLJob")
    parser_logs.add_argument(
        "jobname", type=str, help="name of the AdaptDLJob")
    parser_logs.set_defaults(handler=logs)

    parser_ls = subparsers.add_parser(
        "ls",
        help="View the status of AdaptDLJobs in the format"
             " (Name, Status, Start Time, Runtime, Replicas (Rplc),"
             "  Number of nodes (Nodes), Speedup (Spdp), Restarts (Rtrt)")
    parser_ls.set_defaults(handler=ls)

    parser_cp = subparsers.add_parser(
        "cp", help="copy a file from the AdaptDL cluster to the local client")
    parser_cp.add_argument(
        "source", type=str,
        help="source for the copy. use kubectl cp conventions"
             ", e.g. <job-name>:<remote-path>")
    parser_cp.add_argument(
        "destination", type=str,
        help="destination path for the copy on the client")
    parser_cp.set_defaults(handler=cp)

    parser_tensorboard = subparsers.add_parser(
        "tensorboard", help="manage tensorboard instances")
    add_tensorboard_commands(parser_tensorboard)

    if "--" in sys.argv:
        remaining = []
        idx = sys.argv.index("--")
        sys.argv.pop(idx)
        while len(sys.argv) > idx:
            remaining.append(sys.argv.pop(idx))
        args = parser.parse_args()
    else:
        args, remaining = parser.parse_known_args()

    kubernetes.config.load_kube_config()
    try:
        args.handler(args, remaining)
    except KeyboardInterrupt:
        pass
    except kubernetes.client.rest.ApiException as exc:
        result = json.loads(exc.body)
        print("{} ({}): {}".format(result.get("status"), result.get("reason"),
                                   result.get("message")), file=sys.stderr)
        exit(1)
