"""Monitor SLURM jobs and status"""

import sys
import time
import argparse
import datetime
from typing import List
import numpy

from .cli import common as common_cli
from .cli import status as status_cli
from ..client.script import SlurmScriptApi


def main(argv=None):
    if argv is None:
        argv = sys.argv

    parser = argparse.ArgumentParser(
        description="Slurm Job Monitor", prog="slurm-monitor"
    )
    subparsers = parser.add_subparsers(help="Commands", dest="command")

    check = subparsers.add_parser("check", help="Check API")
    common_cli.add_parameters(check)

    status = subparsers.add_parser("status", help="Job status")
    common_cli.add_parameters(status)
    status_cli.add_parameters(status)

    args = parser.parse_args(argv[1:])

    if args.command == "status":
        command_status(args)
    elif args.command == "check":
        command_check(args)
    else:
        parser.print_help()
    return 0


def command_status(args):
    common_cli.apply_parameters(args)
    status_cli.apply_parameters(args)

    with SlurmScriptApi(
        args.url,
        args.user_name,
        args.token,
        log_directory=args.log_directory,
    ) as api:
        for _ in _monitor_loop(args.interval):
            _print_jobs(api, args.jobid, args.all)
            if args.jobid:
                api.print_stdout_stderr(args.jobid)


def command_check(args):
    common_cli.apply_parameters(args)

    with SlurmScriptApi(
        args.url,
        args.user_name,
        args.token,
        log_directory=args.log_directory,
    ) as api:
        assert api.server_has_api(), "Wrong Rest API version"


def _monitor_loop(interval):
    try:
        if not interval:
            yield
            return
        while True:
            yield
            time.sleep(interval)
    except KeyboardInterrupt:
        pass


def _print_jobs(api, jobid, all_users):
    if jobid:
        jobs = [api.get_job_properties(jobid)]
    else:
        if all_users:
            filter = {"user_name": None}
        else:
            filter = None
        jobs = api.get_all_job_properties(filter=filter)
    fields = {
        "job_id": _passthrough,
        "name": _passthrough,
        "job_state": _passthrough,
        "user_name": _passthrough,
        "submit_time": _duration,
        "start_time": _duration,
        "tres_alloc_str": _passthrough,
    }
    titles = {
        "job_id": "Job ID",
        "job_state": "State",
        "user_name": "User",
        "submit_time": "Submit time",
        "start_time": "Run time",
        "tres_alloc_str": "Resources",
        "name": "Name",
    }
    rows = list()
    for info in jobs:
        rows.append([parser(info[k]) for k, parser in fields.items()])
    if not rows:
        return
    titles = [titles.get(k, k) for k in fields]
    print(_format_info(titles, rows))


def _passthrough(x):
    return str(x)


def _duration(x):
    if x == 0:
        return "-"
    duration = datetime.datetime.now() - datetime.datetime.fromtimestamp(x)
    if duration.total_seconds() < 0:
        return "-"
    return str(duration)


def _format_info(titles: List[str], rows: List[List[str]]):
    lengths = numpy.array([[len(s) for s in row] for row in rows])
    fmt = "   ".join(["{{:<{}}}".format(n) for n in lengths.max(axis=0)])
    infostr = "\n "
    infostr += fmt.format(*titles)
    infostr += "\n "
    infostr += "\n ".join([fmt.format(*row) for row in rows])
    return infostr


if __name__ == "__main__":
    sys.exit(main())
