#!/usr/bin/env python

"""
loads models computed using scrinet_fit

example

scrinet_evaluate_model -v --basis-model amp-phase --amp-basis rb/amp/amp_eim_basis.npy  --amp-model-dir ts/amp/fits --phase-basis rb/phase/phase_eim_basis.npy  --phase-model-dir ts/phase/fits --wf-dir train_wf_data


scrinet_evaluate_model -v --basis-model real-imag --real-basis rb/real/real_eim_basis.npy  --real-model-dir ts/real/fits --imag-basis rb/imag/imag_eim_basis.npy  --imag-model-dir ts/imag/fits --wf-dir train_wf_data
"""

from scrinet.results import results
from scrinet.workflow.pipe_utils import init_logger, load_data, load_model, wave_sur_many, match, real_imag_wave_sur_many
import tensorflow as tf
import h5py
import pathlib
import time
import tqdm
import argparse
import numpy as np
import os
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams.update({'font.size': 16})


def find_nearest_idx(a, a0):
    """
    Find closest index in array 'a' to point 'a0'
    https://stackoverflow.com/a/10465997/12840171
    """
    idx = np.abs(a - a0).argmin()
    return idx


def plot_waveform(times, h1, h2, filename):
    fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True, figsize=(14, 4))
    ax1.plot(times, h1, label='data')
    ax1.plot(times, h2, label='sur', ls='--')

    ax2.plot(times, h1, label='data')
    ax2.plot(times, h2, label='sur', ls='--')
    ax2.set_xlim(-500, 100)

    ax1.legend()
    fig.savefig(filename)
    plt.close()


def plot_corner(wf_ts_coords, matches, max_match, outname):
    parameters = ['q', 'chi1z', 'chi2z']
    mask = matches <= max_match
    samples = np.recarray(len(wf_ts_coords[mask]), dtype=[
                          (p, float) for p in parameters])
    samples['q'] = wf_ts_coords[mask, 0]
    samples['chi1z'] = wf_ts_coords[mask, 1]
    samples['chi2z'] = wf_ts_coords[mask, 2]

    zvals = matches[mask]

    pars_to_show = ['q', 'chi1z', 'chi2z']
    labels = {'q': r'q', 'chi1z': r'$\chi_{1z}$', 'chi2z': r'$\chi_{1z}$'}
    mins = {'q': 0.99*np.min(samples['q']), 'chi1z': 0.99 * np.min(
        samples['chi1z']), 'chi2z': 0.99 * np.min(samples['chi2z'])}
    maxs = {'q': 1.01*np.max(samples['q']), 'chi1z': 1.01 * np.max(
        samples['chi1z']), 'chi2z': 1.01 * np.max(samples['chi2z'])}

    fig, axes_dict = results.create_multidim_plot(
        pars_to_show,
        samples,
        zvals=zvals,
        show_colorbar=True,
        labels=labels,
        mins=mins,
        maxs=maxs,
        cbar_label='match',
        cb_scale=14)
    fig.suptitle(f'matches <= {max_match:.4f}')
    fig.tight_layout()
    fig.savefig(outname)


def get_strain_from_ts_single_case(data_name, wf_dir, index):
    """[summary]

    Args:
        data_name (str): either "amp-phase" or "real-imag"
        wf_dir ([type]): [description]
        index ([type]): [description]
    """
    start_index = index
    end_index = start_index+1

    assert data_name in ["amp-phase", "real-imag"]
    if data_name == "amp-phase":
        wf_x, wf_ts_amp, wf_ts_coords = load_data(
            data_to_model="amp", dir_name=wf_dir, start_index=start_index, end_index=end_index)
        _, wf_ts_phase, _ = load_data(
            data_to_model="phase", dir_name=wf_dir, start_index=start_index, end_index=end_index)

        vs_amp = wf_ts_amp[0]
        vs_phase = wf_ts_phase[0]
        wf_ts_h22 = vs_amp * np.exp(-1.j*vs_phase)
    elif data_name == "real-imag":
        wf_x, wf_ts_real, wf_ts_coords = load_data(
            data_to_model="real", dir_name=wf_dir, start_index=start_index, end_index=end_index)
        _, wf_ts_imag, _ = load_data(
            data_to_model="imag", dir_name=wf_dir, start_index=start_index, end_index=end_index)

        vs_hreal = wf_ts_real[0]
        vs_himag = wf_ts_imag[0]
        wf_ts_h22 = vs_hreal - 1.j * vs_himag

    return wf_ts_h22


def get_strain_from_model_single_case(data_name, wf_ts_coords, amp_model, amp_basis, phase_model, phase_basis):
    if data_name == 'amp-phase':
        sur_hp, sur_hc, _, _ = wave_sur_many(
            wf_ts_coords, amp_model, amp_basis, phase_model, phase_basis)
    elif data_name == 'real-imag':
        sur_hp, sur_hc = real_imag_wave_sur_many(
            wf_ts_coords, real_model, real_basis, imag_model, imag_basis)

    return sur_hp[0]


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("--output-dir", type=str, default='evaluate',
                        help="directory to save data")
    parser.add_argument("--wf-dir", type=str, required=True,
                        help="directory of waveform data")
    parser.add_argument("--num-per-chunk", type=int, default=100,
                        help="number of waveforms to load at a time")

    parser.add_argument("--basis-model", type=str,
                        required=True,
                        help="data type of basis",
                        choices=['amp-phase', 'real-imag'])

    parser.add_argument("--amp-basis", type=str,
                        help="path to amp basis")
    parser.add_argument("--amp-model-dir", type=str,
                        help="dir of amp model weights and scalers")
    parser.add_argument("--phase-basis", type=str,
                        help="path to phase basis")
    parser.add_argument("--phase-model-dir", type=str,
                        help="dir of phase model weights and scalers")

    parser.add_argument("--real-basis", type=str,
                        help="path to real basis")
    parser.add_argument("--real-model-dir", type=str,
                        help="dir of real model weights and scalers")
    parser.add_argument("--imag-basis", type=str,
                        help="path to imag basis")
    parser.add_argument("--imag-model-dir", type=str,
                        help="dir of imag model weights and scalers")

    parser.add_argument("-v", "--verbose", help="increase output verbosity",
                        action="store_true")

    args = parser.parse_args()

    logger = init_logger()

    if args.verbose:
        logger.info(f"current file: {__file__}")
        logger.info("verbosity turned on")

        logger.info(f"basis-model: {args.basis_model}")

        if args.basis_model == 'amp-phase':
            logger.info(f"amp basis: {args.amp_basis}")
            logger.info(f"amp model dir: {args.amp_model_dir}")
            logger.info(f"phase basis: {args.phase_basis}")
            logger.info(f"phase model dir: {args.phase_model_dir}")
        elif args.basis_model == 'real-imag':
            logger.info(f"real basis: {args.real_basis}")
            logger.info(f"real model dir: {args.real_model_dir}")
            logger.info(f"imag basis: {args.imag_basis}")
            logger.info(f"imag model dir: {args.imag_model_dir}")

    sub_dir = pathlib.PurePath(args.wf_dir).name
    output_data_dir = os.path.join(args.output_dir, sub_dir)

    if args.verbose:
        logger.info("making outdir tree")
        logger.info(f"making dir: {output_data_dir}")
    os.makedirs(f"{output_data_dir}", exist_ok=True)

    # specify number of threads to use for tensorflow
    tf.config.threading.set_inter_op_parallelism_threads(1)
    tf.config.threading.set_intra_op_parallelism_threads(1)
    os.environ['OMP_NUM_THREADS'] = str(1)
    if args.verbose:
        logger.info(
            f"tf using {tf.config.threading.get_inter_op_parallelism_threads()} inter_op_parallelism_threads thread(s)")
        logger.info(
            f"tf using {tf.config.threading.get_intra_op_parallelism_threads()} intra_op_parallelism_threads thread(s)")
        logger.info("OMP_NUM_THREADS: 1")

    if args.basis_model == 'amp-phase':
        if args.verbose:
            logger.info(f"loading amp model")

        model_file = os.path.join(args.amp_model_dir, "best.h5")
        if os.path.exists(model_file) is False:
            model_file = os.path.join(args.amp_model_dir, "model.h5")

        X_scalers_file = os.path.join(args.amp_model_dir, "X_scalers.npy")
        if os.path.exists(X_scalers_file) is False:
            X_scalers_file = ""

        Y_scalers_file = os.path.join(args.amp_model_dir, "Y_scalers.npy")
        if os.path.exists(Y_scalers_file) is False:
            Y_scalers_file = os.path.join(args.amp_model_dir, "y_scalers.npy")
            if os.path.exists(Y_scalers_file) is False:
                Y_scalers_file = ""

        amp_model, amp_basis = load_model(
            basis_file=args.amp_basis,
            nn_weights_file=model_file,
            X_scalers_file=X_scalers_file,
            Y_scalers_file=Y_scalers_file)

        if args.verbose:
            logger.info(f"loading phase model")

        model_file = os.path.join(args.phase_model_dir, "best.h5")
        if os.path.exists(model_file) is False:
            model_file = os.path.join(args.phase_model_dir, "model.h5")

        X_scalers_file = os.path.join(args.phase_model_dir, "X_scalers.npy")
        if os.path.exists(X_scalers_file) is False:
            X_scalers_file = ""

        Y_scalers_file = os.path.join(args.phase_model_dir, "Y_scalers.npy")
        if os.path.exists(Y_scalers_file) is False:
            Y_scalers_file = os.path.join(
                args.phase_model_dir, "y_scalers.npy")
            if os.path.exists(Y_scalers_file) is False:
                Y_scalers_file = ""

        phase_model, phase_basis = load_model(
            basis_file=args.phase_basis,
            nn_weights_file=model_file,
            X_scalers_file=X_scalers_file,
            Y_scalers_file=Y_scalers_file)

    elif args.basis_model == 'real-imag':
        if args.verbose:
            logger.info(f"loading real model")

        real_model, real_basis = load_model(
            basis_file=args.real_basis,
            nn_weights_file=os.path.join(args.real_model_dir, "best.h5"),
            X_scalers_file=os.path.join(args.real_model_dir, "X_scalers.npy"),
            Y_scalers_file=os.path.join(args.real_model_dir, "Y_scalers.npy"))

        if args.verbose:
            logger.info(f"loading imag model")
        imag_model, imag_basis = load_model(
            basis_file=args.imag_basis,
            nn_weights_file=os.path.join(args.imag_model_dir, "best.h5"),
            X_scalers_file=os.path.join(args.imag_model_dir, "X_scalers.npy"),
            Y_scalers_file=os.path.join(args.imag_model_dir, "Y_scalers.npy"))

    if args.verbose:
        logger.info(f"loading wf data from {args.wf_dir}")

    all_matches = []
    all_coords = []
    # load coords to get number of times to loop
    with h5py.File(os.path.join(args.wf_dir, 'coords.h5'), 'r') as f:
        num = f['data'].shape[0]

    num_per_chunk = args.num_per_chunk
    if args.verbose:
        logger.info(f"num_per_chunk: {args.num_per_chunk}")
    if args.num_per_chunk >= num:
        num_per_chunk = num  # just in case we have fewer than 100 cases
        if args.verbose:
            logger.info(
                f"args.num_per_chunk ({args.num_per_chunk}) >= num ({num})")
            logger.info("setting num_per_chunk to num")
    num_chunks = int(num / num_per_chunk)
    splits = np.array_split(range(num), num_chunks)
    start_end = []  # contains indicies of the start and end of the chunk
    # ensure we don't miss any points due to indexing and missing one each time
    for i in range(len(splits)):
        if i == 0:
            start_end.append([splits[i][0], splits[i][-1]])
        else:
            start_end.append([splits[i-1][-1], splits[i][-1]])

    for i in range(len(start_end)):
        start_index = start_end[i][0]
        end_index = start_end[i][1]
        if args.verbose:
            logger.info(f"{i} / {num_chunks - 1}")
            logger.info(f"start_index: {start_index}, end_index: {end_index}")

        if args.verbose:
            logger.info("loading data")

        if args.basis_model == 'amp-phase':
            wf_x, wf_ts_amp, wf_ts_coords = load_data(
                data_to_model="amp", dir_name=args.wf_dir, start_index=start_index, end_index=end_index)
            _, wf_ts_phase, _ = load_data(
                data_to_model="phase", dir_name=args.wf_dir, start_index=start_index, end_index=end_index)

            wf_ts_h22 = np.zeros(shape=wf_ts_amp.shape, dtype=np.complex128)
            for i in range(wf_ts_h22.shape[0]):
                vs_amp = wf_ts_amp[i]
                vs_phase = wf_ts_phase[i]
                wf_ts_h22[i] = vs_amp * np.exp(-1.j*vs_phase)
        elif args.basis_model == 'real-imag':
            wf_x, wf_ts_real, wf_ts_coords = load_data(
                data_to_model="real", dir_name=args.wf_dir, start_index=start_index, end_index=end_index)
            _, wf_ts_imag, _ = load_data(
                data_to_model="imag", dir_name=args.wf_dir, start_index=start_index, end_index=end_index)

            wf_ts_h22 = np.zeros(shape=wf_ts_real.shape, dtype=np.complex128)
            for i in range(wf_ts_h22.shape[0]):
                vs_hreal = wf_ts_real[i]
                vs_himag = wf_ts_imag[i]
                wf_ts_h22[i] = vs_hreal - 1.j * vs_himag

        qs = wf_ts_coords[:, 0]
        if wf_ts_coords.shape[1] == 2:
            chis = wf_ts_coords[:, 1]
        # chi1xs = wf_ts_coords[:,1]
        # chi1zs = wf_ts_coords[:,2]
        total_number_wfs = len(qs)

        if args.verbose:
            logger.info(f"total number of waveforms: {total_number_wfs}")
            logger.info("generating surrogate data")

        t1 = time.time()
        if args.basis_model == 'amp-phase':
            sur_hp, sur_hc, _, _ = wave_sur_many(
                wf_ts_coords, amp_model, amp_basis, phase_model, phase_basis)
        elif args.basis_model == 'real-imag':
            sur_hp, sur_hc = real_imag_wave_sur_many(
                wf_ts_coords, real_model, real_basis, imag_model, imag_basis)
        t2 = time.time()
        dt_sur = t2-t1

        if args.verbose:
            logger.info(f"time taken (surrogate) = {dt_sur:.5f} s")
            logger.info(
                f"time per waveform (surrogate) = {dt_sur/total_number_wfs:.5f} s")

        if args.verbose:
            logger.info("computing matches")
        matches = np.zeros(len(qs))
        for i in tqdm.tqdm(range(len(qs))):
            vs_h = wf_ts_h22[i]
            vs_hp = np.real(vs_h)
            # vs_hc = np.imag(vs_h)

            maxmatch = np.max(np.abs(match(vs_hp, sur_hp[i], wf_x)))

            matches[i] = maxmatch

        all_matches.append(matches)
        all_coords.append(wf_ts_coords)

    matches = np.concatenate((np.array(all_matches)))
    wf_ts_coords = np.row_stack((np.array(all_coords)))

    worst_idx = np.argmin(matches)

    if args.verbose:
        logger.info(f"worst match: {np.min(matches):.7f}")
        logger.info(f"best match : {np.max(matches):.7f}")
        logger.info(f"median match : {np.median(matches):.7f}")
        logger.info(f"st. dev match : {np.std(matches):.7f}")

    if args.verbose:
        best_idx = np.argmax(matches)
        best_coords = wf_ts_coords[best_idx]
        logger.info(f"best coords: {best_coords}")

    result_file = os.path.join(output_data_dir, "matches.h5")
    if args.verbose:
        logger.info(f"saving matches: {result_file}")
    with h5py.File(result_file, "w") as f:
        f.create_dataset("coords", data=wf_ts_coords)
        f.create_dataset("data", data=matches)

    if args.verbose:
        logger.info("plotting...")

    # plot some waveforms to see what they look like
    if args.verbose:
        logger.info("plotting best match waveform")

    best_idx = np.argmax(matches)

    best_ts_h22 = get_strain_from_ts_single_case(
        args.basis_model, args.wf_dir, best_idx)

    coords = wf_ts_coords[best_idx].reshape(1, -1)
    if args.basis_model == 'amp-phase':
        best_sur_h22 = get_strain_from_model_single_case(
            args.basis_model, coords, amp_model, amp_basis, phase_model, phase_basis)
    elif args.basis_model == 'real-imag':
        best_sur_h22 = get_strain_from_model_single_case(
            args.basis_model, coords, real_model, real_basis, imag_model, imag_basis)

    filename = os.path.join(output_data_dir, "best_match.png")
    plot_waveform(wf_x, np.real(best_ts_h22), best_sur_h22, filename)

    if args.verbose:
        logger.info("plotting worst  match waveform")

    worst_idx = np.argmin(matches)

    worst_ts_h22 = get_strain_from_ts_single_case(
        args.basis_model, args.wf_dir, worst_idx)

    coords = wf_ts_coords[worst_idx].reshape(1, -1)
    if args.basis_model == 'amp-phase':
        worst_sur_h22 = get_strain_from_model_single_case(
            args.basis_model, coords, amp_model, amp_basis, phase_model, phase_basis)
    elif args.basis_model == 'real-imag':
        worst_sur_h22 = get_strain_from_model_single_case(
            args.basis_model, coords, real_model, real_basis, imag_model, imag_basis)

    filename = os.path.join(output_data_dir, "worst_match.png")
    plot_waveform(wf_x, np.real(worst_ts_h22), worst_sur_h22, filename)

    if args.verbose:
        logger.info("plotting median match waveform")

    median_idx = find_nearest_idx(matches, np.median(matches))

    median_ts_h22 = get_strain_from_ts_single_case(
        args.basis_model, args.wf_dir, median_idx)

    coords = wf_ts_coords[median_idx].reshape(1, -1)
    if args.basis_model == 'amp-phase':
        median_sur_h22 = get_strain_from_model_single_case(
            args.basis_model, coords, amp_model, amp_basis, phase_model, phase_basis)
    elif args.basis_model == 'real-imag':
        median_sur_h22 = get_strain_from_model_single_case(
            args.basis_model, coords, real_model, real_basis, imag_model, imag_basis)

    filename = os.path.join(output_data_dir, "median_match.png")
    plot_waveform(wf_x, np.real(median_ts_h22), median_sur_h22, filename)

    plt.figure()
    plt.scatter(range(len(matches)), 1-matches)
    plt.yscale('log')
    plt.ylim(1e-9)
    plt.xlabel("case")
    plt.ylabel("mismatch")
    plt.tight_layout()
    plt.savefig(os.path.join(output_data_dir, "1d-matches-vs-model.png"))
    plt.close()

    if wf_ts_coords.shape[1] == 3:
        if args.verbose:
            logger.info("plotting corner1.png")
        max_match = np.percentile(matches, 1)
        plot_corner(wf_ts_coords, matches, max_match=max_match,
                    outname=os.path.join(output_data_dir, "corner1.png"))
        if args.verbose:
            logger.info("plotting corner2.png")
        plot_corner(wf_ts_coords, matches, max_match=1.,
                    outname=os.path.join(output_data_dir, "corner2.png"))

    if wf_ts_coords.shape[1] == 1:
        plt.figure()
        plt.scatter(qs, 1-matches)
        plt.yscale('log')
        plt.ylim(1e-9)
        plt.xlabel("q")
        plt.ylabel("mismatch")
        plt.tight_layout()
        plt.savefig(os.path.join(output_data_dir, "1d-q-matches-vs-model.png"))
        plt.close()

    elif wf_ts_coords.shape[1] == 2:

        plt.figure()
        plt.scatter(qs, 1-matches)
        plt.yscale('log')
        plt.ylim(1e-9)
        plt.xlabel("q")
        plt.ylabel("mismatch")
        plt.tight_layout()
        plt.savefig(os.path.join(output_data_dir, "1d-q-matches-vs-model.png"))
        plt.close()

        plt.figure()
        plt.scatter(chis, 1-matches)
        plt.yscale('log')
        plt.ylim(1e-9)
        plt.xlabel(r"$\chi$")
        plt.ylabel("mismatch")
        plt.tight_layout()
        plt.savefig(os.path.join(output_data_dir,
                                 "1d-chi-matches-vs-model.png"))
        plt.close()

        plt.figure()
        plt.scatter(qs, chis, c=matches)
        plt.colorbar()
        plt.scatter(qs[worst_idx], chis[worst_idx], marker='o', s=100, c='k')
        plt.xlabel("q")
        plt.ylabel(r"$\chi$")
        plt.tight_layout()
        plt.savefig(os.path.join(output_data_dir, "2d-matches-vs-model.png"))
        plt.close()

    if args.verbose:
        logger.info("finished!")
