#!/usr/bin/env python

"""
to run all steps in workflow do

scrinet_make_workflow --config-file config.ini

to make a dag to only fit and evaluate then you can parse --[exename]-exe
flags

scrinet_make_workflow --config-file config.ini --fit-exe --evaluate-exe
"""

from configparser import ConfigParser, ExtendedInterpolation
import os
from os.path import abspath, join
import argparse
import shutil
from scrinet.workflow import condor
from scrinet.workflow.pipe_utils import init_logger

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description="""Condor dag generator. If 'exe' options are parsed then
        all exe's will be generated. If a subset are parsed then only those
        will be generated.
        """
    )

    parser.add_argument("--config-file", type=str,
                        help="path to workflow ini file")

    parser.add_argument("--genwf-exe", help="if parsed then add waveform generation to dag",
                        action="store_true")
    parser.add_argument("--rb-exe", help="if parsed then add reduced basis computation to dag",
                        action="store_true")
    parser.add_argument("--ts-exe", help="if parsed then add training set computation to dag",
                        action="store_true")
    parser.add_argument("--fit-exe", help="if parsed then add fit computation to dag",
                        action="store_true")
    parser.add_argument("--evaluate-exe", help="if parsed then add evaluation computation to dag",
                        action="store_true")
    parser.add_argument("--webpage-exe", help="if parsed then add webpage construction to dag",
                        action="store_true")

    parser.add_argument("--plot-dag", help="make a plot of the dags",
                        action="store_true")

    parser.add_argument("--force", help="force make directories",
                        action="store_true")

    args = parser.parse_args()

    logger = init_logger()

    check_list = [args.genwf_exe, args.rb_exe, args.ts_exe,
                  args.fit_exe, args.evaluate_exe, args.webpage_exe]

    # if all False then make them true
    if not any(check_list):
        args.genwf_exe = True
        args.rb_exe = True
        args.ts_exe = True
        args.fit_exe = True
        args.evaluate_exe = True
        args.webpage_exe = True

    logger.info("workflow exe status:")
    logger.info(f"genwf: {args.genwf_exe}")
    logger.info(f"rb: {args.rb_exe}")
    logger.info(f"ts: {args.ts_exe}")
    logger.info(f"fit: {args.fit_exe}")
    logger.info(f"evaluate: {args.evaluate_exe}")
    logger.info(f"webpage_exe: {args.webpage_exe}")

    if args.force:
        exist_ok = True
    else:
        exist_ok = False
    logger.info(f"exist_ok = {exist_ok}")

    config = ConfigParser(interpolation=ExtendedInterpolation())
    config.optionxform = str
    config.read(args.config_file)

    logger.info("determining what data to model")
    data1 = config.get("workflow", "data1")
    data2 = config.get("workflow", "data2")
    datas = [data1, data2]

    try_extra_data = ['data3', 'data4', 'data5']
    for try_extra in try_extra_data:
        try:
            data_ = config.get("workflow", try_extra)
            datas.append(data_)
        except:
            pass
    logger.info(f"datas: {datas}")

    root_data_output_dir = config.get("workflow", "root-data-output-dir")
    logger.info(f"root_data_output_dir: {root_data_output_dir}")
    logger.info(f"root_data_output_dir: {root_data_output_dir}")
    os.makedirs(f"{root_data_output_dir}", exist_ok=exist_ok)
    logger.info(f"copying config-file to {root_data_output_dir}")
    shutil.copyfile(args.config_file, join(
        root_data_output_dir, args.config_file))

    log_dir = config.get("workflow", "log_dir")
    logger.info(f"Making log_dir: {log_dir}")
    os.makedirs(f"{log_dir}", exist_ok=exist_ok)
    logger.info(f"copying config-file to {log_dir}")
    shutil.copyfile(args.config_file, join(log_dir, args.config_file))
    submit = abspath(join(log_dir, 'submit'))

    logger.info("setting up maindag")
    dagman = condor.Dagman(name='maindag', submit=submit)

    if args.genwf_exe:

        logger.info("setting up genwf-seed")
        name = 'subdag-genwf-seed'
        logger.info(f"creating subdag: {name}")
        subdag_wfgen_seed = condor.Dagman(
            name=name, submit=submit, dag=dagman
        )

        genwf_seed_job = condor.GenWaveformJob(config, "genwf", "seed")
        subdag_wfgen_seed.add_job(genwf_seed_job)

        subdag_wfgen_dict = {"train": None, "validation": None, "test": None}
        set_names = subdag_wfgen_dict.keys()
        for set_name in set_names:
            name = f'subdag-genwf-{set_name}'
            logger.info(f"creating subdag: {name}")
            subdag_wfgen_dict[set_name] = condor.Dagman(
                name=name, submit=submit, dag=dagman
            )

            combine_set_job = condor.CombineWaveformJob(
                config, "wfcombine", set_name)
            subdag_wfgen_dict[set_name].add_job(combine_set_job)

            n_jobs = int(config.get(f"genwf-{set_name}-split", "n_jobs"))
            leading_zeros = len(str(n_jobs))

            for n in range(n_jobs):
                # nice leading zeros that generalises to any number of n_jobs
                zstr = str(n).zfill(leading_zeros)

                genwf_set_job = condor.GenWaveformJob(
                    config, "genwf", set_name, idx=zstr)
                subdag_wfgen_dict[set_name].add_job(genwf_set_job)
                combine_set_job.add_parent(genwf_set_job)

    if args.rb_exe:

        logger.info("setting up rb (reduced basis)")

        subdag_rb_dict = {}
        for data in datas:
            subdag_rb_dict.update({data: None})
            name = f'subdag-rb-{data}'
            logger.info(f"creating subdag: {name}")
            subdag_rb_dict[data] = condor.Dagman(
                name=name, submit=submit, dag=dagman
            )
            rb_job = condor.BuildReducedBasisJob(config, "rb", data)
            subdag_rb_dict[data].add_job(rb_job)

            # if generating waveforms part of dag then reduced basis depends on them
            if args.genwf_exe:
                subdag_rb_dict[data].add_parents(
                    [subdag_wfgen_seed, subdag_wfgen_dict['train']])

    if args.ts_exe:

        logger.info("setting up ts (training set)")
        set_names = ['train', 'val']
        subdag_ts_dict = {}
        for data in datas:
            subdag_ts_dict.update({data: {}})
            for set_name in set_names:
                subdag_ts_dict[data].update({set_name: None})
                name = f'subdag-ts-{data}-{set_name}'
                logger.info(f"creating subdag: {name}")
                subdag_ts_dict[data][set_name] = condor.Dagman(
                    name=name, submit=submit, dag=dagman
                )
                ts_job = condor.GenTrainingSetJob(config, "ts", data, set_name)
                subdag_ts_dict[data][set_name].add_job(ts_job)

                if set_name == 'train':
                    if args.rb_exe:
                        subdag_ts_dict[data][set_name].add_parent(
                            subdag_rb_dict[data])
                elif set_name == 'val':
                    if args.rb_exe:
                        subdag_ts_dict[data][set_name].add_parent(
                            subdag_rb_dict[data])
                    if args.genwf_exe:
                        subdag_ts_dict[data][set_name].add_parent(
                            subdag_wfgen_dict["validation"])

    if args.fit_exe:

        # finally onto the fitting
        logger.info("setting up fit")
        subdag_fit_dict = {}
        for data in datas:
            subdag_fit_dict.update({data: None})
            name = f'subdag-fit-{data}'
            logger.info(f"creating subdag: {name}")
            subdag_fit_dict[data] = condor.Dagman(
                name=name, submit=submit, dag=dagman
            )
            fit_job = condor.FitJob(config, "fit", data)
            subdag_fit_dict[data].add_job(fit_job)
            if args.ts_exe:
                subdag_fit_dict[data].add_parents(
                    [subdag_ts_dict[data]['train'], subdag_ts_dict[data]['val']])

    if args.evaluate_exe:

        logger.info("setting up evaluate")
        subdag_evaluate_dict = {}
        set_names = ["train", "validation", "test"]
        for set_name in set_names:
            subdag_evaluate_dict.update({set_name: None})
            name = f'subdag-eval-{set_name}'
            logger.info(f"creating subdag: {name}")
            subdag_evaluate_dict[set_name] = condor.Dagman(
                name=name, submit=submit, dag=dagman
            )
            eval_job = condor.EvaluateJob(config, "evaluate", set_name)
            subdag_evaluate_dict[set_name].add_job(eval_job)
            if set_name in ["train", "validation"]:
                if args.fit_exe:
                    subdag_evaluate_dict[set_name].add_parents(
                        [subdag_fit_dict[data] for data in datas])
            elif set_name == "test":
                if args.fit_exe:
                    subdag_evaluate_dict[set_name].add_parents(
                        [subdag_fit_dict[data] for data in datas])
                if args.genwf_exe:
                    subdag_evaluate_dict[set_name].add_parent(
                        subdag_wfgen_dict["test"])

    if args.webpage_exe:

        logger.info("setting up webpage")
        name = f'subdag-webpage'
        logger.info(f"creating subdag: {name}")
        subdag_webpage = condor.Dagman(
            name=name, submit=submit, dag=dagman
        )
        webpage_job = condor.MakeWebpageJob(config, "webpage")
        subdag_webpage.add_job(webpage_job)
        if args.evaluate_exe:
            for set_name in subdag_evaluate_dict.keys():
                subdag_webpage.add_parent(subdag_evaluate_dict[set_name])

    if args.plot_dag:
        logger.info("plotting dags")
        dag_plot_dir = join(log_dir, "dag_plot")
        logger.info(f"Making output plot directory: {dag_plot_dir}")
        os.makedirs(f"{dag_plot_dir}", exist_ok=exist_ok)
        dagman.visualize(join(dag_plot_dir, f'{dagman.name}-workflow.png'))
        for node in dagman.nodes:
            node.visualize(join(dag_plot_dir, f'{node.name}-workflow.png'))

    logger.info("building dagman")
    dagman.build()

    logger.info("to submit dag:")
    logger.info(f"condor_submit_dag {dagman.submit_file}")
