#!/usr/bin/env python

"""

For a given reduced basis compute the representation error (or match)
of the basis with respect to the training set

scrinet_representation_error -v --basis-model amp-phase --amp-basis rb/amp/amp_eim_basis.npy --amp-alpha rb/amp/amp_eim_indices.npy --phase-basis rb/phase/phase_eim_basis.npy --phase-alpha rb/phase/phase_eim_indices.npy --wf-dir train_wf_data

scrinet_representation_error -v --basis-model real-imag --real-basis rb/real/real_eim_basis.npy --real-alpha rb/real/real_eim_indices.npy --imag-basis rb/imag/imag_eim_basis.npy --imag-alpha rb/imag/imag_eim_indices.npy --wf-dir train_wf_data

"""

from scrinet.results import results
from scrinet.workflow.pipe_utils import init_logger, load_data, match
import tensorflow as tf
import h5py
import pathlib
import time
import tqdm
import argparse
import numpy as np
import os
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams.update({'font.size': 16})


def find_nearest_idx(a, a0):
    """
    Find closest index in array 'a' to point 'a0'
    https://stackoverflow.com/a/10465997/12840171
    """
    idx = np.abs(a - a0).argmin()
    return idx


def plot_corner(wf_ts_coords, matches, max_match, outname):
    parameters = ['q', 'chi1z', 'chi2z']
    mask = matches <= max_match
    samples = np.recarray(len(wf_ts_coords[mask]), dtype=[
                          (p, float) for p in parameters])
    samples['q'] = wf_ts_coords[mask, 0]
    samples['chi1z'] = wf_ts_coords[mask, 1]
    samples['chi2z'] = wf_ts_coords[mask, 2]

    zvals = matches[mask]

    pars_to_show = ['q', 'chi1z', 'chi2z']
    labels = {'q': r'q', 'chi1z': r'$\chi_{1z}$', 'chi2z': r'$\chi_{1z}$'}
    mins = {'q': 0.99*np.min(samples['q']), 'chi1z': 0.99 * np.min(
        samples['chi1z']), 'chi2z': 0.99 * np.min(samples['chi2z'])}
    maxs = {'q': 1.01*np.max(samples['q']), 'chi1z': 1.01 * np.max(
        samples['chi1z']), 'chi2z': 1.01 * np.max(samples['chi2z'])}

    fig, axes_dict = results.create_multidim_plot(
        pars_to_show,
        samples,
        zvals=zvals,
        show_colorbar=True,
        labels=labels,
        mins=mins,
        maxs=maxs,
        cbar_label='match',
        cb_scale=14)
    fig.suptitle(f'matches <= {max_match:.4f}')
    fig.tight_layout()
    fig.savefig(outname)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("--output-dir", type=str, default='rep_error',
                        help="directory to save data")
    parser.add_argument("--wf-dir", type=str, required=True,
                        help="directory of waveform data")
    parser.add_argument("--num-per-chunk", type=int, default=100,
                        help="number of waveforms to load at a time")

    parser.add_argument("--basis-model", type=str,
                        required=True,
                        help="data type of basis",
                        choices=['amp-phase', 'real-imag'])

    parser.add_argument("--amp-basis", type=str,
                        help="path to amp basis")
    parser.add_argument("--amp-alpha", type=str,
                        help="path to amp coefficients")

    parser.add_argument("--phase-basis", type=str,
                        help="path to phase basis")
    parser.add_argument("--phase-alpha", type=str,
                        help="path to phase coefficients")

    parser.add_argument("--real-basis", type=str,
                        help="path to real basis")
    parser.add_argument("--real-alpha", type=str,
                        help="path to real coefficients")
    parser.add_argument("--imag-basis", type=str,
                        help="path to imag basis")
    parser.add_argument("--imag-alpha", type=str,
                        help="path to imag coefficients")

    parser.add_argument("-v", "--verbose", help="increase output verbosity",
                        action="store_true")

    args = parser.parse_args()

    logger = init_logger()

    if args.verbose:
        logger.info(f"current file: {__file__}")
        logger.info("verbosity turned on")

        logger.info(f"basis-model: {args.basis_model}")

        if args.basis_model == 'amp-phase':
            logger.info(f"amp basis: {args.amp_basis}")
            logger.info(f"amp alpha: {args.amp_alpha}")
            logger.info(f"phase basis: {args.phase_basis}")
            logger.info(f"phase alpha: {args.phase_alpha}")
        elif args.basis_model == 'real-imag':
            logger.info(f"real basis: {args.real_basis}")
            logger.info(f"real alpha: {args.real_alpha}")
            logger.info(f"imag basis: {args.imag_basis}")
            logger.info(f"imag alpha: {args.imag_alpha}")

    sub_dir = pathlib.PurePath(args.wf_dir).name
    output_data_dir = os.path.join(args.output_dir, sub_dir)

    if args.verbose:
        logger.info("making outdir tree")
        logger.info(f"making dir: {output_data_dir}")
    os.makedirs(f"{output_data_dir}", exist_ok=True)

    # specify number of threads to use for tensorflow
    tf.config.threading.set_inter_op_parallelism_threads(1)
    tf.config.threading.set_intra_op_parallelism_threads(1)
    os.environ['OMP_NUM_THREADS'] = str(1)
    if args.verbose:
        logger.info(
            f"tf using {tf.config.threading.get_inter_op_parallelism_threads()} inter_op_parallelism_threads thread(s)")
        logger.info(
            f"tf using {tf.config.threading.get_intra_op_parallelism_threads()} intra_op_parallelism_threads thread(s)")
        logger.info("OMP_NUM_THREADS: 1")

    if args.basis_model == 'amp-phase':
        if args.verbose:
            logger.info(f"loading amp basis")
        amp_basis = np.load(args.amp_basis)
        amp_alpha = np.load(args.amp_alpha)

        if args.verbose:
            logger.info(f"loading phase basis")
        phase_basis = np.load(args.phase_basis)
        phase_alpha = np.load(args.phase_alpha)

    elif args.basis_model == 'real-imag':
        if args.verbose:
            logger.info(f"loading real basis")
        real_basis = np.load(args.real_basis)
        real_alpha = np.load(args.real_alpha)

        if args.verbose:
            logger.info(f"loading imag basis")
        imag_basis = np.load(args.imag_basis)
        imag_alpha = np.load(args.imag_alpha)

    if args.verbose:
        logger.info(f"loading wf data from {args.wf_dir}")

    all_matches = []
    all_coords = []
    # load coords to get number of times to loop
    with h5py.File(os.path.join(args.wf_dir, 'coords.h5'), 'r') as f:
        num = f['data'].shape[0]

    num_per_chunk = args.num_per_chunk
    if args.verbose:
        logger.info(f"num_per_chunk: {args.num_per_chunk}")
    if args.num_per_chunk >= num:
        num_per_chunk = num  # just in case we have fewer than 100 cases
        if args.verbose:
            logger.info(
                f"args.num_per_chunk ({args.num_per_chunk}) >= num ({num})")
            logger.info("setting num_per_chunk to num")
    num_chunks = int(num / num_per_chunk)
    splits = np.array_split(range(num), num_chunks)
    start_end = []  # contains indicies of the start and end of the chunk
    # ensure we don't miss any points due to indexing and missing one each time
    for i in range(len(splits)):
        if i == 0:
            start_end.append([splits[i][0], splits[i][-1]])
        else:
            start_end.append([splits[i-1][-1], splits[i][-1]])

    for i in range(len(start_end)):
        start_index = start_end[i][0]
        end_index = start_end[i][1]
        if args.verbose:
            logger.info(f"{i} / {num_chunks - 1}")
            logger.info(f"start_index: {start_index}, end_index: {end_index}")

        if args.verbose:
            logger.info("loading data")

        if args.basis_model == 'amp-phase':
            wf_x, wf_ts_amp, wf_ts_coords = load_data(
                data_to_model="amp", dir_name=args.wf_dir, start_index=start_index, end_index=end_index)
            _, wf_ts_phase, _ = load_data(
                data_to_model="phase", dir_name=args.wf_dir, start_index=start_index, end_index=end_index)

            wf_ts_h22 = np.zeros(shape=wf_ts_amp.shape, dtype=np.complex128)
            for i in range(wf_ts_h22.shape[0]):
                vs_amp = wf_ts_amp[i]
                vs_phase = wf_ts_phase[i]
                wf_ts_h22[i] = vs_amp * np.exp(-1.j*vs_phase)
        elif args.basis_model == 'real-imag':
            wf_x, wf_ts_real, wf_ts_coords = load_data(
                data_to_model="real", dir_name=args.wf_dir, start_index=start_index, end_index=end_index)
            _, wf_ts_imag, _ = load_data(
                data_to_model="imag", dir_name=args.wf_dir, start_index=start_index, end_index=end_index)

            wf_ts_h22 = np.zeros(shape=wf_ts_real.shape, dtype=np.complex128)
            for i in range(wf_ts_h22.shape[0]):
                vs_hreal = wf_ts_real[i]
                vs_himag = wf_ts_imag[i]
                wf_ts_h22[i] = vs_hreal - 1.j * vs_himag

        qs = wf_ts_coords[:, 0]
        if wf_ts_coords.shape[1] == 2:
            chis = wf_ts_coords[:, 1]
        # chi1xs = wf_ts_coords[:,1]
        # chi1zs = wf_ts_coords[:,2]
        total_number_wfs = len(qs)

        if args.verbose:
            logger.info(f"total number of waveforms: {total_number_wfs}")
            logger.info("generating reduced basis waveform data")

        t1 = time.time()
        if args.basis_model == 'amp-phase':
            rb_amp = np.dot(amp_alpha[start_index:end_index], amp_basis)
            rb_phase = np.dot(phase_alpha[start_index:end_index], phase_basis)
            rb_hp = np.real(rb_amp * np.exp(-1.j * rb_phase))
        elif args.basis_model == 'real-imag':
            rb_hp = np.dot(real_alpha[start_index:end_index], real_basis)
            # rb_imag = np.dot(imag_alpha[start_index:end_index], imag_basis)

        t2 = time.time()
        dt_sur = t2-t1

        if args.verbose:
            logger.info(f"time taken (rb) = {dt_sur:.5f} s")
            logger.info(
                f"time per waveform (rb) = {dt_sur/total_number_wfs:.5f} s")

        if args.verbose:
            logger.info("computing matches")
        matches = np.zeros(len(qs))
        for i in tqdm.tqdm(range(len(qs))):
            vs_h = wf_ts_h22[i]
            vs_hp = np.real(vs_h)
            # vs_hc = np.imag(vs_h)

            maxmatch = np.max(np.abs(match(vs_hp, rb_hp[i], wf_x)))

            matches[i] = maxmatch

        del wf_ts_h22
        del rb_hp

        all_matches.append(matches)
        all_coords.append(wf_ts_coords)

    matches = np.concatenate((np.array(all_matches)))
    wf_ts_coords = np.row_stack((np.array(all_coords)))

    worst_idx = np.argmin(matches)
    worst_match = matches[worst_idx]

    if args.verbose:
        logger.info(f"worst match: {worst_match:.16f}")
        logger.info(f"best match : {np.max(matches):.16f}")
        logger.info(f"median match : {np.median(matches):.16f}")
        logger.info(f"st. dev match : {np.std(matches):.16f}")

    if args.verbose:
        best_idx = np.argmax(matches)
        best_coords = wf_ts_coords[best_idx]
        logger.info(f"best coords: {best_coords}")

    result_file = os.path.join(output_data_dir, "matches.h5")
    if args.verbose:
        logger.info(f"saving matches: {result_file}")
    with h5py.File(result_file, "w") as f:
        f.create_dataset("coords", data=wf_ts_coords)
        f.create_dataset("data", data=matches)

    if args.verbose:
        logger.info("plotting...")

    mismatches = 1-matches
    ylim = [1e-17, 1]

    plt.figure()
    plt.scatter(range(len(matches)), mismatches)
    plt.yscale('log')
    plt.ylim(*ylim)
    plt.xlabel("case")
    plt.ylabel("mismatch")
    plt.tight_layout()
    plt.savefig(os.path.join(output_data_dir, "1d-matches-vs-model.png"))
    plt.close()

    if wf_ts_coords.shape[1] == 3:
        if args.verbose:
            logger.info("plotting corner1.png")
        max_match = np.percentile(matches, 1)
        plot_corner(wf_ts_coords, matches, max_match=max_match,
                    outname=os.path.join(output_data_dir, "corner1.png"))
        if args.verbose:
            logger.info("plotting corner2.png")
        plot_corner(wf_ts_coords, matches, max_match=1.,
                    outname=os.path.join(output_data_dir, "corner2.png"))

    if wf_ts_coords.shape[1] == 1:

        qs = wf_ts_coords[:, 0]

        plt.figure()
        plt.scatter(qs, mismatches)
        plt.yscale('log')
        plt.ylim(*ylim)
        plt.xlabel("q")
        plt.ylabel("mismatch")
        plt.tight_layout()
        plt.savefig(os.path.join(output_data_dir, "1d-q-matches-vs-model.png"))
        plt.close()

    elif wf_ts_coords.shape[1] == 2:
        qs = wf_ts_coords[:, 0]
        chis = wf_ts_coords[:, 1]

        plt.figure()
        plt.scatter(qs, mismatches)
        plt.yscale('log')
        plt.ylim(*ylim)
        plt.xlabel("q")
        plt.ylabel("mismatch")
        plt.tight_layout()
        plt.savefig(os.path.join(output_data_dir, "1d-q-matches-vs-model.png"))
        plt.close()

        plt.figure()
        plt.scatter(chis, mismatches)
        plt.yscale('log')
        plt.ylim(*ylim)
        plt.xlabel(r"$\chi$")
        plt.ylabel("mismatch")
        plt.tight_layout()
        plt.savefig(os.path.join(output_data_dir,
                                 "1d-chi-matches-vs-model.png"))
        plt.close()

        plt.figure()
        plt.scatter(qs, chis, c=matches)
        plt.colorbar()
        plt.scatter(qs[worst_idx], chis[worst_idx], marker='o', s=100, c='k')
        plt.xlabel("q")
        plt.ylabel(r"$\chi$")
        plt.tight_layout()
        plt.savefig(os.path.join(output_data_dir, "2d-matches-vs-model.png"))
        plt.close()

    if args.verbose:
        logger.info("finished!")
