#!/usr/bin/env python

"""
generates waveform data will be fed into scrinet_build_rb

generates 7d data for precessing single spin systems in co-precessing frame.

parameters

q
spin1_mag, spin1_theta, spin1_phi (that get transformed into chi1x, chi1y, chi1z)
spin2_mag, spin2_theta, spin2_phi (that get transformed into chi2x, chi2y, chi2z)

example:

gen seed data
scrinet_gen_wf_data_7d_prec_single_spin_coprec --grid regular --npts 2 -v --n-cores 4 --output-dir seed_wf_data

# used to build the basis and in fitting coeffs
scrinet_gen_wf_data_7d_prec_single_spin_coprec --grid random --npts 100 -v --n-cores 4 --output-dir train_wf_data

# used to validate the fits of coeffs
scrinet_gen_wf_data_7d_prec_single_spin_coprec --grid random --npts 100 -v --n-cores 4 --output-dir validation_wf_data

# final test set, not seen before
scrinet_gen_wf_data_7d_prec_single_spin_coprec --grid random --npts 100 -v --n-cores 4 --output-dir test_wf_data

"""
import numpy as np
import time
import sys
import os
import lalsimulation as lalsim
import argparse
import h5py
import pandas

from scrinet.workflow.pipe_utils import init_logger
from scrinet.workflow.generators import gen_7d_prec_single_spin_coprec_data, polar_to_cart, cart_to_polar

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("--data-to-save", type=str, nargs='+',
                        default=['time', 'amp', 'phase',
                                 'alpha', 'beta', 'gamma'],
                        help="list of data to save.",
                        choices=['time', 'amp', 'phase', 'freq', 'real', 'imag', 'alpha', 'beta', 'gamma'])

    parser.add_argument("--output-dir", type=str, default='seed_wf_data',
                        help="directory to save data")

    parser.add_argument("--q-min", type=float, default=1,
                        help="minimum mass-ratio")
    parser.add_argument("--q-max", type=float, default=8,
                        help="maximum mass-ratio")
    parser.add_argument("--chi1-min", type=float, default=0,
                        help="minimum chi1 magnitude")
    parser.add_argument("--chi1-max", type=float, default=0.99,
                        help="maxium chi1 magnitude")
    parser.add_argument("--theta1-min", type=float, default=0,
                        help="minimum theta1 polar angle")
    parser.add_argument("--theta1-max", type=float, default=np.pi,
                        help="maxium theta1 polar angle")
    parser.add_argument("--phi1-min", type=float, default=0,
                        help="minimum phi1 polar angle")
    parser.add_argument("--phi1-max", type=float, default=2*np.pi,
                        help="maxium phi1 polar angle")

    parser.add_argument("--chi2-min", type=float, default=0,
                        help="minimum chi2 magnitude")
    parser.add_argument("--chi2-max", type=float, default=0.99,
                        help="maxium chi2 magnitude")
    parser.add_argument("--theta2-min", type=float, default=0,
                        help="minimum theta2 polar angle")
    parser.add_argument("--theta2-max", type=float, default=np.pi,
                        help="maxium theta2 polar angle")
    parser.add_argument("--phi2-min", type=float, default=0,
                        help="minimum phi2 polar angle")
    parser.add_argument("--phi2-max", type=float, default=2*np.pi,
                        help="maxium phi2 polar angle")

    parser.add_argument("--total-mass", type=float, default=60,
                        help="total mass in Msun")
    parser.add_argument("--sample-rate", type=float, default=1024.,
                        help="deltaT will be 1./sample_rate")
    parser.add_argument("--approx", type=str, default="SEOBNRv4PHM",
                        help="Select waveform model")

    parser.add_argument("--f-min", type=float, default=10,
                        help="Start frequency in Hz for waveform generation")

    parser.add_argument("--t-min", type=float, default=-10000,
                        help="Time in M before peak to use")
    parser.add_argument("--t-max", type=float, default=100,
                        help="Time in M after peak to use")
    parser.add_argument("--t-npts", type=int, default=5000,
                        help="Number of points to use for time grid")

    parser.add_argument("--random-seed", type=int, default=None,
                        help="Set random numpy seed. Used in --grid random")
    parser.add_argument("--grid", type=str, required=True,
                        choices=['regular', 'random', 'file'],
                        help="choose how to generate date")
    parser.add_argument("--npts", type=int, default=10,
                        help="""
                        if --grid regular then --npts is the number of points
                        in each dimension.
                        if --grid random then --npts is the total number of
                        points.""")
    parser.add_argument("--coord-file", type=str,
                        help="""ASCII file of three columns mass-ratio, chi1z
                        and chi2z.
                        if provided then will generate data at these points.
                        Expected if --grid file
                        """)

    parser.add_argument("--n-cores", type=int, default=1,
                        help="number of cores to use in wf generation")
    parser.add_argument("-v", "--verbose", help="increase output verbosity",
                        action="store_true")

    args = parser.parse_args()

    logger = init_logger()

    if args.verbose:
        logger.info(f"current file: {__file__}")
        logger.info("verbosity turned on")
        logger.info(f"using {args.n_cores} cores")
        logger.info(f"using grid option: {args.grid}")

    os.environ['OMP_NUM_THREADS'] = str(1)
    if args.verbose:
        logger.info("OMP_NUM_THREADS: 1")

    if args.random_seed is not None:
        np.random.seed(args.random_seed)
        if args.verbose:
            logger.info(f"setting random seed: {args.random_seed}")

    if args.verbose:
        logger.info(f"Making output directory: {args.output_dir}")
    os.makedirs(f"{args.output_dir}", exist_ok=True)

    if args.grid == 'file':
        assert args.coord_file is not None, "--grid file specified but not --coord-file"
        if args.verbose:
            logger.info(f"found coord file: {args.coord_file}")
        qq, chi1s, theta1s, phi1s, chi2s, theta2s, phi2s = np.loadtxt(
            args.coord_file, unpack=True)
    elif args.grid == 'regular':
        if args.verbose:
            logger.info("Generating data on regular grid.")
        q_1d = np.linspace(args.q_min, args.q_max, args.npts)
        chi1_1d = np.linspace(args.chi1_min, args.chi1_max, args.npts)
        theta1_1d = np.linspace(args.theta1_min, args.theta1_max, args.npts)
        phi1_1d = np.linspace(args.phi1_min, args.phi1_max, args.npts)
        chi2_1d = np.linspace(args.chi2_min, args.chi2_max, args.npts)
        theta2_1d = np.linspace(args.theta2_min, args.theta2_max, args.npts)
        phi2_1d = np.linspace(args.phi2_min, args.phi2_max, args.npts)
        qq_tmp, chi1s_tmp, theta1s_tmp, phi1s_tmp, chi2s_tmp, theta2s_tmp, phi2s_tmp = np.meshgrid(
            q_1d, chi1_1d, theta1_1d, phi1_1d, chi2_1d, theta2_1d, phi2_1d)
        qq_tmp = qq_tmp.ravel()
        chi1s_tmp = chi1s_tmp.ravel()
        theta1s_tmp = theta1s_tmp.ravel()
        phi1s_tmp = phi1s_tmp.ravel()
        chi2s_tmp = chi2s_tmp.ravel()
        theta2s_tmp = theta2s_tmp.ravel()
        phi2s_tmp = phi2s_tmp.ravel()

        if args.verbose:
            logger.info(
                "THIS IS A VERY ROUND ABOUT WAY OF MAKING A REGULAR SPHERICAL GIRD WITHOUT DUPLICATES AT chi=0. PLEASE FIX ME IN FUTURE")
            logger.info("remove duplicate entries")
            logger.info("convert to cartesian")
        chi1x_array, chi1y_array, chi1z_array = polar_to_cart(
            chi1s_tmp, theta1s_tmp, phi1s_tmp)
        chi2x_array, chi2y_array, chi2z_array = polar_to_cart(
            chi2s_tmp, theta2s_tmp, phi2s_tmp)
        params = np.column_stack(
            (qq_tmp, chi1x_array, chi1y_array, chi1z_array, chi2x_array, chi2y_array, chi2z_array))
        df = pandas.DataFrame(data=params, columns=[
                              'q', 'chi1x', 'chi1y', 'chi1z', 'chi2x', 'chi2y', 'chi2z'])
        qq, chi1x, chi1y, chi1z, chi2x, chi2y, chi2z = \
            df.round(5).drop_duplicates().to_numpy().T
        if args.verbose:
            logger.info("now convert back to spherical polar")
        chi1s, theta1s, phi1s = cart_to_polar(chi1x, chi1y, chi1z)
        chi2s, theta2s, phi2s = cart_to_polar(chi2x, chi2y, chi2z)

    elif args.grid == 'random':
        if args.verbose:
            logger.info("Generating data randomly with uniform distribution.")
        qq = np.random.uniform(args.q_min, args.q_max, args.npts)
        chi1s = np.random.uniform(args.chi1_min, args.chi1_max, args.npts)
        theta1s = np.random.uniform(
            args.theta1_min, args.theta1_max, args.npts)
        phi1s = np.random.uniform(args.phi1_min, args.phi1_max, args.npts)
        chi2s = np.random.uniform(args.chi2_min, args.chi2_max, args.npts)
        theta2s = np.random.uniform(
            args.theta2_min, args.theta2_max, args.npts)
        phi2s = np.random.uniform(args.phi2_min, args.phi2_max, args.npts)

    if args.verbose:
        logger.info(f"total number of waveforms to generate: {len(qq)}")

    # generates waveform data
    deltaT = 1./args.sample_rate
    lal_approx = lalsim.SimInspiralGetApproximantFromString(args.approx)

    if args.verbose:
        logger.info(f"data to save: {args.data_to_save}")

    if args.verbose:
        logger.info("Generating waveforms")

    t1 = time.time()
    ts_x, ts_amp, ts_phase, ts_freq, ts_hreal, ts_himag, ts_alpha, ts_beta, ts_gamma, ts_coords = gen_7d_prec_single_spin_coprec_data(
        q_array=qq,
        chi1_array=chi1s,
        theta1_array=theta1s,
        phi1_array=phi1s,
        chi2_array=chi2s,
        theta2_array=theta2s,
        phi2_array=phi2s,
        M=args.total_mass,
        n_cores=args.n_cores,
        deltaT=deltaT,
        f_min=args.f_min,
        approximant=lal_approx,
        t_min=args.t_min,
        t_max=args.t_max,
        npts=args.t_npts
    )
    t2 = time.time()
    dur = t2-t1
    if args.verbose:
        logger.info(f"WF generation took: {dur:.3f} s")

    if args.verbose:
        logger.info("Finished generating waveforms")

    # save data - probably use h5py?
    if args.verbose:
        logger.info(f"Saving data to {args.output_dir}")

    if args.verbose:
        logger.info("saving coords")
    filename = os.path.join(args.output_dir, 'coords.h5')
    with h5py.File(filename, "w") as f:
        f.create_dataset("data", data=ts_coords)

    if args.verbose:
        logger.info("saving times")
    filename = os.path.join(args.output_dir, 'times.h5')
    with h5py.File(filename, "w") as f:
        f.create_dataset("times", data=ts_x)

    if 'amp' in args.data_to_save:
        if args.verbose:
            logger.info("saving amp")
        filename = os.path.join(args.output_dir, 'amp.h5')
        with h5py.File(filename, "w") as f:
            f.create_dataset("data", data=ts_amp)

    if 'phase' in args.data_to_save:
        if args.verbose:
            logger.info("saving phase")
        filename = os.path.join(args.output_dir, 'phase.h5')
        with h5py.File(filename, "w") as f:
            f.create_dataset("data", data=ts_phase)

    if 'freq' in args.data_to_save:
        if args.verbose:
            logger.info("saving freq")
        filename = os.path.join(args.output_dir, 'freq.h5')
        with h5py.File(filename, "w") as f:
            f.create_dataset("data", data=ts_freq)

    if 'real' in args.data_to_save:
        if args.verbose:
            logger.info("saving real")
        filename = os.path.join(args.output_dir, 'real.h5')
        with h5py.File(filename, "w") as f:
            f.create_dataset("data", data=ts_hreal)

    if 'imag' in args.data_to_save:
        if args.verbose:
            logger.info("saving imag")
        filename = os.path.join(args.output_dir, 'imag.h5')
        with h5py.File(filename, "w") as f:
            f.create_dataset("data", data=ts_himag)

    if 'alpha' in args.data_to_save:
        if args.verbose:
            logger.info("saving alpha")
        filename = os.path.join(args.output_dir, 'alpha.h5')
        with h5py.File(filename, "w") as f:
            f.create_dataset("data", data=ts_alpha)

    if 'beta' in args.data_to_save:
        if args.verbose:
            logger.info("saving beta")
        filename = os.path.join(args.output_dir, 'beta.h5')
        with h5py.File(filename, "w") as f:
            f.create_dataset("data", data=ts_beta)

    if 'gamma' in args.data_to_save:
        if args.verbose:
            logger.info("saving gamma")
        filename = os.path.join(args.output_dir, 'gamma.h5')
        with h5py.File(filename, "w") as f:
            f.create_dataset("data", data=ts_gamma)
