#!/usr/bin/env python

"""
reads data generated by scrinet_gen_wf_data
and builds a reduced greedy basis and EIM basis

example

./scrinet_build_rb --data-to-model amp -v --greedy-tol 1e-6

to use the rompy reduced basis code add the '--use-rompy' flag

./scrinet_build_rb --data-to-model amp -v --greedy-tol 1e-6 --use-rompy

"""
import rompy as rp
from scrinet.workflow.pipe_utils import init_logger, load_data
from scrinet.greedy import greedyrb
import h5py
import numpy as np
import argparse
import sys
import os
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
matplotlib.rcParams.update({'font.size': 16})


def plot_greedy_error(greedy_errors, output_dir, output_name):
    plt.figure()
    plt.plot(range(len(greedy_errors)), greedy_errors)
    plt.yscale('log')
    plt.xlabel("greedy point")
    plt.ylabel("greedy error")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, output_name))
    plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("--data-to-model", type=str,
                        help="What to model. Name of h5 file too with out ext.",
                        choices=['amp', 'phase', 'freq', 'real', 'imag', 'alpha', 'beta', 'gamma'])

    parser.add_argument("--output-dir", type=str, default='rb',
                        help="directory to save data")

    parser.add_argument("--seed-dir", type=str, default='seed_wf_data',
                        help="directory of seed waveform data")
    parser.add_argument("--train-dir", type=str, default='train_wf_data',
                        help="directory of train waveform data")

    parser.add_argument("--greedy-tol", type=float, default='1e-6',
                        help="greedy basis error tolerance")

    parser.add_argument("--use-rompy", help="use rompy reduced basis code",
                        action="store_true")
    parser.add_argument("--max-num-basis", type=int, default=None,
                        help="maximum number of basis vectors. Only compatible with --use-rompy")
    parser.add_argument("--rel", action="store_true",
                        help="use relative errors. Only compatible with --use-rompy")

    parser.add_argument("-v", "--verbose", help="increase output verbosity",
                        action="store_true")

    args = parser.parse_args()

    logger = init_logger()

    if args.verbose:
        logger.info(f"current file: {__file__}")
        logger.info("verbosity turned on")

    os.environ['OMP_NUM_THREADS'] = str(1)
    if args.verbose:
        logger.info("OMP_NUM_THREADS: 1")

    if args.verbose:
        if args.use_rompy:
            logger.info("using rompy backend")

    output_data_dir = os.path.join(args.output_dir, args.data_to_model)

    if args.verbose:
        logger.info(f"making dir: {output_data_dir}")
    os.makedirs(f"{output_data_dir}", exist_ok=True)

    if args.verbose:
        logger.info(f"output root dir: {args.output_dir}")
        logger.info(f"output data subdir dir: {output_data_dir}")
        logger.info(f"data to model: {args.data_to_model}")

    if args.use_rompy:

        if args.verbose:
            logger.info(f"loading training data from {args.train_dir}")
        ts_x, ts, ts_coords = load_data(
            data_to_model=args.data_to_model, dir_name=args.train_dir)

        integration = rp.Integration(
            [ts_x[0], ts_x[-1]], len(ts_x), rule="trapezoidal")
        x = integration.nodes
        rb = rp.ReducedBasis(integration)

        if args.max_num_basis is None:
            args.max_num_basis = ts.shape[0]

        if args.verbose:
            logger.info(f"max-num-basis: {args.max_num_basis}")

        if args.verbose:
            logger.info("running rb.make")

        rb.make(ts, 0, args.greedy_tol,
                verbose=args.verbose, num=args.max_num_basis,
                rel=args.rel)

        if args.verbose:
            logger.info("rb.make complete")

        greedy_errors = rb.errors
        nbasis = rb.size
        greedy_points = ts_coords[rb.indices]
        basis = rb.basis

    else:

        if args.verbose:
            logger.info(f"loading seed data from {args.seed_dir}")
        seed_x, seed_ts, seed_ts_coords = load_data(
            data_to_model=args.data_to_model, dir_name=args.seed_dir)

        if args.verbose:
            logger.info("Making integration rule")
        int_range = [seed_x[0], seed_x[-1]]
        int_num = len(seed_x)
        integration = greedyrb.Riemann(int_range, num=int_num)
        x = integration.nodes  # Define x for convenience

        if args.verbose:
            logger.info("Instantiating GreedyReducesBasis instance")
        grb = greedyrb.GreedyReducedBasis(integration=integration)

        if args.verbose:
            logger.info("building seed basis")
        grb.build_seed_basis(ts=seed_ts, ts_coords=seed_ts_coords)

        # now that we have the seed basis we need to build the greedy reduced basis
        if args.verbose:
            logger.info(f"loading training data from {args.train_dir}")
        ts_x, ts, ts_coords = load_data(
            data_to_model=args.data_to_model, dir_name=args.train_dir)

        # the x grid should be the same
        if args.verbose:
            logger.info("checking x grids are the same")
        np.testing.assert_array_equal(seed_x, ts_x)

        if args.verbose:
            logger.info(f"greedy tol: {args.greedy_tol}")
            logger.info("running greedy sweep")
        # using the training set we add points until we reach the greedy_tol
        grb.greedy_sweep(ts, ts_coords, verbose=args.verbose,
                         greedy_tol=args.greedy_tol)

        if args.verbose:
            logger.info("greedy sweep completed")

        greedy_errors = grb.greedy_errors
        nbasis = grb.nbasis
        greedy_points = grb.greedy_points
        basis = grb.basis

    filename = os.path.join(output_data_dir, 'greedy_errors.h5')
    if args.verbose:
        logger.info(f"Saving greedy errors: {filename}")
    with h5py.File(filename, "w") as f:
        f.create_dataset("data", data=greedy_errors)

    if args.verbose:
        logger.info("Plotting greedy errors")
    plot_greedy_error(greedy_errors, output_data_dir, "greedy_errors.png")

    if args.verbose:
        logger.info(f"Number of greedy points: {nbasis}")

    filename = os.path.join(output_data_dir, 'greedy_points.h5')
    if args.verbose:
        logger.info(f"saving greedy points: {filename}")
    with h5py.File(filename, "w") as f:
        f.create_dataset("data", data=greedy_points)

    filename = args.data_to_model + "_greedy_basis.npy"
    if args.verbose:
        logger.info(f"saving greedy basis: {filename}")
    np.save(os.path.join(output_data_dir, filename), basis)

    if args.verbose:
        logger.info("building EIM basis")

    if args.use_rompy:
        eim = rp.EmpiricalInterpolant(rb.basis, verbose=args.verbose)
        eim_basis = eim.B
        eim_indices = eim.indices
    else:
        grb.setup_eim()
        eim_basis = grb.eim.B
        eim_indices = grb.eim.indices

    filename = args.data_to_model + "_eim_basis.npy"
    if args.verbose:
        logger.info(f"saving eim basis: {filename}")
    np.save(os.path.join(output_data_dir, filename), eim_basis)

    filename = args.data_to_model + "_eim_indices.npy"
    if args.verbose:
        logger.info(f"saving eim indices: {filename}")
    np.save(os.path.join(output_data_dir, filename), eim_indices)

    if args.verbose:
        logger.info("finished")
