#!/usr/bin/env python
#
# Copyright (C) 2020 Prayush Kumar
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
"""Setup workflow to perform Bayesian parameter estimation runs on a
custom set of public gravitational-wave events using open data"""

import os
import argparse
import logging
import tempfile
import numpy

from glue.pipeline import CondorDAGNode, CondorDAG
import pycbc
from pycbc import fft, opt, scheme
from pycbc.workflow import configuration

from gwnrtools.workflow.pycbc_inference import PycbcInferenceOnEventBatch
from gwnrtools.workflow.condor import InferenceJob
from gwnrtools.utils.support import (mkdir, rmdir)
from gwnrtools import __version__

############################################################
# command line usage
parser = argparse.ArgumentParser(usage=__file__ + " [--options]",
                                 description=__doc__)
parser.add_argument("--version",
                    action="version",
                    version=__version__,
                    help="Prints version information.")
parser.add_argument("--verbose",
                    action="store_true",
                    default=False,
                    help="Print logging messages.")
# workflow options

# output options
parser.add_argument("--output-dir",
                    type=str,
                    required=False,
                    default='',
                    help="Output directory path.")
parser.add_argument("--force",
                    action="store_true",
                    default=False,
                    help="If the output-dir already exists, overwrite it. "
                    "Otherwise, an OSError is raised.")
parser.add_argument("--do-not-fetch-data",
                    action="store_true",
                    default=False,
                    help="Don't fetch GWOSC data.")

# parallelization options
parser.add_argument("--nprocesses",
                    type=int,
                    default=0,
                    help="Number of processes to use. If not given then only "
                    "a single core will be used.")
parser.add_argument("--use-mpi",
                    action='store_true',
                    default=False,
                    help="Use MPI to parallelize the sampler")
parser.add_argument("--seed",
                    type=int,
                    default=0,
                    help="Seed to use for the random number generator that "
                    "initially distributes the walkers. Default is 0.")
# add config options
configuration.add_workflow_command_line_group(parser)
# add module pre-defined options
fft.insert_fft_option_group(parser)
opt.insert_optimization_option_group(parser)
scheme.insert_processing_option_group(parser)

# parse command line
opts = parser.parse_args()

# setup log
# If we're running in MPI mode, only allow the parent to print
if opts.use_mpi:
    from mpi4py import MPI
    rank = MPI.COMM_WORLD.Get_rank()
    opts.verbose &= rank == 0
pycbc.init_logging(opts.verbose)

# verify options are sane
fft.verify_fft_options(opts, parser)
opt.verify_optimization_options(opts, parser)
scheme.verify_processing_options(opts, parser)

# read configuration file
confs = configuration.WorkflowConfigParser.from_cli(opts)

# Set defaults in configuration
if True:
    if not confs.has_option('workflow', 'log-path'):
        confs.set('workflow', 'log-path', './')
    if not confs.has_option('workflow', 'psd-estimation'):
        confs.set('workflow', 'psd-estimation', 'download')
    if not confs.has_option('workflow', 'inference-request-memory'):
        confs.set('workflow', 'inference-request-memory', '')
    if not confs.has_option('workflow', 'inference-request-cpus'):
        confs.set('workflow', 'inference-request-cpus', '10')
    if not confs.has_option('workflow', 'plot-request-memory'):
        confs.set('workflow', 'plot-request-memory', '')
    if not confs.has_option('workflow', 'plot-request-cpus'):
        confs.set('workflow', 'plot-request-cpus', '1')

# override options based on CLI
override_opts = [('seed', str(opts.seed))]
try:
    confs.add_options_to_section('inference',
                                 override_opts,
                                 overwrite_options=False)
except:
    pass
if opts.nprocesses >= 1:
    override_opts = [('nprocesses', str(opts.nprocesses))]
    try:
        confs.add_options_to_section('inference',
                                     override_opts,
                                     overwrite_options=False)
    except:
        pass
if opts.use_mpi:
    override_opts = [('use-mpi', '')]
    try:
        confs.add_options_to_section('inference',
                                     override_opts,
                                     overwrite_options=False)
    except:
        pass

# set seed
numpy.random.seed(opts.seed)
logging.info("Using seed %i", opts.seed)

# we'll silence numpy warnings since they are benign and make for confusing
# logging output
numpy.seterr(divide='ignore', invalid='ignore')

# get scheme
ctx = scheme.from_cli(opts)
fft.from_cli(opts)

with ctx:

    # Sort out the directory where all analyses are to be run
    analyses_dir = os.getcwd()
    if len(opts.output_dir) > 0:
        analyses_dir = opts.output_dir
    if os.path.exists(analyses_dir):
        if not opts.force:
            raise IOError(
                "Output directory {} exists. Use --force to overwrite it.".
                format(analyses_dir))
    else:
        mkdir(analyses_dir)
    os.chdir(analyses_dir)
    logging.info("Will setup analyses in {0}".format(analyses_dir))

    # Workspace dirs
    logging.info("Making workspace directories")
    # for d in ['log']:
    #    mkdir(d)

    # Get logdir
    log_path = confs.get("workflow", 'log-path')
    if not os.path.exists(log_path):
        mkdir(log_path)

    tempfile.tempdir = log_path
    tempfile.template = 'pycbc_inference_events.dag.log.'
    logfile = tempfile.mktemp()

    try:
        accounting_group = confs.get('workflow', 'accounting-group')
    except:
        accounting_group = None
        logging.warn(
            'Warning: accounting-group not specified, LDG clusters may'
            ' reject this workflow!')

    logging.info("Creating DAG")
    dag = CondorDAG(logfile)

    dag.set_dag_file("pycbc_inference_events")
    dag.set_dax_file("pycbc_inference_events")

    analyses = PycbcInferenceOnEventBatch(confs,
                                          analyses_dir,
                                          verbose=opts.verbose)
    # analyses.setup_runs()
    all_event_runs = analyses.get_runs()

    # list to store event names that have been configured once,
    # as we do not want to download either data/psd multiple times
    events_already_setup = []

    for tag in all_event_runs:
        run = all_event_runs[tag]

        # Setup analysis directory
        run.setup()

        # Procure GWOSC data
        if not opts.do_not_fetch_data and run.get_event_name(
        ) not in events_already_setup:
            run.fetch_all_data()
            if confs.get('workflow', 'psd-estimation') == 'download':
                run.fetch_all_psds()
            events_already_setup.append(run.get_event_name())

        # Configure inference job
        req_mem = confs.get('workflow', 'inference-request-memory')
        req_cpus = int(confs.get('workflow', 'inference-request-cpus'))
        inf_job = InferenceJob('log',
                               run.get_inf_exe_path(),
                               None,
                               None,
                               accounting_group=accounting_group,
                               request_memory=req_mem,
                               request_cpus=req_cpus)
        inf_node = CondorDAGNode(inf_job)
        dag.add_node(inf_node)

        # Configure plotting job(s)
        req_mem = confs.get('workflow', 'plot-request-memory')
        req_cpus = int(confs.get('workflow', 'plot-request-cpus'))
        if confs.has_option('executables', 'plot'):
            # We allow for multiple plotting jobs
            for ss in confs.get_subsections('plot'):
                plt_job = InferenceJob('log',
                                       run.get_plt_exe_path() +
                                       '_{0}'.format(ss),
                                       None,
                                       None,
                                       accounting_group=accounting_group,
                                       request_memory=req_mem,
                                       request_cpus=req_cpus)
                plt_node = CondorDAGNode(plt_job)
                plt_node.add_parent(inf_node)
                dag.add_node(plt_node)

    # Write out the DAG
    dag.write_sub_files()
    dag.write_script()
    dag.write_concrete_dag()

rmdir(log_path)
logging.info('Done')
