#!/usr/bin/env python

import sys
import os
import pathlib
import argparse
import configparser
from importlib.util import spec_from_file_location, module_from_spec
import subprocess


_TODO_YAML = 'todo.yaml'


def _parse_todo_list(fname):
    with open(fname, 'r') as f:
        import yaml
        todo_list = yaml.full_load(f)
    return todo_list['jobs']


def run(args):

    todo_list = _parse_todo_list(args.todo_list)

    global_argv = []

    if (args.path_prefix is not None) and (len(args.path_prefix) > 0):
        global_argv += ['--path-prefix',  args.path_prefix]
    
    if (args.overwrite_output is not None):
        maybe_no = '' if args.overwrite_output else 'no-'
        global_argv += [f'--{maybe_no}overwrite-output']

    if (args.test is not None):
        maybe_no = '' if args.test else 'no-'
        global_argv += [f'--{maybe_no}test']

    for item in todo_list:
        argv = ['submit', item.pop('jobmodule')]
        for k, v in item.items():
            argv += [f'--{k}', v]

        parser = _mk_argparser()
        args = parser.parse_args(global_argv + argv)
        print(f'[jor] Submitting job: {argv[1]}')
        args.subcommand(args)

        # in testing mode only submit first job-array
        if (args.test is not None) and args.test and (len(todo_list) > 1):
            print(f'[jor] Not submitting {len(todo_list) - 1} job-array(s) in test-mode')
            break


def submit(args):

    jobs = _init_jobs(args)

    if jobs.is_output_complete:
        print('[jor] All output files exist, nothing to do.')
        return

    # mk output folder if it doesn't exist yet
    pathlib.Path(jobs._get_output_folder()).mkdir(parents=True, exist_ok=True)

    command = f'{__file__}'

    if (args.path_prefix is not None) and (len(args.path_prefix) > 0):
        command += f' --path-prefix {args.path_prefix}'
    
    if (args.overwrite_output is not None):
        maybe_no = '' if args.overwrite_output else 'no-'
        command += f' --{maybe_no}overwrite-output'

    command += f' exec {args.jobmodule} {{i}} --jobargs "{args.jobargs}"'

    if (args.condaenv is not None) and (len(args.condaenv) > 0):
        command = f'source `conda info --base`/etc/profile.d/conda.sh; ' \
                  f'conda activate {args.condaenv}; {command}'

    if (args.sif is not None) and (len(args.sif) > 0):
        command = f'singularity exec --cleanenv {args.sif} {command}'

    n_jobs = len(jobs)
    # in testing mode only submit first job
    if (args.test is not None) and args.test and (n_jobs > 1):
        print(f'[jor] Not executing {n_jobs - 1} job(s) in test-mode')
        n_jobs = 1

    if args.scheduler == 'local':
        for i in range(n_jobs):
            subprocess.run(command.format(i=i), shell=True)

    elif args.scheduler == 'slurm':
        subprocess.run([
            'sbatch',
            '--job-name', jobs.name,
            '--cpus-per-task', str(jobs.cpus_per_task),
            '--mem', jobs.mem,
            '--time', jobs.time,
            '--partition', args.partition,
            '--array', f'0-{n_jobs - 1}',
            '--wrap', command.format(i='"$SLURM_ARRAY_TASK_ID"')
        ])

    else:
        raise ValueError(f'Invalid scheduler: {args.scheduler}')


def exec(args):

    jobs = _init_jobs(args)

    # check if job needs to be run or if output already exists:
    myargs = jobs._jobs[args.i]
    output_path = pathlib.Path(jobs._get_output_path(**myargs))
    output_exists = output_path.exists()

    if (not args.overwrite_output) and output_exists:
        print(f'[jor] STOPPING: Output file exists and overwriting not '
              f'selected: {output_path}')
        return
    
    # otherwise:
    try:
        if not output_exists:
            # note that a particular job is running, but don't update
            # modification time of an already existing output file
            output_path.touch()
        jobs.execute(args.i)
    except:
        if not output_exists:  # i.e. "touch" just created an empty file
            output_path.unlink()
        # else: in case of error don't delete an already existing output 
        # that was supposed to be overwritten
        raise

    if not os.path.exists(output_path):
        print(f'[jor] Execution finished but output file missing: {output_path}')


def collect(args):

    if args.jobmodule == '[parse todo.yaml]':
        todo_list = _parse_todo_list(args.todo_list)

    else:
        todo_list = [dict(jobmodule=args.jobmodule, jobargs=args.jobargs)]

    for item in todo_list:

        print(f"[jor] Collecting outputs for {item['jobmodule']} ({item['jobargs']})")

        args.jobmodule = item['jobmodule']
        args.jobargs = item['jobargs']
        jobs = _init_jobs(args)

        # check and handle missing output files
        outputs = jobs.output_paths
        existing_outputs = [output for output in outputs if os.path.exists(output)]
        if len(existing_outputs) < len(outputs):
            missing_outputs = set(outputs).difference(set(existing_outputs))
            if args.missing_output == 'raise':
                raise FileNotFoundError("Some output files couldn't be found: "
                                        f"{missing_outputs}")
            elif args.missing_output == 'ignore':
                if len(existing_outputs) == 0:
                    raise FileNotFoundError(f"No output file found. Expected "
                                            f"outputs: {outputs}")
                else:
                    print(f'[jor] Continuing without the following missing output '
                        f'files: {missing_outputs}')
            else:
                raise ValueError(f'Invalid argument for "missing-outputs", must '
                                f'be "raise" or "ignore", got '
                                f'{args.missing_output}')

        # collect outputs
        jobs.collect()


def status(args):
    
    todo_list = _parse_todo_list(args.todo_list)
    for item in todo_list:

        args.jobmodule = item['jobmodule']
        args.jobargs = item['jobargs']
        jobs = _init_jobs(args)

        outputs = jobs.output_paths
        missing_outputs, empty_outputs, complete_outputs = [], [], []
        for output in outputs:
            path = pathlib.Path(output)
            if not path.exists():
                missing_outputs.append(path.name)
            else:
                if path.stat().st_size == 0:
                    empty_outputs.append(path.name)
                else:
                    # assume non-empty files are complete
                    complete_outputs.append(path.name)

        # to not clutter output, don't print status of completely completed jobs
        if (len(missing_outputs) > 0) or (len(empty_outputs) > 0):
            print(f'[jor] {args.jobmodule} :: {args.jobargs}')
            print(f'      -> {jobs._get_output_folder()}')
            print(f'      {len(missing_outputs): >6} / {len(outputs)} missing')
            print(f'      {len(empty_outputs): >6} / {len(outputs)} running or dead')
            print(f'      {len(complete_outputs): >6} / {len(outputs)} completed')
            if len(missing_outputs) > 0:
                print(f'      Missing outputs: {missing_outputs}')
            print('')


def _init_jobs(args):

    if (args.jobargs is not None) and (len(args.jobargs) > 0):
        jobargs = args.jobargs.strip()
        if not jobargs.endswith(','):
            jobargs += ','
    else:
        jobargs = ''

    if args.path_prefix is None:
        path_prefix_arg = ''
    else:
        path_prefix_arg = f'path_prefix="{args.path_prefix}"'

    module_dir = os.path.dirname(args.jobmodule)
    if module_dir == '':
        module_dir = os.getcwd()
    sys.path.insert(0, module_dir)

    module_spec = spec_from_file_location('jobs_module', args.jobmodule)
    jobs_module = module_from_spec(module_spec)
    module_spec.loader.exec_module(jobs_module)
    jobs = eval(f"jobs_module.Jobs({jobargs}{path_prefix_arg})")
    return jobs


def _mk_argparser():

    # --- 1: check if there'a a config file ---
    # if config file exists its contents will be used as default values
    # for argparser
    cfgargs = configparser.ConfigParser()
    if os.path.exists('jor.cfg'):
        cfgargs.read('jor.cfg')

    # --- 2: now setup ArgumentParser ---

    parser = argparse.ArgumentParser(
        'jor', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '--path-prefix', help='path prefix for output files',
        type=_parse_path,
        default=cfgargs.get('global', 'path-prefix', fallback=None))
    parser.add_argument(
        '--todo-list', help='file containing todo list',
        default=cfgargs.get('run', 'todo-list', fallback='todo.yaml'))
    parser.add_argument(
        '--overwrite-output', dest='overwrite_output', action='store_true',
        help='overwrite existing output file')
    parser.add_argument(
        '--no-overwrite-output', dest='overwrite_output', action='store_false',)
    parser.set_defaults(overwrite_output=_str2bool(
        cfgargs.get('global', 'overwrite-output', fallback=False)))
    parser.add_argument('--test', dest='test', action='store_true',
        help='only execute first job')
    parser.add_argument('--no-test', dest='test', action='store_false')
    parser.set_defaults(test=False)
    
    subparsers = parser.add_subparsers(help='available subcommands')

    # run command
    subparser_run = subparsers.add_parser(
        'run', help='run todo list',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    subparser_run.set_defaults(subcommand=run)

    # submit command
    subparser_submit = subparsers.add_parser(
        'submit', help='submit job array',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    subparser_submit.add_argument(
        'jobmodule',
        help='Python file (without ".py") containing a class "Jobs"')
    subparser_submit.add_argument(
        '--jobargs', help='arguments for "Jobs" class', default="")
    subparser_submit.add_argument(
        '--scheduler', help='"local" or "slurm"',
        default=cfgargs.get('submit', 'scheduler', fallback='local'))
    subparser_submit.add_argument(
        '--partition', help='slurm partition',
        default=cfgargs.get('submit', 'partition', fallback='day'))
    subparser_submit.add_argument(
        '--sif', help='singularity file image if any',
        default=cfgargs.get('submit', 'sif', fallback=None))
    subparser_submit.add_argument(
        '--condaenv', help='conda environment if any',
        default=cfgargs.get('submit', 'condaenv', fallback=None))
    subparser_submit.set_defaults(subcommand=submit)

    # run command
    subparser_exec = subparsers.add_parser(
        'exec', help='execute a job',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    subparser_exec.add_argument(
        'jobmodule',
        help='Python file (without ".py") containing a class "Jobs"')
    subparser_exec.add_argument('i', type=int, help='array id')
    subparser_exec.add_argument('--jobargs', help='arguments for "Jobs" class')
    subparser_exec.set_defaults(subcommand=exec)

    # collect command
    subparser_collect = subparsers.add_parser(
        'collect', help='collect outputs from job array',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    subparser_collect.add_argument(
        'jobmodule', nargs='?', default="[parse todo.yaml]",
        help='Python file (without ".py") containing a class "Jobs"')
    subparser_collect.add_argument('--jobargs', help='arguments for "Jobs" class')
    subparser_collect.add_argument(
        '--missing-output',
        help='either "ignore" a missing output file or "raise" an error',
        default=cfgargs.get('collect', 'missing-output', fallback='ignore'))
    subparser_collect.set_defaults(subcommand=collect)

    # status command
    subparser_status = subparsers.add_parser(
        'status', help='show status of jobs',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    subparser_status.set_defaults(subcommand=status)

    return parser


def _parse_path(s):
    return os.path.expanduser(s)


def _str2bool(s):
    if isinstance(s, bool):
        return s
    elif s.lower() in ['0', 'false', 'no']:
        return False
    elif s.lower() in ['1', 'true', 'yes']:
        return True
    else:
        raise ValueError(f"Can't convert {s} to bool")


if __name__ == '__main__':

    parser = _mk_argparser()
    args = parser.parse_args()
    if hasattr(args, 'subcommand'):
        args.subcommand(args)
    else:
        parser.print_help()
