#!/usr/bin/env python
# Copyright (c) 2020 The Regents of the University of Michigan
# All rights reserved.
# This software is licensed under the BSD 3-Clause License.
"""A simple scheduler implementation that works with signac-flow.

The package signac-flow includes the ``simple-scheduler`` script as a simple
model of a cluster job scheduler. The ``simple-scheduler`` script is designed
primarily for testing and demonstration.
"""
import argparse
import json
import logging
import os
import shutil
import sys
import threading
import time
import uuid
from contextlib import contextmanager

from signac import Collection

from flow.scheduling.base import JobStatus

logger = logging.getLogger(__name__)


FMT_STATUS = {
    JobStatus.inactive: "I",
    JobStatus.queued: "Q",
    JobStatus.active: "A",
}


def _get_submit_parser(parser_submit=None):
    if parser_submit is None:
        parser_submit = argparse.ArgumentParser()
    parser_submit.add_argument("--job-name", required=True, help="The name of the job.")
    parser_submit.add_argument(
        "-D", "--chdir", help="Change to this directory prior to execution."
    )
    return parser_submit


def _get_args(script):
    for line in script.readlines():
        if line.startswith("#SSCHED"):
            yield from line[8:].split()


def main_submit(args):
    """Submit jobs.

    Jobs are accepted from a provided filename and copied to a unique name in
    the scheduler's inbox.
    """
    # Try to parse args, should raise error if anything is wrong or missing
    with open(args.filename) as script:
        _get_submit_parser().parse_args(list(_get_args(script)))
    dst = os.path.join(args.inbox, str(uuid.uuid4()) + ".sh")
    shutil.copyfile(args.filename, dst)


@contextmanager
def _lock_database(args):
    fn_lock = args.db + ".lock"
    remove_lock = True
    try:
        with open(fn_lock, mode="x"):
            with Collection.open(args.db) as db:
                yield db
    except FileExistsError:
        remove_lock = False
        raise RuntimeError(
            "Unable to lock database file '{}'. Is another process already "
            "running?".format(args.db)
        )
    finally:
        if remove_lock:
            try:
                os.unlink(fn_lock)
            except Exception:
                pass


def _list_files_by_mtime(path):
    files = os.listdir(path)
    return sorted(files, key=lambda fn: os.path.getmtime(os.path.join(path, fn)))


def _process_inbox_loop(inbox, queue, db, stop):
    parser = _get_submit_parser()

    while True:
        logger.info("Processing inbox...")
        for fn in _list_files_by_mtime(args.inbox):
            src = os.path.join(inbox, fn)
            dst = os.path.join(queue, fn)
            os.rename(src, dst)
            with open(dst) as script:
                submit_args = parser.parse_args(list(_get_args(script)))
            doc = vars(submit_args)
            doc["script"] = dst
            doc["status"] = int(JobStatus.queued)
            doc["_queued"] = time.time()
            _id = db.insert_one(doc)
            logger.info(f"Queued '{_id}'.")
        db.flush()
        if stop.wait(timeout=5):
            break


def _process_queue(args, db):
    from subprocess import CalledProcessError, check_call

    logger.info("Processing queue...")
    db.delete_many({"_delete_after.$lt": time.time()})

    docs = db.find({"status": int(JobStatus.queued)})
    for doc in sorted(docs, key=lambda doc: doc["_id"]):
        break
    else:
        logger.info("No jobs...")
        return

    doc["status"] = int(JobStatus.active)
    db[doc["_id"]] = doc
    db.flush()

    cwd = os.getcwd()
    try:
        logger.info("Executing job '{}' ({})...".format(doc["job_name"], doc["_id"]))
        try:
            chdir = doc.get("chdir", "").strip('"') or os.path.expanduser("~")
            os.chdir(chdir)
        except FileNotFoundError as error:
            logger.warning(error)
        for i in range(100):
            fn_out = "{:05x}.out.{}".format(int(doc["_id"], 16), i)
            if not os.path.exists(fn_out):
                break
        else:
            assert False  # This point should not be reached.
        for i in range(100):
            fn_err = "{:05x}.err.{}".format(int(doc["_id"], 16), i)
            if not os.path.exists(fn_err):
                break
        else:
            assert 0

        cmd = "/bin/bash " + doc["script"]

        with open(fn_out, "w") as outfile:
            with open(fn_err, "w") as errfile:
                check_call(cmd, shell=True, stdout=outfile, stderr=errfile)
    except CalledProcessError:
        logger.warning("Error while executing job '{}'.".format(doc["_id"]))
    finally:
        os.chdir(cwd)
        os.remove(doc["script"])
        doc["status"] = int(JobStatus.inactive)
        doc["_delete_after"] = time.time() + 2 * 60  # remove after 2 mins
        db[doc["_id"]] = doc
        db.flush()


def main_run(args):
    """Run the scheduler.

    The scheduler runs multiple threads. A separate thread moves jobs from the
    inbox to a queue. The main thread processes jobs in the queue.
    """
    print("Start scheduler...")
    os.makedirs(args.inbox, exist_ok=True)
    os.makedirs(args.queue, exist_ok=True)
    with _lock_database(args) as db:
        print("Execute this to enable environment detection for this scheduler:")
        print(f'export SIMPLE_SCHEDULER="{sys.argv[0]} --data={args.data}"')
        stop = threading.Event()
        process_inbox_thread = threading.Thread(
            target=_process_inbox_loop,
            kwargs=dict(inbox=args.inbox, queue=args.queue, db=db, stop=stop),
        )
        try:
            process_inbox_thread.start()
            while True:
                _process_queue(args, db)
                time.sleep(5)
        except KeyboardInterrupt:
            stop.set()
            process_inbox_thread.join()
            print("Stopping...")


def main_status(args):
    """Get job status.

    This returns the status of all jobs from the scheduler's database.
    """
    with Collection.open(args.db, mode="r") as db:
        if args.json:
            print(json.dumps({doc["_id"]: doc for doc in db}, indent=4))
        else:
            for _id in sorted(db.ids):
                doc = db[_id]
                print(
                    "{}\t{}\t{}".format(
                        doc["_id"],
                        doc["job_name"],
                        FMT_STATUS[JobStatus(doc["status"])],
                    )
                )


if __name__ == "__main__":
    default_home_data = os.environ.get(
        "SIMPLE_SCHEDULER_DATA_HOME",
        os.path.expanduser("~/.local/share/simple-scheduler"),
    )

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data",
        default=default_home_data,
        help="Specify the path to the scheduler's data directory. "
        "Default: {}".format(default_home_data),
    )
    parser.add_argument(
        "--max-queue-size",
        type=int,
        default=10000,
        help="Specify the maximum number of jobs that can be queued.",
    )

    subparsers = parser.add_subparsers()

    parser_submit = subparsers.add_parser(
        "submit", description="Submit a job script for execution to the scheduler."
    )
    parser_submit.add_argument(
        "filename", help="The path to the script to submit for execution."
    )
    parser_submit.set_defaults(func=main_submit)

    parser_status = subparsers.add_parser(
        "status", description="Display the status of jobs submitted to the scheduler."
    )
    parser_status.add_argument("--json", action="store_true")
    parser_status.set_defaults(func=main_status)

    parser_run = subparsers.add_parser(
        "run", description="Execute jobs submitted to the scheduler."
    )
    parser.add_argument(
        "--polling-period",
        default=10,
        help="Check the database for new entries every given seconds.",
    )
    parser_run.set_defaults(func=main_run)

    args = parser.parse_args()
    if not hasattr(args, "func"):
        parser.print_usage()
        sys.exit(2)

    args.inbox = os.path.join(args.data, "inbox")
    args.queue = os.path.join(args.data, "queue")
    args.db = os.path.join(args.data, "db.txt")

    logging.basicConfig(level=logging.INFO)
    args.func(args)
