#!/usr/bin/env python

"""
generates waveform data will be fed into scrinet_build_rb

generates 3d data for precessing single spin systems.

parameters

q, spin1_mag, spin1_theta (that get transformed into chi1x and chi1z)

example:

gen seed data
scrinet_gen_wf_data_3d_prec_single_spin --grid regular --npts 5 -v --n-cores 4 --output-dir seed_wf_data

# used to build the basis and in fitting coeffs
scrinet_gen_wf_data_3d_prec_single_spin --grid random --npts 100 -v --n-cores 4 --output-dir train_wf_data

# used to validate the fits of coeffs
scrinet_gen_wf_data_3d_prec_single_spin --grid random --npts 100 -v --n-cores 4 --output-dir validation_wf_data

# final test set, not seen before
scrinet_gen_wf_data_3d_prec_single_spin --grid random --npts 100 -v --n-cores 4 --output-dir test_wf_data

"""
import numpy as np
import time
import sys
import os
import lalsimulation as lalsim
import argparse
import h5py
import pandas

from scrinet.workflow.pipe_utils import init_logger
from scrinet.workflow.generators import gen_3d_prec_single_spin_data, polar_to_cart, cart_to_polar

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("--data-to-save", type=str, nargs='+',
                        default=['time', 'amp', 'phase'],
                        help="list of data to save.",
                        choices=['time', 'amp', 'phase', 'freq', 'real', 'imag'])

    parser.add_argument("--output-dir", type=str, default='seed_wf_data',
                        help="directory to save data")

    parser.add_argument("--q-min", type=float, default=1,
                        help="minimum mass-ratio")
    parser.add_argument("--q-max", type=float, default=8,
                        help="maximum mass-ratio")
    parser.add_argument("--chi1-min", type=float, default=0,
                        help="minimum chi1 magnitude")
    parser.add_argument("--chi1-max", type=float, default=0.99,
                        help="maxium chi1 magnitude")
    parser.add_argument("--theta1-min", type=float, default=0,
                        help="minimum theta1 polar angle")
    parser.add_argument("--theta1-max", type=float, default=np.pi,
                        help="maxium theta1 polar angle")
    parser.add_argument("--total-mass", type=float, default=60,
                        help="total mass in Msun")
    parser.add_argument("--sample-rate", type=float, default=1024.,
                        help="deltaT will be 1./sample_rate")
    parser.add_argument("--approx", type=str, default="SEOBNRv4P",
                        help="Select waveform model")

    parser.add_argument("--f-min", type=float, default=10,
                        help="Start frequency in Hz for waveform generation")

    parser.add_argument("--t-min", type=float, default=-10000,
                        help="Time in M before peak to use")
    parser.add_argument("--t-max", type=float, default=100,
                        help="Time in M after peak to use")
    parser.add_argument("--t-npts", type=int, default=5000,
                        help="Number of points to use for time grid")

    parser.add_argument("--random-seed", type=int, default=None,
                        help="Set random numpy seed. Used in --grid random")
    parser.add_argument("--grid", type=str, required=True,
                        choices=['regular', 'random', 'file'],
                        help="choose how to generate date")
    parser.add_argument("--npts", type=int, default=10,
                        help="""
                        if --grid regular then --npts is the number of points
                        in each dimension.
                        if --grid random then --npts is the total number of
                        points.""")
    parser.add_argument("--coord-file", type=str,
                        help="""ASCII file of three columns mass-ratio, chi1z
                        and chi2z.
                        if provided then will generate data at these points.
                        Expected if --grid file
                        """)

    parser.add_argument("--n-cores", type=int, default=1,
                        help="number of cores to use in wf generation")
    parser.add_argument("-v", "--verbose", help="increase output verbosity",
                        action="store_true")

    args = parser.parse_args()

    logger = init_logger()

    if args.verbose:
        logger.info(f"current file: {__file__}")
        logger.info("verbosity turned on")
        logger.info(f"using {args.n_cores} cores")
        logger.info(f"using grid option: {args.grid}")

    if args.random_seed is not None:
        np.random.seed(args.random_seed)
        if args.verbose:
            logger.info(f"setting random seed: {args.random_seed}")

    if args.verbose:
        logger.info(f"Making output directory: {args.output_dir}")
    os.makedirs(f"{args.output_dir}", exist_ok=True)

    if args.grid == 'file':
        assert args.coord_file is not None, "--grid file specified but not --coord-file"
        if args.verbose:
            logger.info(f"found coord file: {args.coord_file}")
        qq, chi1s, theta1s = np.loadtxt(args.coord_file, unpack=True)
    elif args.grid == 'regular':
        if args.verbose:
            logger.info("Generating data on regular grid.")
        q_1d = np.linspace(args.q_min, args.q_max, args.npts)
        chi1_1d = np.linspace(args.chi1_min, args.chi1_max, args.npts)
        theta1_1d = np.linspace(args.theta1_min, args.theta1_max, args.npts)
        qq_tmp, chi1s_tmp, theta1s_tmp = np.meshgrid(q_1d, chi1_1d, theta1_1d)
        qq_tmp = qq_tmp.ravel()
        chi1s_tmp = chi1s_tmp.ravel()
        theta1s_tmp = theta1s_tmp.ravel()

        if args.verbose:
            logger.info("THIS IS A VERY ROUND ABOUT WAY OF MAKING A REGULAR SPHERICAL GIRD WITHOUT DUPLICATES AT chi=0. PLEASE FIX ME IN FUTURE")
            logger.info("remove duplicate entries")
            logger.info("convert to cartesian")
        phi1_array = np.zeros(len(chi1s_tmp))
        chi1x_array, chi1y_array, chi1z_array = polar_to_cart(chi1s_tmp, theta1s_tmp, phi1_array)
        params = np.column_stack((qq_tmp, chi1x_array, chi1y_array, chi1z_array))
        df = pandas.DataFrame(data=params, columns=['q','chi1x', 'chi1y', 'chi1z'])
        qq, chi1x, chi1y, chi1z = df.round(5).drop_duplicates().to_numpy().T
        if args.verbose:
            logger.info("now convert back to spherical polar")
        chi1s, theta1s, phi1s = cart_to_polar(chi1x, chi1y, chi1z)

    elif args.grid == 'random':
        if args.verbose:
            logger.info("Generating data randomly with uniform distribution.")
        qq = np.random.uniform(args.q_min, args.q_max, args.npts)
        chi1s = np.random.uniform(args.chi1_min, args.chi1_max, args.npts)
        theta1s = np.random.uniform(args.theta1_min, args.theta1_max, args.npts)

    if args.verbose:
        logger.info(f"total number of waveforms to generate: {len(qq)}")

    # generates waveform data
    deltaT = 1./args.sample_rate
    lal_approx = lalsim.SimInspiralGetApproximantFromString(args.approx)

    if args.verbose:
        logger.info(f"data to save: {args.data_to_save}")

    if args.verbose:
        logger.info("Generating waveforms")


    t1 = time.time()
    ts_x, ts_amp, ts_phase, ts_freq, ts_hreal, ts_himag, ts_coords = gen_3d_prec_single_spin_data(
        q_array=qq,
        chi1_array=chi1s,
        theta1_array=theta1s,
        M=args.total_mass,
        n_cores=args.n_cores,
        deltaT=deltaT,
        f_min=args.f_min,
        approximant=lal_approx,
        t_min=args.t_min,
        t_max=args.t_max,
        npts=args.t_npts
    )
    t2 = time.time()
    dur = t2-t1
    if args.verbose:
        logger.info(f"WF generation took: {dur:.3f} s")

    if args.verbose:
        logger.info("Finished generating waveforms")

    # save data - probably use h5py?
    if args.verbose:
        logger.info(f"Saving data to {args.output_dir}")

    if args.verbose:
        logger.info("saving coords")
    filename = os.path.join(args.output_dir, 'coords.h5')
    with h5py.File(filename, "w") as f:
        f.create_dataset("data", data=ts_coords)

    if args.verbose:
        logger.info("saving times")
    filename = os.path.join(args.output_dir, 'times.h5')
    with h5py.File(filename, "w") as f:
        f.create_dataset("times", data=ts_x)

    if 'amp' in args.data_to_save:
        if args.verbose:
            logger.info("saving amp")
        filename = os.path.join(args.output_dir, 'amp.h5')
        with h5py.File(filename, "w") as f:
            f.create_dataset("data", data=ts_amp)

    if 'phase' in args.data_to_save:
        if args.verbose:
            logger.info("saving phase")
        filename = os.path.join(args.output_dir, 'phase.h5')
        with h5py.File(filename, "w") as f:
            f.create_dataset("data", data=ts_phase)

    if 'freq' in args.data_to_save:
        if args.verbose:
            logger.info("saving freq")
        filename = os.path.join(args.output_dir, 'freq.h5')
        with h5py.File(filename, "w") as f:
            f.create_dataset("data", data=ts_freq)

    if 'real' in args.data_to_save:
        if args.verbose:
            logger.info("saving real")
        filename = os.path.join(args.output_dir, 'real.h5')
        with h5py.File(filename, "w") as f:
            f.create_dataset("data", data=ts_hreal)

    if 'imag' in args.data_to_save:
        if args.verbose:
            logger.info("saving imag")
        filename = os.path.join(args.output_dir, 'imag.h5')
        with h5py.File(filename, "w") as f:
            f.create_dataset("data", data=ts_himag)
