#!/usr/bin/env python

"""
code to build a 2D surrogate TD model
using Neural Networks that can be run on GPUs

reads data generated by scrinet_gen_ts_data

./scrinet_fit --data-to-model phase -v --scaleX --scaleY --epochs 1000 --batch-size 1000 --input-units 64 --hidden-units 64 64 64 64 64

# example with tensorboard

./scrinet_fit --data-to-model amp -v --scaleX --scaleY --epochs 400  --batch-size 1000 --tfboard --tfboard-histogram-freq 1

# example with --tfboard-tag

./scrinet_fit --data-to-model amp -v --scaleX --scaleY --epochs 1000  --batch-size 1000 --tfboard --tfboard-histogram-freq 1 --tfboard-tag 64units  --input-units 64 --hidden-units 64 64 64 64 64

"""

import matplotlib
import matplotlib.pyplot as plt

import os
import sys
import argparse
import numpy as np
import h5py
import time

from scrinet.fits.nn import RegressionANN

from scrinet.workflow.pipe_utils import init_logger, load_data

import tensorflow as tf

def load_fit_data(path):
    X = np.load(os.path.join(path, "X.npy"))
    y = np.load(os.path.join(path, "y.npy"))
    return X, y

def plot_history(history, outname=None):
    """
    """
    import matplotlib
    import matplotlib.pyplot as plt

    plt.figure()
    plt.plot(history.history['loss'], label='training loss')
    if 'val_loss' in history.history.keys():
        plt.plot(history.history['val_loss'], label='validation loss')
    plt.yscale('log')
    plt.ylabel('loss')
    plt.xlabel('epochs')
    plt.legend()
    plt.tight_layout()
    if outname:
        plt.savefig(outname)
    else:
        plt.show()
    plt.close()

def plot_fit(X, y, fit, outdir):

    yhat = fit.predict(X)
    xhat = np.array(range(X.shape[0]))

    # loop over basis index
    for i in range(y.shape[1]):
        outname = os.path.join(outdir, f"{i}.png")

        plt.figure()
        plt.scatter(xhat, y[:,i], label='data', s=50, c='r', lw=2)
        plt.scatter(xhat, yhat[:,i], label='fit', marker='x', s=100, lw=2, c='k')
        plt.title(f"basis number = {i}")
        plt.legend()
        plt.tight_layout()
        plt.savefig(outname)
        plt.close()



if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("--data-to-model", type=str,
                        help="What to model",
                        choices=['time', 'amp', 'phase', 'freq', 'real', 'imag', 'alpha', 'beta', 'gamma'])

    parser.add_argument("--ts-dir", type=str, default='ts',
                        help="root directory of training set data")

    parser.add_argument("--output-dir", type=str, default='ts',
                        help="root directory to save data")

    parser.add_argument("--epochs", type=int, default=1000,
                        help="number of NN training epochs")
    parser.add_argument("--batch-size", type=int, default=None,
                        help="NN batch size")

    parser.add_argument("--activation", type=str, default='relu',
                        help="activation function for all layers")
    parser.add_argument("--kernel-initializer", type=str, default='he_uniform',
                        help="kernel initializer for all layers")
    parser.add_argument("--input-units", type=int, default=128,
                        help="number of units in input layer")
    parser.add_argument("--hidden-units", type=int, default=[128,128,128,128,128],
                        nargs='+', help="number of units in each hidden layer")
    parser.add_argument("--learning-rate", type=float, default=0.001,
                        help="learning rate for optimiser")
    parser.add_argument("--lr-schedule", help="use learning rate schedule",
                        action="store_true")



    parser.add_argument("--use-alr", help="activate adaptive learning rate (alr)",
                        action="store_true")
    parser.add_argument("--alr-monitor", type=str, default='val_loss',
                        help="quanity to monitor")
    parser.add_argument("--alr-factor", type=float, default=0.2,
                        help="factor by which the learning rate will be reduced")
    parser.add_argument("--alr-patience", type=int, default=50,
                        help="number of epochs with no improvement after which learning rate will be reduced.")
    parser.add_argument("--alr-verbose", type=int, default=0,
                        help="int. 0: quiet, 1: update messages.")
    parser.add_argument("--alr-min-lr", type=float, default=1e-4,
                        help="lower bound on the learning rate.")

    parser.add_argument("--use-es", help="activate early stopping (es)",
                        action="store_true")
    parser.add_argument("--es-monitor", type=str, default='val_loss',
                        help="quanity to monitor")
    parser.add_argument("--es-patience", type=int, default=100,
                        help="Number of epochs with no improvement after which training will be stopped.")
    parser.add_argument("--es-verbose", type=int, default=0,
                        help="int. 0: quiet, 1: update messages.")

    parser.add_argument("--tfboard", help="activate TensorBoard logging",
                        action="store_true")
    parser.add_argument("--tfboard-log-dir", type=str, default='tfboard-log',
                        help="log directory for TensorBoard")
    parser.add_argument("--tfboard-tag", type=str,
                        help="""identifier tag for TensorBoard sub directory
                        If None then will use time stamp
                        """)

    parser.add_argument("--tfboard-histogram-freq", type=int, default=0,
                        help="histogram frequency argument for TensorBoard")

    parser.add_argument("--scaleX", help="scale X data",
                        action="store_true")
    parser.add_argument("--scaleY", help="scale y data",
                        action="store_true")

    parser.add_argument("--plot-fits", help="plot fits of traiing and val data",
                        action="store_true")

    parser.add_argument("--nn-verbose", type=int, default=1,
                        help="verbosity of NN")
    parser.add_argument("-v", "--verbose", help="increase output verbosity",
                        action="store_true")

    args = parser.parse_args()

    logger = init_logger()

    if args.verbose:
        logger.info(f"current file: {__file__}")
        logger.info("verbosity turned on")
        logger.info(f"data directory: {args.ts_dir}")

    output_data_dir = os.path.join(args.output_dir, args.data_to_model, "fits")
    output_fits_train_dir = os.path.join(output_data_dir, "plots", "train")
    output_fits_val_dir = os.path.join(output_data_dir, "plots", "val")

    if args.verbose:
        logger.info("making outdir tree")
        logger.info(f"making dir: {output_data_dir}")
        logger.info(f"making dir: {output_fits_train_dir}")
        logger.info(f"making dir: {output_fits_val_dir}")
        logger.info("making outdir tree")
    os.makedirs(f"{output_data_dir}", exist_ok=True)
    os.makedirs(f"{output_fits_train_dir}", exist_ok=True)
    os.makedirs(f"{output_fits_val_dir}", exist_ok=True)

    tf.config.threading.set_inter_op_parallelism_threads(1)
    tf.config.threading.set_intra_op_parallelism_threads(1)
    os.environ['OMP_NUM_THREADS'] = str(1)
    if args.verbose:
        logger.info(f"tf using {tf.config.threading.get_inter_op_parallelism_threads()} inter_op_parallelism_threads thread(s)")
        logger.info(f"tf using {tf.config.threading.get_intra_op_parallelism_threads()} intra_op_parallelism_threads thread(s)")
        logger.info("OMP_NUM_THREADS: 1")


    path = os.path.join(args.ts_dir, args.data_to_model, "train")
    if args.verbose:
        logger.info(f"loading training data: {path}")
    X, y = load_fit_data(path)
    if args.verbose:
        logger.info(f"training data shape: X.shape = {X.shape}")
        logger.info(f"training data shape: y.shape = {y.shape}")

    path = os.path.join(args.ts_dir, args.data_to_model, "val")
    if args.verbose:
        logger.info(f"loading validation data: {path}")
    Xval, yval = load_fit_data(path)
    if args.verbose:
        logger.info(f"validation data shape: X.shape = {Xval.shape}")
        logger.info(f"validation data shape: y.shape = {yval.shape}")



    if args.verbose:
        logger.info(f"activation function: {args.activation}")
        logger.info(f"input units: {args.input_units}")
        logger.info(f"hidden units: {args.hidden_units}")
        logger.info(f"batch size: {args.batch_size}")



    if args.tfboard:
        tfboard_log_dir = os.path.join(args.tfboard_log_dir, args.data_to_model)
        if args.verbose:
            logger.info("using TensorBoard")
            logger.info(f"TensorBoard log dir: {tfboard_log_dir}")
            logger.info(f"TensorBoard histogram freq: {args.tfboard_histogram_freq}")
            logger.info(f"TensorBoard tag: {args.tfboard_tag}")
    else:
        tfboard_log_dir = None

    if args.verbose:
        if args.lr_schedule is True:
            logger.info("using learning rate scheduler")
        elif args.use_alr:
            logger.info("using adaptive learning rate (alr)")
            logger.info(f"alr_monitor: {args.alr_monitor}")
            logger.info(f"alr_factor: {args.alr_factor}")
            logger.info(f"alr_patience: {args.alr_patience}")
            logger.info(f"alr_verbose: {args.alr_verbose}")
            logger.info(f"alr_min_lr: {args.alr_min_lr}")
        else:
            logger.info(f"learning rate: {args.learning_rate}")

    if args.verbose:
        logger.info(f"kernel initializer: {args.kernel_initializer}")

    if args.verbose:
        if args.use_es:
            logger.info("using early stopping (es)")
            logger.info(f"es_monitor: {args.es_monitor}")
            logger.info(f"es_patience: {args.es_patience}")
            logger.info(f"es_verbose: {args.es_verbose}")

    if args.use_alr and args.use_es:
        if args.es_patience < args.alr_patience:
            logger.warning("es_patience < alr_patience training might stop too early")

    if args.verbose:
        logger.info("fitting")
    t1 = time.time()

    fit = RegressionANN()
    history = fit.fit(
        X,
        y,
        input_dim=X.shape[1],
        noutput=y.shape[1],
        epochs=args.epochs,
        validation_data=(Xval, yval),
        outdir=output_data_dir,
        scaleX=args.scaleX,
        scaleY=args.scaleY,
        verbose=args.nn_verbose,
        batch_size=args.batch_size,
        activation=args.activation,
        input_units=args.input_units,
        units=args.hidden_units,
        learning_rate=args.learning_rate,
        lr_schedule=args.lr_schedule,
        use_alr=args.use_alr,
        alr_monitor=args.alr_monitor,
        alr_factor=args.alr_factor,
        alr_patience=args.alr_patience,
        alr_verbose=args.alr_verbose,
        alr_min_lr=args.alr_min_lr,
        use_es=args.use_es,
        es_monitor=args.es_monitor,
        es_patience=args.es_patience,
        es_verbose=args.es_verbose,
        kernel_initializer=args.kernel_initializer,
        use_tfboard=args.tfboard,
        tfboard_logdir=tfboard_log_dir,
        tfboard_histogram_freq=args.tfboard_histogram_freq,
        tfboard_tag=args.tfboard_tag
        )

    best_network_file = os.path.join(output_data_dir, "best.h5")
    if args.verbose:
        logger.info(f"loading best network: {best_network_file}")
    fit.load_model(best_network_file)


    t2 = time.time()
    dt = t2-t1
    if args.verbose:
        logger.info(f"fitting finished took: {dt:.2f} s")

    if args.tfboard:
        if args.verbose:
            logger.info(f"TensorBoard log dir: {tfboard_log_dir}")

    loss_plot_name = os.path.join(output_data_dir, "loss.png")
    if args.verbose:
        logger.info(f"plotting loss: {loss_plot_name}")
    plot_history(history, outname=loss_plot_name)

    if args.plot_fits:
        if args.verbose:
            logger.info(f"plotting training fits: {output_fits_train_dir}")
        plot_fit(X, y, fit, outdir=output_fits_train_dir)

        if args.verbose:
            logger.info(f"plotting validation fits: {output_fits_val_dir}")
        plot_fit(Xval, yval, fit, outdir=output_fits_val_dir)