#!/usr/bin/env python

"""
loads models computed using scrinet_fit

example

scrinet_evaluate_model -v --basis-model amp-phase --amp-basis rb/amp/amp_eim_basis.npy  --amp-model-dir ts/amp/fits --phase-basis rb/phase/phase_eim_basis.npy  --phase-model-dir ts/phase/fits --wf-dir train_wf_data


scrinet_evaluate_model -v --basis-model real-imag --real-basis rb/real/real_eim_basis.npy  --real-model-dir ts/real/fits --imag-basis rb/imag/imag_eim_basis.npy  --imag-model-dir ts/imag/fits --wf-dir train_wf_data
"""

from scrinet.results import results
from scrinet.workflow.pipe_utils import init_logger, load_data, load_model, wave_sur_many, match, real_imag_wave_sur_many
import tensorflow as tf
import h5py
import pathlib
import time
import tqdm
import argparse
import numpy as np
import os
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams.update({'font.size': 16})


def find_nearest_idx(a, a0):
    """
    Find closest index in array 'a' to point 'a0'
    https://stackoverflow.com/a/10465997/12840171
    """
    idx = np.abs(a - a0).argmin()
    return idx


def plot_waveform(times, h1, h2, filename):
    fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True, figsize=(14, 4))
    ax1.plot(times, h1, label='data')
    ax1.plot(times, h2, label='sur', ls='--')

    ax2.plot(times, h1, label='data')
    ax2.plot(times, h2, label='sur', ls='--')
    ax2.set_xlim(-500, 100)

    ax1.legend()
    fig.savefig(filename)
    plt.close()


def plot_corner(wf_ts_coords, matches, max_match, outname):
    parameters = ['q', 'chi1z', 'chi2z']
    mask = matches <= max_match
    samples = np.recarray(len(wf_ts_coords[mask]), dtype=[
                          (p, float) for p in parameters])
    samples['q'] = wf_ts_coords[mask, 0]
    samples['chi1z'] = wf_ts_coords[mask, 1]
    samples['chi2z'] = wf_ts_coords[mask, 2]

    zvals = matches[mask]

    pars_to_show = ['q', 'chi1z', 'chi2z']
    labels = {'q': r'q', 'chi1z': r'$\chi_{1z}$', 'chi2z': r'$\chi_{1z}$'}
    mins = {'q': 0.99*np.min(samples['q']), 'chi1z': 0.99 * np.min(
        samples['chi1z']), 'chi2z': 0.99 * np.min(samples['chi2z'])}
    maxs = {'q': 1.01*np.max(samples['q']), 'chi1z': 1.01 * np.max(
        samples['chi1z']), 'chi2z': 1.01 * np.max(samples['chi2z'])}

    fig, axes_dict = results.create_multidim_plot(
        pars_to_show,
        samples,
        zvals=zvals,
        show_colorbar=True,
        labels=labels,
        mins=mins,
        maxs=maxs,
        cbar_label='match',
        cb_scale=14)
    fig.suptitle(f'matches <= {max_match:.4f}')
    fig.tight_layout()
    fig.savefig(outname)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("--output-dir", type=str, default='evaluate',
                        help="directory to save data")
    parser.add_argument("--wf-dir", type=str, required=True,
                        help="directory of waveform data")

    parser.add_argument("--basis-model", type=str,
                        required=True,
                        help="data type of basis",
                        choices=['amp-phase', 'real-imag'])

    parser.add_argument("--amp-basis", type=str,
                        help="path to amp basis")
    parser.add_argument("--amp-model-dir", type=str,
                        help="dir of amp model weights and scalers")
    parser.add_argument("--phase-basis", type=str,
                        help="path to phase basis")
    parser.add_argument("--phase-model-dir", type=str,
                        help="dir of phase model weights and scalers")

    parser.add_argument("--real-basis", type=str,
                        help="path to real basis")
    parser.add_argument("--real-model-dir", type=str,
                        help="dir of real model weights and scalers")
    parser.add_argument("--imag-basis", type=str,
                        help="path to imag basis")
    parser.add_argument("--imag-model-dir", type=str,
                        help="dir of imag model weights and scalers")

    parser.add_argument("-v", "--verbose", help="increase output verbosity",
                        action="store_true")

    args = parser.parse_args()

    logger = init_logger()

    if args.verbose:
        logger.info(f"current file: {__file__}")
        logger.info("verbosity turned on")

        logger.info(f"basis-model: {args.basis_model}")

        if args.basis_model == 'amp-phase':
            logger.info(f"amp basis: {args.amp_basis}")
            logger.info(f"amp model dir: {args.amp_model_dir}")
            logger.info(f"phase basis: {args.phase_basis}")
            logger.info(f"phase model dir: {args.phase_model_dir}")
        elif args.basis_model == 'real-imag':
            logger.info(f"real basis: {args.real_basis}")
            logger.info(f"real model dir: {args.real_model_dir}")
            logger.info(f"imag basis: {args.imag_basis}")
            logger.info(f"imag model dir: {args.imag_model_dir}")

    sub_dir = pathlib.PurePath(args.wf_dir).name
    output_data_dir = os.path.join(args.output_dir, sub_dir)

    if args.verbose:
        logger.info("making outdir tree")
        logger.info(f"making dir: {output_data_dir}")
    os.makedirs(f"{output_data_dir}", exist_ok=True)

    # specify number of threads to use for tensorflow
    tf.config.threading.set_inter_op_parallelism_threads(1)
    tf.config.threading.set_intra_op_parallelism_threads(1)
    os.environ['OMP_NUM_THREADS'] = str(1)
    if args.verbose:
        logger.info(
            f"tf using {tf.config.threading.get_inter_op_parallelism_threads()} inter_op_parallelism_threads thread(s)")
        logger.info(
            f"tf using {tf.config.threading.get_intra_op_parallelism_threads()} intra_op_parallelism_threads thread(s)")
        logger.info("OMP_NUM_THREADS: 1")

    if args.basis_model == 'amp-phase':
        if args.verbose:
            logger.info(f"loading amp model")

        amp_model, amp_basis = load_model(
            basis_file=args.amp_basis,
            nn_weights_file=os.path.join(args.amp_model_dir, "best.h5"),
            X_scalers_file=os.path.join(args.amp_model_dir, "X_scalers.npy"),
            Y_scalers_file=os.path.join(args.amp_model_dir, "Y_scalers.npy"))

        if args.verbose:
            logger.info(f"loading phase model")
        phase_model, phase_basis = load_model(
            basis_file=args.phase_basis,
            nn_weights_file=os.path.join(args.phase_model_dir, "best.h5"),
            X_scalers_file=os.path.join(args.phase_model_dir, "X_scalers.npy"),
            Y_scalers_file=os.path.join(args.phase_model_dir, "Y_scalers.npy"))
    elif args.basis_model == 'real-imag':
        if args.verbose:
            logger.info(f"loading real model")

        real_model, real_basis = load_model(
            basis_file=args.real_basis,
            nn_weights_file=os.path.join(args.real_model_dir, "best.h5"),
            X_scalers_file=os.path.join(args.real_model_dir, "X_scalers.npy"),
            Y_scalers_file=os.path.join(args.real_model_dir, "Y_scalers.npy"))

        if args.verbose:
            logger.info(f"loading imag model")
        imag_model, imag_basis = load_model(
            basis_file=args.imag_basis,
            nn_weights_file=os.path.join(args.imag_model_dir, "best.h5"),
            X_scalers_file=os.path.join(args.imag_model_dir, "X_scalers.npy"),
            Y_scalers_file=os.path.join(args.imag_model_dir, "Y_scalers.npy"))

    if args.verbose:
        logger.info(f"loading wf data from {args.wf_dir}")
    if args.basis_model == 'amp-phase':
        wf_x, wf_ts_amp, wf_ts_coords = load_data(
            data_to_model="amp", dir_name=args.wf_dir)
        _, wf_ts_phase, _ = load_data(
            data_to_model="phase", dir_name=args.wf_dir)

        wf_ts_h22 = np.zeros(shape=wf_ts_amp.shape, dtype=np.complex128)
        for i in range(wf_ts_h22.shape[0]):
            vs_amp = wf_ts_amp[i]
            vs_phase = wf_ts_phase[i]
            wf_ts_h22[i] = vs_amp * np.exp(-1.j*vs_phase)
    elif args.basis_model == 'real-imag':
        wf_x, wf_ts_real, wf_ts_coords = load_data(
            data_to_model="real", dir_name=args.wf_dir)
        _, wf_ts_imag, _ = load_data(
            data_to_model="imag", dir_name=args.wf_dir)

        wf_ts_h22 = np.zeros(shape=wf_ts_real.shape, dtype=np.complex128)
        for i in range(wf_ts_h22.shape[0]):
            vs_hreal = wf_ts_real[i]
            vs_himag = wf_ts_imag[i]
            wf_ts_h22[i] = vs_hreal - 1.j * vs_himag

    qs = wf_ts_coords[:, 0]
    if wf_ts_coords.shape[1] == 2:
        chis = wf_ts_coords[:, 1]
    # chi1xs = wf_ts_coords[:,1]
    # chi1zs = wf_ts_coords[:,2]
    total_number_wfs = len(qs)

    if args.verbose:
        logger.info(f"total number of waveforms: {total_number_wfs}")
        logger.info("generating surrogate data")

    t1 = time.time()
    if args.basis_model == 'amp-phase':
        sur_hp, sur_hc, _, _ = wave_sur_many(
            wf_ts_coords, amp_model, amp_basis, phase_model, phase_basis)
    elif args.basis_model == 'real-imag':
        sur_hp, sur_hc = real_imag_wave_sur_many(
            wf_ts_coords, real_model, real_basis, imag_model, imag_basis)
    t2 = time.time()
    dt_sur = t2-t1

    if args.verbose:
        logger.info(f"time taken (surrogate) = {dt_sur:.5f} s")
        logger.info(
            f"time per waveform (surrogate) = {dt_sur/total_number_wfs:.5f} s")

    if args.verbose:
        logger.info("computing matches")
    matches = np.zeros(len(qs))
    for i in tqdm.tqdm(range(len(qs))):
        vs_h = wf_ts_h22[i]
        vs_hp = np.real(vs_h)
        # vs_hc = np.imag(vs_h)

        maxmatch = np.max(np.abs(match(vs_hp, sur_hp[i], wf_x)))

        matches[i] = maxmatch

    worst_idx = np.argmin(matches)

    if args.verbose:
        logger.info(f"worst match: {np.min(matches):.7f}")
        logger.info(f"best match : {np.max(matches):.7f}")
        logger.info(f"median match : {np.median(matches):.7f}")
        logger.info(f"st. dev match : {np.std(matches):.7f}")

    if args.verbose:
        best_idx = np.argmax(matches)
        best_coords = wf_ts_coords[best_idx]
        logger.info(f"best coords: {best_coords}")

    result_file = os.path.join(output_data_dir, "matches.h5")
    if args.verbose:
        logger.info(f"saving matches: {result_file}")
    with h5py.File(result_file, "w") as f:
        f.create_dataset("coords", data=wf_ts_coords)
        f.create_dataset("data", data=matches)

    if args.verbose:
        logger.info("plotting...")

    # plot some waveforms to see what they look like
    if args.verbose:
        logger.info("plotting best match waveform")
    filename = os.path.join(output_data_dir, "best_match.png")
    best_idx = np.argmax(matches)
    plot_waveform(wf_x, np.real(
        wf_ts_h22[best_idx]), sur_hp[best_idx], filename)

    if args.verbose:
        logger.info("plotting worst  match waveform")
    filename = os.path.join(output_data_dir, "worst_match.png")
    plot_waveform(wf_x, np.real(
        wf_ts_h22[worst_idx]), sur_hp[worst_idx], filename)

    if args.verbose:
        logger.info("plotting median match waveform")
    filename = os.path.join(output_data_dir, "median_match.png")
    median_idx = find_nearest_idx(matches, np.median(matches))
    plot_waveform(wf_x, np.real(
        wf_ts_h22[median_idx]), sur_hp[median_idx], filename)

    plt.figure()
    plt.scatter(range(len(matches)), 1-matches)
    plt.yscale('log')
    plt.ylim(1e-9)
    plt.xlabel("case")
    plt.ylabel("mismatch")
    plt.tight_layout()
    plt.savefig(os.path.join(output_data_dir, "1d-matches-vs-model.png"))
    plt.close()

    if wf_ts_coords.shape[1] == 3:
        if args.verbose:
            logger.info("plotting corner1.png")
        max_match = np.percentile(matches, 1)
        plot_corner(wf_ts_coords, matches, max_match=max_match,
                    outname=os.path.join(output_data_dir, "corner1.png"))
        if args.verbose:
            logger.info("plotting corner2.png")
        plot_corner(wf_ts_coords, matches, max_match=1.,
                    outname=os.path.join(output_data_dir, "corner2.png"))

    if wf_ts_coords.shape[1] == 1:
        plt.figure()
        plt.scatter(qs, 1-matches)
        plt.yscale('log')
        plt.ylim(1e-9)
        plt.xlabel("q")
        plt.ylabel("mismatch")
        plt.tight_layout()
        plt.savefig(os.path.join(output_data_dir, "1d-q-matches-vs-model.png"))
        plt.close()

    elif wf_ts_coords.shape[1] == 2:

        plt.figure()
        plt.scatter(qs, 1-matches)
        plt.yscale('log')
        plt.ylim(1e-9)
        plt.xlabel("q")
        plt.ylabel("mismatch")
        plt.tight_layout()
        plt.savefig(os.path.join(output_data_dir, "1d-q-matches-vs-model.png"))
        plt.close()

        plt.figure()
        plt.scatter(chis, 1-matches)
        plt.yscale('log')
        plt.ylim(1e-9)
        plt.xlabel(r"$\chi$")
        plt.ylabel("mismatch")
        plt.tight_layout()
        plt.savefig(os.path.join(output_data_dir,
                                 "1d-chi-matches-vs-model.png"))
        plt.close()

        plt.figure()
        plt.scatter(qs, chis, c=matches)
        plt.colorbar()
        plt.scatter(qs[worst_idx], chis[worst_idx], marker='o', s=100, c='k')
        plt.xlabel("q")
        plt.ylabel(r"$\chi$")
        plt.tight_layout()
        plt.savefig(os.path.join(output_data_dir, "2d-matches-vs-model.png"))
        plt.close()

    if args.verbose:
        logger.info("finished!")
