#!/usr/bin/env python
#
# Copyright (C) 2020 Prayush Kumar
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
"""Setup workflow to perform Bayesian parameter estimation runs on a
custom set of simulated signals"""

import os
import argparse
import logging
import tempfile
import numpy

from glue.pipeline import CondorDAGNode, CondorDAG
import pycbc
from pycbc import fft, opt, scheme
from pycbc.workflow import configuration

from gwnrtools.utils import mkdir
from gwnrtools.workflow.pycbc_inference import PycbcInferenceOnInjectionBatch
from gwnrtools.workflow.condor import InferenceJob
from gwnrtools import __version__

############################################################
# command line usage
parser = argparse.ArgumentParser(usage=__file__ + " [--options]",
                                 description=__doc__)
parser.add_argument("--version",
                    action="version",
                    version=__version__,
                    help="Prints version information.")
parser.add_argument("--verbose",
                    action="store_true",
                    default=False,
                    help="Print logging messages.")
# workflow options
parser.add_argument("--skip-creating-injections",
                    action="store_true",
                    help="Skip calling lalapps_inspinj and assume "
                    "injections already exist",
                    default=False)

# output options
parser.add_argument("--output-dir",
                    type=str,
                    required=False,
                    default='',
                    help="Output directory path.")
parser.add_argument("--force",
                    action="store_true",
                    default=False,
                    help="If the output-dir already exists, overwrite it. "
                    "Otherwise, an OSError is raised.")
parser.add_argument("--save-backup",
                    action="store_true",
                    default=False,
                    help="Don't delete the backup file after the run has "
                    "completed.")
# parallelization options
parser.add_argument("--nprocesses",
                    type=int,
                    default=1,
                    help="Number of processes to use. If not given then only "
                    "a single core will be used.")
parser.add_argument("--use-mpi",
                    action='store_true',
                    default=False,
                    help="Use MPI to parallelize the sampler")
parser.add_argument("--samples-file",
                    default=None,
                    help="Use an iteration from an InferenceFile as the "
                    "initial proposal distribution. The same "
                    "number of walkers and the same [variable_params] "
                    "section in the configuration file should be used. "
                    "The priors must allow encompass the initial "
                    "positions from the InferenceFile being read.")
parser.add_argument("--seed",
                    type=int,
                    default=0,
                    help="Seed to use for the random number generator that "
                    "initially distributes the walkers. Default is 0.")
# add config options
configuration.add_workflow_command_line_group(parser)
# add module pre-defined options
fft.insert_fft_option_group(parser)
opt.insert_optimization_option_group(parser)
scheme.insert_processing_option_group(parser)

# parse command line
opts = parser.parse_args()

# setup log
# If we're running in MPI mode, only allow the parent to print
if opts.use_mpi:
    from mpi4py import MPI
    rank = MPI.COMM_WORLD.Get_rank()
    opts.verbose &= rank == 0
pycbc.init_logging(opts.verbose)

# verify options are sane
fft.verify_fft_options(opts, parser)
opt.verify_optimization_options(opts, parser)
scheme.verify_processing_options(opts, parser)

# set seed
numpy.random.seed(opts.seed)
logging.info("Using seed %i", opts.seed)

# we'll silence numpy warnings since they are benign and make for confusing
# logging output
numpy.seterr(divide='ignore', invalid='ignore')

# Sort out the directory where all analyses are to be run
analyses_dir = os.getcwd()
if len(opts.output_dir) > 0:
    analyses_dir = opts.output_dir
if os.path.exists(analyses_dir):
    if not opts.force:
        raise IOError(
            "Output directory {} exists. Use --force to overwrite it.".format(
                analyses_dir))
os.chdir(analyses_dir)
logging.info("Will setup analyses in {0}".format(analyses_dir))

# get scheme
ctx = scheme.from_cli(opts)
fft.from_cli(opts)

with ctx:

    # read configuration file
    confs = configuration.WorkflowConfigParser.from_cli(opts)

    # Workspace dirs
    logging.info("Making workspace directories")
    for d in ['scripts', 'log', 'plots']:
        mkdir(d)

    # Get logdir
    try:
        log_path = confs.get("workflow", 'log-path')
    except:
        log_path = './'

    tempfile.tempdir = log_path
    tempfile.template = 'pycbc_inference_injections.dag.log.'
    logfile = tempfile.mktemp()

    try:
        accounting_group = confs.get('workflow', 'accounting-group')
    except:
        accounting_group = None
        logging.warn(
            'Warning: accounting-group not specified, LDG clusters may'
            ' reject this workflow!')

    logging.info("Creating DAG")
    dag = CondorDAG(logfile)

    dag.set_dag_file("pycbc_inference_injections")
    dag.set_dax_file("pycbc_inference_injections")

    analyses = PycbcInferenceOnInjectionBatch(confs,
                                              analyses_dir,
                                              verbose=True)
    # analyses.setup_runs()
    all_injection_runs = analyses.get_runs()

    for tag in all_injection_runs:
        run = all_injection_runs[tag]

        # Setup analysis directory
        run.setup()

        # Configure injection job
        req_mem = None
        if confs.has_option('workflow', 'inspinj-request-memory'):
            req_mem = confs.get('workflow', 'inspinj-request-memory')
        req_cpus = 1
        if confs.has_option('workflow', 'inspinj-request-cpus'):
            req_cpus = int(confs.get('workflow', 'inspinj-request-cpus'))
        inj_job = InferenceJob('log',
                               run.get_inj_exe_path(),
                               None,
                               None,
                               accounting_group=accounting_group,
                               request_memory=req_mem,
                               request_cpus=req_cpus)
        inj_node = CondorDAGNode(inj_job)
        dag.add_node(inj_node)

        # Configure inference job
        req_mem = None
        if confs.has_option('workflow', 'inference-request-memory'):
            req_mem = confs.get('workflow', 'inference-request-memory')
        req_cpus = 10
        if confs.has_option('workflow', 'inference-request-cpus'):
            req_cpus = int(confs.get('workflow', 'inference-request-cpus'))
        inf_job = InferenceJob('log',
                               run.get_inf_exe_path(),
                               None,
                               None,
                               accounting_group=accounting_group,
                               request_memory=req_mem,
                               request_cpus=req_cpus)
        inf_node = CondorDAGNode(inf_job)
        inf_node.add_parent(inj_node)
        dag.add_node(inf_node)

        # Configure plotting job
        req_mem = None
        if confs.has_option('workflow', 'plot-request-memory'):
            req_mem = confs.get('workflow', 'plot-request-memory')
        req_cpus = 1
        if confs.has_option('workflow', 'plot-request-cpus'):
            req_cpus = int(confs.get('workflow', 'plot-request-cpus'))
        if confs.has_option('executables', 'plot'):
            # We allow for multiple plotting jobs
            for ss in confs.get_subsections('plot'):
                plt_job = InferenceJob('log',
                                       run.get_plt_exe_path() +
                                       '_{0}'.format(ss),
                                       None,
                                       None,
                                       accounting_group=accounting_group,
                                       request_memory=req_mem,
                                       request_cpus=req_cpus)
                plt_node = CondorDAGNode(plt_job)
                plt_node.add_parent(inf_node)
                dag.add_node(plt_node)

    # Write out the DAG
    dag.write_sub_files()
    dag.write_script()
    dag.write_concrete_dag()

logging.info('Done')
