#!/usr/bin/env python
#
# Copyright (C) 2020 Prayush Kumar
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
"""Perform calibration runs for ENIGMA model using another BBH model """

import time
__itime__ = time.time()

import os
import sys
import h5py
import numpy as np
import argparse
import logging
import traceback
from multiprocessing import Pool
import pandas as pd

from gwnrtools import __version__
from gwnrtools.utils import make_padded_frequency_series
from gwnrtools.stats.priors import (__all_cbc_parameters__, default_bbh_params)
from gwnrtools.stats.samplers import (get_emcee_ensemble_sampler,
                                      write_output_from_emcee_sampler)
from gwnrtools.stats.sampling import OneDRandom
from gwnrtools.stats.enigma_utils import (log_prior_enigma,
                                          log_likelihood_enigma,
                                          log_prob_enigma, __log_prob_funcs__,
                                          __order_of_sampled_params__,
                                          __ranges_of_sampled_params__)

from pycbc.psd import from_string

############################################################
# command line usage
parser = argparse.ArgumentParser(usage=__file__ + " [--options]",
                                 description=__doc__)
parser.add_argument("--version",
                    action="version",
                    version=__version__,
                    help="Prints version information.")
parser.add_argument("--verbose",
                    action="store_true",
                    default=False,
                    help="Print logging messages.")
parser.add_argument("--debug",
                    action="store_true",
                    default=False,
                    help="Print debugging messages.")

# workflow options
parser.add_argument("--job-id", type=str, required=True)
parser.add_argument("--param-file",
                    type=str,
                    required=True,
                    help="Input parameter file")
parser.add_argument("--enigma-tag",
                    type=str,
                    required=True,
                    help="TAG that uniquely identifies the fit to be used.")

parser.add_argument("--signal-approx",
                    type=str,
                    required=False,
                    default='SEOBNRv4_ROM',
                    help="Signal approximant to test against (FD)")

parser.add_argument("--num-samplers",
                    type=int,
                    default=32,
                    help="No of MCMC walkers")
parser.add_argument("--num-mcmc-steps",
                    type=int,
                    default=100,
                    help="No of MCMC steps per walker")
parser.add_argument("--sample-rate",
                    type=int,
                    default=4096,
                    help="Sampling rate for wave gen and matches")
parser.add_argument("--time-length",
                    type=int,
                    default=32,
                    help="Expected max duration of waves")

parser.add_argument("--use-dilation-map-match",
                    action="store_true",
                    default=False,
                    help="use a nonlinear dilating map for match")
# parallelization options
parser.add_argument("--num-processes",
                    type=int,
                    default=2,
                    help="No of Multiprocessing processes")

# output options
parser.add_argument("--output-prefix",
                    type=str,
                    required=False,
                    default='results/samples_vs_',
                    help="Output file prefix to store samples in.")

# parse command line
opts = parser.parse_args()

logging.getLogger().setLevel(logging.INFO)

if opts.enigma_tag not in __ranges_of_sampled_params__:
    raise IOError("FIT TAG: {} not recognized for ranges".format(
        opts.enigma_tag))
if opts.enigma_tag not in __order_of_sampled_params__:
    raise IOError("FIT TAG: {} not recognized for order".format(
        opts.enigma_tag))
if opts.enigma_tag not in __log_prob_funcs__:
    raise IOError("FIT TAG: {} not recognized for logprob funcs".format(
        opts.enigma_tag))
############################################################
# inputs
job_inputs = {}

job_inputs['job_id'] = opts.job_id
job_inputs['param_file'] = opts.param_file
job_inputs['signal_approx'] = opts.signal_approx
job_inputs['output_prefix'] = opts.output_prefix + \
    '{0}'.format(opts.signal_approx)
# job_inputs['use_dilation_map_match'] = opts.use_dilation_map_match

# MCMC params
job_inputs['num_samplers'] = opts.num_samplers
job_inputs['num_steps'] = opts.num_mcmc_steps

# Filtering params
job_inputs['sample_rate'] = opts.sample_rate
job_inputs['time_length'] = opts.time_length

job_inputs['delta_t'] = 1. / job_inputs['sample_rate']
job_inputs['delta_f'] = 1. / job_inputs['time_length']
N = job_inputs['sample_rate'] * job_inputs['time_length']
n = N / 2 + 1

job_inputs = pd.DataFrame.from_dict({i: (job_inputs[i], ) for i in job_inputs})

############################################################
# Setup
# Note: the dataframes populated below are meant to provide sampling
# details for all possible parameters that we might want to sample.
# This does **not** provide an eshaustive list of sampled parameters.
# That is to be provided by the version of ENIGMA fit, and is defined
# in `gwnrtools.stats.enigma_utils`.
all_calc_params = {}
all_calc_inputs = {}

with h5py.File(job_inputs['param_file'][0], 'r') as fin:
    # The input file has attributes set to a list of JSON
    # files containing information related to each calibration
    # experiment this job is to run. We loop over each of those
    # files and therefore each of those experiments.
    for j_id, calc_info_file in enumerate(fin.attrs['calc_info']):
        calc_params_from_input = {}
        with open(calc_info_file, "r") as json_fin:
            lines = json_fin.readlines()
            for l in lines:
                split_line = l.split(";")
                calc_params_from_input[split_line[0]] = pd.read_json(
                    split_line[-1])

        if opts.verbose:
            logging.info("Preparing calc: {}".format(j_id))

        # Start populating DataFrames with parameters for the current
        # calculation
        calc_inputs = job_inputs.copy()
        calc_params = default_bbh_params.copy()
        calc_params = calc_params.append(
            pd.Series({c: "default"
                       for c in calc_params.columns},
                      name="vartype"))

        # Add column vartype indicating the type of parameter
        # possible values are: default, hidden, sampling, fixed, dependent
        for param_family in calc_params_from_input:
            # Assuming name = 'something_params'
            param_type = param_family.split("_")[0]
            # Add row showing the type of parameter now
            calc_params_from_input[param_family] =\
                calc_params_from_input[param_family].append(pd.Series(
                    {c: param_type
                     for c in calc_params_from_input[param_family].columns},
                    name="vartype"))

            # Populate the main calc_inputs DataFrame with parameters
            # read-in from disk.
            for param in calc_params_from_input[param_family]:
                # Check if the parameter exists in the initial df
                if param in calc_params:
                    for idx in calc_params_from_input[param_family][
                            param].index:
                        calc_params[param][idx] =\
                            calc_params_from_input[param_family][param][idx]
                else:
                    # Else add a new column for this parameter!
                    # get order of indexes in calc_params first
                    # retain that order as much as possible
                    calc_params[param] = calc_params_from_input[param_family][
                        param]

        # Insert (a1-b4) calibration parameters as columns here
        ranges_of_sampled_params_for_this_fit = __ranges_of_sampled_params__[
            opts.enigma_tag]
        for p in ranges_of_sampled_params_for_this_fit:
            # In case user forces a range for any parameter that
            # is part of the requested fit, use that
            if p in calc_params:
                if calc_params[p].vartype != 'default':
                    continue

            r = ranges_of_sampled_params_for_this_fit[p]

            if len(r) == 2:  # interpret this as an interval
                calc_params[p] = pd.Series(
                    {
                        "dist": "uniform",
                        "range": r,
                        "vartype": "sampling"
                    },
                    name=p)
            elif len(r) > 2 or len(
                    r) == 1:  # interpret this as discrete choices
                calc_params[p] = pd.Series(
                    {
                        "dist": "choices",
                        "range": r,
                        "vartype": "sampling"
                    },
                    name=p)
            else:
                raise RuntimeError("Range {} for param {} not valid!".format(
                    r, p))
        # Insert an ordered list of sampled parameters
        ordered_sampling_params = [
            c for c in __order_of_sampled_params__[opts.enigma_tag]
            if calc_params[c].vartype == "sampling"
        ]
        calc_params['sampler_params'] = pd.Series(
            {
                "dist": "fixed",
                "range": ordered_sampling_params,
                "vartype": "fixed"
            },
            name='sampler_params')
        calc_inputs['sampler_params'] = (ordered_sampling_params, )
        all_calc_params[j_id] = calc_params
        all_calc_inputs[j_id] = calc_inputs

logging.info(".. input parameters read.")

############################################################
if opts.verbose:
    logging.info("Starting job: {}".format(job_inputs.job_id[0]))
    logging.info("Will read from: {}".format(job_inputs.param_file[0]))
    logging.info("Will compute matches against {}".format(
        job_inputs.signal_approx[0]))
    logging.info("Will sample omega_att: [{0}, {1}]".format(
        *calc_params.omega_attach.range))
    logging.info("Will sample PN order from: {}".format(calc_params.PNO.range))
    logging.info("Will take {} MCMC steps per sampler".format(
        job_inputs.num_steps[0]))
    logging.info("Will use {} samplers".format(job_inputs.num_samplers[0]))
    logging.info("Will write output with prefix: {}".format(
        job_inputs.output_prefix[0]))
    logging.info("Will filter at {0}Hz with maxT = {1}secs".format(
        job_inputs.sample_rate[0], job_inputs.time_length[0]))


############################################################
# Functions and classes
def output_file(prefix, job_id, idx):
    return prefix + '{0}_{1:06d}.dat'.format(job_id, idx)


def hdf_backend_file(out_file):
    return out_file.replace('.dat', '.ckpt.hdf')


def get_this_global_param(param):
    try:
        x = calc_params[param].range[0]
    except:
        try:
            x = calc_inputs[param][0]
        except:
            logging.info("""
Could not find {} anywhere, including in the input file {}""".format(
                param, job_inputs['param_file'][0]))
            raise
    return x


############################################################
# Create and run samplers, store output
try:
    __my_pool__ = Pool(processes=opts.num_processes)
except:
    __my_pool__ = None

samplers = {}
old_f_lower = -1

# Loop over all unique MCMC jobs
num_mcmc = len(all_calc_params)

for idx, j_id in enumerate(all_calc_params):
    calc_params = all_calc_params[j_id]
    calc_inputs = all_calc_inputs[j_id]
    if opts.verbose:
        logging.info("Params for calc {}: {}".format(
            j_id, calc_params.columns.get_values()))

    # As there are multiple calculations to run, we do not want the
    # failure of one to affect the other. Hence we handle exceptions below.
    try:
        # filtering related prameters
        f_lower = float(get_this_global_param('f_lower'))
        df = float(get_this_global_param('delta_f'))
        dt = float(get_this_global_param('delta_t'))
        sample_rate = 1. / dt
        N = int(1. / dt / df)
        n = N / 2 + 1

        if opts.verbose:
            logging.info(
                "\n\n ... starting MCMC for setting {0}/{1}: f_low={2:.0f}".
                format(idx, len(all_calc_params), f_lower))

        if f_lower != old_f_lower:
            # from global settings
            psd = from_string(get_this_global_param('psd'), n, df, f_lower)
            psd = make_padded_frequency_series(psd, N, df)
            old_f_lower = f_lower

        # Create a backend to checkpoint sampler output
        hdf_backend = hdf_backend_file(
            output_file(get_this_global_param('output_prefix'),
                        get_this_global_param('job_id'), idx))
        if opts.verbose:
            logging.info("... will checkpoint the MCMC sampler to {}".format(
                hdf_backend))

        # Initialize the sampler
        ordered_sampling_params = [
            c for c in __order_of_sampled_params__[opts.enigma_tag]
            if calc_params[c].vartype == "sampling"
        ]
        sampling_params = calc_params[ordered_sampling_params]
        if opts.verbose:
            logging.info("... initializing the MCMC sampler for {}".format(
                sampling_params.columns.get_values()))
        samplers[j_id] = \
            get_emcee_ensemble_sampler(__log_prob_funcs__[opts.enigma_tag],
                                       sampling_params,
                                       [calc_inputs.iloc[0],
                                        f_lower, calc_params, psd,
                                        opts.enigma_tag,  # TAG for the fit to use
                                        False,  # dilation_map
                                        opts.verbose],
                                       kwargs={},
                                       backend_hdf=hdf_backend,
                                       nwalkers=get_this_global_param(
                                           'num_samplers'),
                                       burn_in=100,
                                       pool=__my_pool__,
                                       verbose=opts.verbose,
                                       debug=opts.debug)

        # RUN the sampler
        if opts.verbose:
            logging.info("... Running the MCMC sampler for {0} steps".format(
                get_this_global_param('num_steps')))

        s, state, p0 = samplers[j_id]
        try:
            s.run_mcmc(None, calc_inputs.num_steps[0])
        except ValueError:
            s.run_mcmc(p0, calc_inputs.num_steps[0])

        # WRite output from the sampler
        logging.info("... Writing output now..")
        write_output_from_emcee_sampler(
            output_file(get_this_global_param('output_prefix'),
                        get_this_global_param('job_id'), idx), s,
            ordered_sampling_params)

        logging.info("... Written output!")
        # Early stop
        if idx >= 1e9:
            break
    except Exception as e:
        logging.error("\n")
        logging.error(traceback.format_exc())
        logging.error("Error in parameter set[{0}]: f_low={1:.0f}\n".format(
            j_id, f_lower))
        raise

__my_pool__.close()
__my_pool__.terminate()
logging.info(" .. MCMC samples written.")
logging.info("All Done in {0} seconds".format(time.time() - __itime__))
