#! /usr/bin/env python

import os
import sys
import shutil
import argparse
from pathlib import Path
from libensemble.version import __version__
from libensemble.tools.parse_args import parser as callscript_parser

try:
    from psij import Job, JobSpec
    from psij.resource_spec import ResourceSpecV1
    from psij.job_attributes import JobAttributes
    from psij.serialize import Export
except ModuleNotFoundError:
    print(f"*** libEnsemble {__version__} ***")
    print("\nThe PSI/J Python interface is not installed. Please install it via the following:\n")
    print("     git clone https://github.com/ExaWorks/psi-j-python.git")
    print("     cd psi-j-python; pip install -e .\n")

if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        prog="liberegister",
        description="Produce a PSI/J representation for a libEnsemble execution.",
        epilog="Output representations can be passed to `libesubmit`",
        parents=[callscript_parser],
        conflict_handler="resolve",
    )

    parser.add_argument("calling_script", nargs="?")

    parser.add_argument(
        "-o",
        "--outfile",
        type=str,
        nargs="?",
        help="Output PSI/J representation filename.",
        default="libe-job.json",
    )

    parser.add_argument(
        "-n", "--nnodes", type=int, nargs="?", help="Number of nodes", default=1
    )

    parser.add_argument(
        "-p",
        "--python-path",
        type=Path,
        nargs="?",
        help="Which Python to use.",
        default="python",
    )

    choices = ["cobalt", "local", "flux", "lsf", "pbspro", "rp", "slurm"]

    parser.add_argument(
        "-s",
        "--scheduler",
        choices=choices,
        help="Which scheduler to use.",
        default=None,
    )

    parser.add_argument(
        "-j",
        "--jobname",
        type=str,
        nargs="?",
        help="Scheduler job name.",
        default="libe-job",
    )

    parser.add_argument(
        "-q", "--queue", type=str, nargs="?", help="Scheduler queue name.", default=None
    )

    parser.add_argument(
        "-A",
        "--project",
        type=str,
        nargs="?",
        help="Project name for billing hours.",
        default=None,
    )

    parser.add_argument(
        "-t",
        "--wallclock",
        type=int,
        nargs="?",
        help="Total wallclock for job.",
        default=30,
    )

    parser.add_argument(
        "-d",
        "--directory",
        type=Path,
        nargs="?",
        help="Working directory for job.",
        default=None,
    )

    jobargs, unknown = parser.parse_known_args(sys.argv[1:])

    if not jobargs.calling_script:
        parser.print_help()
        sys.exit(
            "\nMust supply a calling script, with the --comms and --nworkers options"
        )

    if not jobargs.calling_script.endswith(".py"):
        parser.print_help()
        sys.exit("\nFirst argument doesn't appear to be a Python script.")

    basename = jobargs.calling_script.split(".py")[0]
    outfile_default = basename + ".json"

    executable = jobargs.python_path

    if jobargs.comms == "local":
        arguments = [
            jobargs.calling_script,
            "--comms",
            jobargs.comms,
        ]

        if jobargs.nworkers:
            arguments.extend(["--nworkers", str(jobargs.nworkers)])

        resources = ResourceSpecV1(node_count=jobargs.nnodes)
    else:  # jobargs.comms == "mpi":
        arguments = [jobargs.calling_script]
        resources = ResourceSpecV1(
            process_count=jobargs.nworkers + 1, processes_per_node=1
        )

    if jobargs.nsim_workers:
        arguments.extend(['--nsim_workers', str(jobargs.nsim_workers)])

    if jobargs.nresource_sets:
        arguments.extend(['--nresource_sets', str(jobargs.nresource_sets)])

    jobspec = JobSpec(
        name=jobargs.jobname,
        executable=str(executable),
        arguments=arguments,
        directory=jobargs.directory,
        environment={"PYTHONNOUSERSITE": "1"},
        resources=resources,
        attributes=JobAttributes(
            duration=jobargs.wallclock,
            queue_name=jobargs.queue,
            project_name=jobargs.project,
        ),
    )

    Export().export(obj=jobspec, dest=outfile_default)
    print(f"*** libEnsemble {__version__} ***")
    print(
        f"Exported PSI/J serialization: {outfile_default}\nOptionally adjust any fields, or specify job attributes on submission to `libesubmit`."
    )
