#! /usr/bin/env python

import os
import sys
import time
import shutil
import argparse
from pathlib import Path

from libensemble.version import __version__
from libensemble.resources import node_resources

try:
    from tqdm.auto import tqdm
except ModuleNotFoundError:
    print(f"*** libEnsemble {__version__} ***")
    print("\ntqdm is not installed, but this only matters if libesubmit can't find your calling script.\n")
    print("\ntqdm can be installed via:\n")
    print("     pip install tqdm")

try:
    from psij import JobExecutor, Import, Export, JobSpec, Job
    from psij.resource_spec import ResourceSpecV1
    from psij.job_attributes import JobAttributes
except ModuleNotFoundError:
    print(f"*** libEnsemble {__version__} ***")
    print("\nThe PSI/J Python interface is not installed. Please install it via the following:\n")
    print("     git clone https://github.com/ExaWorks/psi-j-python.git")
    print("     cd psi-j-python; pip install -e .\n")

if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        prog="libesubmit",
        description="Submit a libEnsemble PSI/J job representation for execution. Additional options may overwrite the input file.",
        conflict_handler="resolve",
    )

    choices = {
        "cobalt": "aprun",
        "local": "mpirun",
        "flux": "mpirun",
        "lsf": "jsrun",
        "pbspro": "mpirun",
        "rp": "mpirun",
        "slurm": "srun",
    }

    parser.add_argument("-s", "--scheduler", choices=choices.keys(), required=True)

    parser.add_argument(
        "-w",
        "--wait",
        action="store_true",
        help="Wait for Job to complete before exiting.",
    )

    parser.add_argument(
        "--dry",
        action="store_true",
        help="Parameterize and re-serialize a Job, without submitting.",
    )

    parser.add_argument(
        "-n", "--nnodes", type=int, nargs="?", help="Number of nodes", default=1
    )

    parser.add_argument(
        "-p",
        "--python-path",
        type=Path,
        nargs="?",
        help="Which Python to use. Default is current Python.",
        default=sys.executable,
    )

    parser.add_argument(
        "-q", "--queue", type=str, nargs="?", help="Scheduler queue name.", default=None
    )

    parser.add_argument(
        "-A",
        "--project",
        type=str,
        nargs="?",
        help="Scheduler project name.",
        default=None,
    )

    parser.add_argument(
        "-t",
        "--wallclock",
        type=int,
        nargs="?",
        help="Total wallclock for job. Default is 30 minutes.",
        default=30,
    )

    parser.add_argument(
        "-d",
        "--directory",
        type=Path,
        nargs="?",
        help="Working directory for job. Default is current directory.",
        default=os.getcwd(),
    )

    jobargs, unknown = parser.parse_known_args(sys.argv[1:])

    script = sys.argv[1]
    if not script.endswith(".json"):
        parser.print_help()
        sys.exit("First argument doesn't appear to be a .json file.")

    print(f"*** libEnsemble {__version__} ***")
    print(f"Imported PSI/J serialization: {script}. Preparing submission...")

    importer = Import()
    jobspec = importer.load(script)
    assert isinstance(jobspec, JobSpec), "Invalid input file."

    jobspec.directory = str(jobargs.directory)
    jobspec.attributes.project_name = jobargs.project
    jobspec.attributes.queue_name = jobargs.queue
    if jobspec.executable == "python":
        jobspec.executable = str(jobargs.python_path)
    jobspec.attributes.duration = jobargs.wallclock
    if jobspec.resources["node_count"] == 1:
        jobspec.resources["node_count"] = jobargs.nnodes

    # we enforced passing a python script in liberegister
    callscript = [i for i in jobspec.arguments if str(i).endswith(".py")][0]
    print(f"Calling script: {callscript}")

    if callscript not in os.listdir(jobargs.directory) and not os.path.isfile(
        callscript
    ):
        print("... not found in Job working directory!")
        exit = input("Check somewhere else? (Y/N): ")
        if exit.upper() != "Y":
            print("Exiting")
            sys.exit()

        home = os.path.expanduser("~")
        check_dirs = []
        for i in os.listdir(home):
            if os.path.isdir(os.path.join(home, i)) and "." not in i:
                check_dirs.append(i)

        print(home + ":")
        for i in enumerate(check_dirs):
            print(f"   {i[0]+1}. /{i[1]}")

        inchoice = input("Specify a starting directory: ")
        choice = home + "/" + check_dirs[int(inchoice)-1]

        def walkdir(folder):
            """Walk through every file in a directory"""
            for dirpath, dirs, files in os.walk(folder, topdown=True):
                for filename in files:
                    yield os.path.abspath(os.path.join(dirpath, filename))

        print("preparing... ctrl+c to abort.")
        filescount = 0
        for _ in tqdm(walkdir(choice)):
            filescount += 1

        print("detecting... ctrl+c to abort.")
        print(home + ":")
        candidate_script_paths = []
        try:
            for filepath in tqdm(walkdir(choice), total=filescount):
                if callscript in filepath.split("/"):
                    candidate_script_paths.append(filepath)
                    tqdm.write(
                        f"   {len(candidate_script_paths)}. {filepath.split(choice)[1]}"
                    )

            exit = input("Specify a detected script: ")
            new_callscript = candidate_script_paths[int(exit) - 1]

        except KeyboardInterrupt:
            exit = input(
                "detection interrupted. ctrl+c again to exit, or specify a detected script: "
            )
            new_callscript = candidate_script_paths[int(exit) - 1]

        jobspec.arguments[jobspec.arguments.index(callscript)] = new_callscript

    else:
        print("...found! Proceeding.")

    # Little bit strange I have to re-initialize this class to re-serialize
    if not jobspec.resources[
        "node_count"
    ]:  # running with MPI - need corresponding executor
        jobspec.resources = ResourceSpecV1(
            process_count=jobspec.resources["process_count"],
            processes_per_node=1,
            cpu_cores_per_process=64
        )
        jobspec.launcher = choices[jobargs.scheduler]
    else:
        jobspec.resources = ResourceSpecV1(node_count=jobspec.resources["node_count"])

    jex = JobExecutor.get_instance(jobargs.scheduler)
    job = Job()

    if job.id.split("-")[0] in script:
        reserialdest = script
    else:
        reserialdest = job.id.split("-")[0] + "." + script

    stdout_path = job.id.split("-")[0] + "." + script.replace("json", "out")
    stderr_path = job.id.split("-")[0] + "." + script.replace("json", "err")
    jobspec.stdout_path = stdout_path
    jobspec.stderr_path = stderr_path

    Export().export(obj=jobspec, dest=reserialdest)

    job.spec = jobspec

    if not jobargs.dry:
        print("Submitting Job!:", job)
        jex.submit(job)

    if jobargs.wait:
        print("Waiting on Job completion...")
        job.wait()
