#!/usr/bin/env python
#
# Copyright (C) 2020 Prayush Kumar
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Setup workflow to perform calibration of ENIGMA model using another BBH model """
import time
__itime__ = time.time()
import os
import sys
import itertools
import shutil
import argparse
import logging
import pandas as pd
import h5py
import numpy as np

from gwnrtools import __version__

############################################################
# command line usage
parser = argparse.ArgumentParser(usage=__file__ + " [--options]",
                                 description=__doc__)
parser.add_argument("--version", action="version", version=__version__,
                    help="Prints version information.")
parser.add_argument("--verbose", action="store_true", default=False,
                    help="Print logging messages.")

# Parameter ranges for those we can grid calculations over
parser.add_argument("--params-to-grid-over", type=str, required=False,
                    default='f_lower,range,15.0,20.0,5.0;psd,choices,aLIGOZeroDetHighPower,flat_unity',
                    help="""\
Semi-colon separated list of parameter details that are
to be gridded over (and held fixed during MCMC). Allowed
types are either 'choices' or 'range'. E.g.

--params-to-grid-over="param1,range,lower_limit,upper_limit,step_size;param2,choices,choice1,choice2"
""")

# Parameters to sample during MCMC OR the fit tag OR both
parser.add_argument("--params-to-mcmc-over", type=str, required=False,
                    default='',
                    help="""\
Semi-colon separated list of parameter details that are
to be sampled during MCMC. E.g.

--params-to-mcmc-over="param1,uniform,lower_limit,upper_limit;param2,choices,choice1,choice2"
""")
parser.add_argument("--enigma-fit-tag", type=str, required=True,
                    help="TAG that uniquely identifies the fit to be used.")

# Parameters that are secretly sampled during MCMC
parser.add_argument("--hidden-params", type=str, required=False,
                    default='',
                    help="""\
Semi-colon separated list of parameter details that are implicitly
sampled during MCMC (i.e. they are sampled but the Markovian chains
do not actively move in the direction of higher likelihood in the
dimensions of these parameters). E.g.

--hidden-params="param1,uniform,lower_limit,upper_limit;param2,choices,choice1,choice2"
""")

# Other (dependent) parameters
parser.add_argument("--dependent-params", type=str, required=False,
                    default='omega_attach,range,0.01,0.1',
                    help="""\
Parameters that are dependent on others, might not be sampled over
but could be used for setting implicity priors etc. These parameters
will be interpreted directly by the sampling script. E.g.

--dependent-params="param1,uniform,lower_limit,upper_limit;param2,choices,choice1,choice2"
""")

# Global parameter choices
parser.add_argument("--signal-approx", type=str, required=False,
                    default='SEOBNRv4_ROM',
                    help="Signal approximant to test against (FD)")

# MCMC and filtering options
parser.add_argument("--num-samplers", type=int, default=32,
                    help="No of MCMC walkers")
parser.add_argument("--num-mcmc-steps", type=int, default=1000,
                    help="No of MCMC steps per walker")
parser.add_argument("--sample-rate", type=int, default=4096,
                    help="Sampling rate for wave gen and matches")
parser.add_argument("--time-length", type=int, default=32,
                    help="Expected max duration of waves")

# parallelization options
parser.add_argument("--num-processes", type=int, default=5,
                    help="No of Multiprocessing processes")

# output options
parser.add_argument("--run-dir", type=str, required=False, default='.',
                    help="Run directory path.")
parser.add_argument("--output-prefix", type=str, required=False,
                    default='results/matches_vs_',
                    help="Input parameter file")


# parallelization options
parser.add_argument("--num-calcs-per-job", type=int, default=1,
                    help="Number of calculations per job ")


# parse command line
opts = parser.parse_args()

logging.getLogger().setLevel(logging.INFO)

############################################################
# inputs


def parse_deterministic_params(input_opt):
    """
Parse semi-colon separated list of deterministic
parameters. The required input format for each
parameter's entry is

1. param-name,range,param-lower-limit,param-upper-limit,param-stepsize
2. param-name,choices,param-value1,param-value2,...
    """
    params = {}
    for param_group in input_opt.split(";"):
        if len(param_group) == 0:
            continue
        param_group_elems = param_group.split(',')
        param_name = param_group_elems[0]
        param_type = param_group_elems[1]
        if param_type == 'choices':
            params[param_name] = param_group_elems[2:]
        elif param_type == 'range':
            arange_str = ""
            for p in param_group_elems[2:]:
                arange_str = arange_str + "{},".format(p)
            arange_str = arange_str[:-1]  # remove trailing comma
            params[param_name] = eval("np.arange(" + arange_str + ")")
        else:
            raise IOError(
                "Invalid format in gridded param {}".format(input_opt))
    return params


def parse_variable_params(input_opt):
    """
Parse semi-colon separated list of variable
parameters. The required input format for each
parameter's entry is

param-name,param-distribution,param-lower-limit,param-upper-limit

where:

- param-distribution can be uniform
    """
    params = {}
    for param_group in input_opt.split(";"):
        if len(param_group) == 0:
            continue
        param_group_elems = param_group.split(',')
        param_name = param_group_elems[0]
        param_dist = param_group_elems[1]
        params[param_name] = (param_dist,
                              np.float64(param_group_elems[2:]).flatten())
    return params


############################################################
# temporaries
__script_name__ = 'gwnrtools_enigma_sample_calib_parameters'
__submit_file_name__ = 'sampler.submit'

__submit_text__ = '''\
universe = vanilla
initialdir = {8}
executable = scripts/{7}
arguments = " --job-id $(macrojobid) --param-file $(macroparamfile) --signal-approx {0} --num-samplers {1} --num-mcmc-steps {2} --sample-rate {3} --time-length {4} --output-prefix {5} --enigma-tag {6} --num-processes {9}"
accounting_group = ligo.dev.o3.cbc.explore.test
accounting_group_user = prayush.kumar
request_memory = {10}G
request_cpus = {9}
getenv = True
max_retries = 10
log = /usr1/prayush.kumar/tmpc8uHuQ
error = log/gwnrtools_enigma_sample_parameters-$(cluster)-$(process).err
output = log/gwnrtools_enigma_sample_parameters-$(cluster)-$(process).out
notification = never
queue 1
'''.format(
    opts.signal_approx,
    opts.num_samplers,
    opts.num_mcmc_steps,
    opts.sample_rate,
    opts.time_length,
    opts.output_prefix,
    opts.enigma_fit_tag,
    __script_name__,
    os.path.abspath(opts.run_dir),
    opts.num_processes,
    2*opts.num_processes
)


def job_id(job_num):
    return '{:06d}'.format(job_num)


def dag_file():
    return 'retune_enigma.dag'


def calc_info_file(job_id, calc_id):
    return os.path.join('input/parameters_{0}_{1}.json'.format(job_id, calc_id))


def parameter_file(job_id):
    return os.path.join('input/parameters_{0}.hdf'.format(job_id))


def group_name(i):
    return '{0:06d}'.format(i)


def write_parameter_file(job_id, params):
    '''Outdated'''
    file_name = os.path.join('input/parameters_{0}.hdf'.format(job_id))
    if os.path.exists(file_name):
        return file_name
    with h5py.File(file_name, 'a') as fout:
        for param_name in params:
            fout.create_dataset(param_name, data=list(
                np.array(params[param_name])))
    return file_name


def write_parameters_to_group(fout, group_name, params):
    gout = fout.create_group(group_name)
    for param_name in params:
        gout.create_dataset(param_name, data=list(params[param_name]))


def write_dataframe_to_group(fout, group_name, params):
    gout = fout.create_group(group_name)
    for param_name in params:
        gout.create_dataset(param_name, data=list(params[param_name]))


def write_job_to_dag(job_id, param_file):
    with open(dag_file(), 'a+') as fout:
        fout.write('''\
JOB TUNE{0} {2}
VARS TUNE{0} macrojobid="{0}"
VARS TUNE{0} macroparamfile="{1}"
RETRY TUNE{0} 10

'''.format(job_id, param_file, __submit_file_name__))


############################################################
# Setup

# Move to run directory
if not os.path.exists(opts.run_dir):
    os.makedirs(opts.run_dir)
os.chdir(opts.run_dir)

# Make directories
dirs_to_make = ['scripts', 'input', 'results', 'log']
for d in dirs_to_make:
    if not os.path.exists(d):
        os.makedirs(d)

sampling_prog = os.popen('which {0}'.format(__script_name__)).read().strip()
shutil.copy(sampling_prog, 'scripts/{0}'.format(__script_name__))
os.chmod('scripts/{0}'.format(__script_name__), 0o0777)

logging.info(".. run directories setup.")

############################################################
# Parse inputs
parameters_to_grid_over = parse_deterministic_params(opts.params_to_grid_over)
parameters_to_mcmc_over = parse_variable_params(opts.params_to_mcmc_over)
parameters_hidden = parse_variable_params(opts.hidden_params)
parameters_dependent = parse_variable_params(opts.dependent_params)

# Make the grid
grid_choices = list(itertools.product(
    *[parameters_to_grid_over[p] for p in parameters_to_grid_over]))
logging.info(" .. parameter grid constructed.")


jobs = {}

num_mcmc_in_job = 1e9
job_num = 0
job_params = None

for p_vector in grid_choices:
    curr_p = {}
    # Choose parameters for this MCMC exploration
    # 1. Use parameters from the current grid choice
    curr_p['fixed_params'] = pd.DataFrame.from_dict({p: ('fixed', [p_vector[i]])
                                                     for i, p in enumerate(parameters_to_grid_over)})
    # 2. Store sampling parameters
    curr_p['sampling_params'] = pd.DataFrame.from_dict(parameters_to_mcmc_over)
    # 3. Store hidden parameters
    curr_p['hidden_params'] = pd.DataFrame.from_dict(parameters_hidden)
    # 4. Store dependent parameters
    curr_p['dependent_params'] = pd.DataFrame.from_dict(parameters_dependent)

    for param_type in curr_p:
        if len(curr_p[param_type]) != 2:
            continue
        curr_p[param_type] = curr_p[param_type].set_index(
            pd.Index(['dist', 'range']))

    # Set job info for DAG formation
    if num_mcmc_in_job < opts.num_calcs_per_job:
        # add to current job
        job_params.append(curr_p)

        num_mcmc_in_job += 1
        job_num += 1
    else:
        # close current job
        if job_params:
            jobs[curr_job_id] = job_params

        # start new job
        curr_job_id = job_id(job_num)
        job_params = [curr_p]

        num_mcmc_in_job = 1
        job_num += 1

# Leftover jobs
if job_params:
    jobs[curr_job_id] = job_params
    job_params = None

logging.info(" .. job details setup.")
############################################################
# Write submit file
with open(__submit_file_name__, "w") as fout:
    fout.write(__submit_text__)

logging.info(" .. condor submission file written.")
############################################################


# Design DAG
# Iterate over the grid, and add a job for each!
for _job_id in jobs:
    job_params = jobs[_job_id]

    calc_file_names = []
    for i, jp in enumerate(job_params):
        this_calc_file_name = calc_info_file(_job_id, i)
        with open(this_calc_file_name, "a") as fout:
            for param_type in jp:
                fout.write(param_type + ";")
                jp[param_type].to_json(fout)
                fout.write("\n")
        calc_file_names.append(this_calc_file_name)

    parameter_file_name = parameter_file(_job_id)
    with h5py.File(parameter_file_name, "w") as fout:
        fout.attrs['calc_info'] = calc_file_names

    write_job_to_dag(_job_id, parameter_file_name)

logging.info(" .. DAG written.")
logging.info("All Done in {0} seconds".format(time.time() - __itime__))
