#!/usr/bin/env python

"""
uses the output from scrinet_build_rb

In scrinet_build_rb it finds a reduced greedy basis.
In this script we read in that basis
and compute the training set which are the projection
coefficients and save them to a file
so that they can be read in by scrinet_fit which
will build fits to them.

example

# training data to fit
./scrinet_gen_ts_data --train-or-val train --data-to-model amp -v --basis-method eim --wf-dir train_wf_data --basis-dir rb

# validation data to test fit
./scrinet_gen_ts_data --train-or-val val --data-to-model amp -v --basis-method eim --wf-dir validation_wf_data --basis-dir rb
"""

import os
import sys
import argparse
import numpy as np
import h5py
from scrinet.greedy import greedyrb

from scrinet.workflow.pipe_utils import init_logger, load_data
# , load_greedy_points


def compute_projection_coefficients(
        ts,
        basis_method,
        eim_indices=None,
        grb=None,
        grb_basis=None):
    """
    given an input training set (ts) compute projection
    coefficients using the basis_method
    basis_method is 'rb' or 'eim'.
    if 'eim' then eim_indicies is required.
    if 'rb' then grb and grb_basis is required.
    """

    if basis_method == 'eim':
        assert eim_indices is not None, "eim_indicies is None"
        idxs = eim_indices
        alpha = np.transpose(ts)[idxs].T
    elif basis_method == 'rb':
        assert grb is not None, "grb is None"
        assert grb_basis is not None, "grb_basis is None"
        alpha = grb.compute_projection_coefficients_array(grb_basis, ts)

    return alpha


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("--data-to-model", type=str,
                        help="What to model. Name of h5 file too with out ext.",
                        choices=['time', 'amp', 'phase', 'freq', 'real', 'imag', 'alpha', 'beta', 'gamma'])
    parser.add_argument("--train-or-val", type=str, required=True,
                        help="sub-directory to save data",
                        choices=['train', 'val'])
    parser.add_argument("--output-dir", type=str, default='ts',
                        help="directory to save data")
    parser.add_argument("--basis-dir", type=str, default='rb',
                        help="root directory of basis data")
    parser.add_argument("--wf-dir", type=str, required=True,
                        help="directory of waveform data")
    parser.add_argument("--basis-method", type=str, required=True,
                        choices=['rb', 'eim'],
                        help="basis method to compute project coefficients")
    parser.add_argument("-v", "--verbose", help="increase output verbosity",
                        action="store_true")

    args = parser.parse_args()

    logger = init_logger()

    if args.verbose:
        logger.info(f"current file: {__file__}")
        logger.info("verbosity turned on")
        logger.info(f"basis method: {args.basis_method}")
        logger.info(f"basis dir: {args.basis_dir}")

    os.environ['OMP_NUM_THREADS'] = str(1)
    if args.verbose:
        logger.info("OMP_NUM_THREADS: 1")

    output_data_dir = os.path.join(
        args.output_dir, args.data_to_model, args.train_or_val)

    if args.verbose:
        logger.info("making outdir tree")
        logger.info(f"making dir: {output_data_dir}")
    os.makedirs(f"{output_data_dir}", exist_ok=True)

    if args.verbose:
        logger.info(f"data to model: {args.data_to_model}")

    basis_sub_dir = os.path.join(args.basis_dir, args.data_to_model)
    if args.verbose:
        logger.info(f"basis sub directory: {basis_sub_dir}")

    # if args.verbose:
    #     logger.info("reading greedy points")
    #     greedy_points = load_greedy_points(
    #         os.path.join(basis_sub_dir, "greedy_points.h5"))

    if args.verbose:
        logger.info(f"loading wf data from {args.wf_dir}")
    wf_x, wf_ts, wf_ts_coords = load_data(
        data_to_model=args.data_to_model, dir_name=args.wf_dir)

    grb = None
    grb_basis = None
    eim_indices = None
    if args.basis_method == 'rb':
        if args.verbose:
            logger.info("inside rb")
        if args.verbose:
            logger.info("setting up integration rule")
        int_range = [wf_x[0], wf_x[-1]]
        int_num = len(wf_x)
        integration = greedyrb.Riemann(int_range, num=int_num)
        if args.verbose:
            logger.info("Instantiating GreedyReducesBasis instance")
        grb = greedyrb.GreedyReducedBasis(integration=integration)
        grb_basis_fname = f"{args.data_to_model}_greedy_basis.npy"
        grb_basis = np.load(os.path.join(basis_sub_dir, grb_basis_fname))
    elif args.basis_method == 'eim':
        if args.verbose:
            logger.info("inside eim")
        eim_ind_fname = f"{args.data_to_model}_eim_indices.npy"
        eim_indices = np.load(os.path.join(basis_sub_dir, eim_ind_fname))
    else:
        raise ValueError(f"basis method: {args.basis_method} not valid")

    if args.verbose:
        logger.info("computing projection coefficients")
    alpha_ts = compute_projection_coefficients(
        wf_ts, args.basis_method, eim_indices=eim_indices, grb=grb, grb_basis=grb_basis)

    if args.verbose:
        logger.info("logging mass-ratio")

    X = wf_ts_coords.copy()
    X[:, 0] = np.log(X[:, 0])
    y = alpha_ts

    if args.verbose:
        logger.info("saving X data")
    np.save(os.path.join(output_data_dir, "X.npy"), X)
    if args.verbose:
        logger.info("saving y data")
    np.save(os.path.join(output_data_dir, "y.npy"), y)
