#!/usr/bin/env python3

import os
import glob
import zfit
import pandas            as pnd
import argparse
import zutils.utils      as zut
import utils_noroot      as utnr
import matplotlib.pyplot as plt

from logzero import logger as log

#----------------------------------------
class data:
    jsn_dir = None
    out_dir = None
    df_fit  = pnd.DataFrame(columns=['ver', 'var', 'mean_v', 'mean_e', 'width_v','width_e'])
#----------------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to make plots from JSON files from model checking scripts')
    parser.add_argument('-j', '--jsn_dir', type =str, help='Directory where JSON files go' , required=True)
    parser.add_argument('-p', '--plt_dir', type =str, help='Directory where plots will go' , required=True)
    parser.add_argument('-v', '--version', nargs='+', help='Versions, needed to pickup and dump' , required=True)
    args = parser.parse_args()

    return args
#----------------------------------------
def load_data(kind):
    jsn_wc = f'{data.jsn_dir}/{kind}_*.json'
    l_json = glob.glob(jsn_wc)
    if len(l_json) == 0:
        log.error(f'Found no JSON file in {jsn_wc}')
        raise

    l_df   = [ pnd.read_json(json_path) for json_path in l_json ]
    df     = pnd.concat(l_df)
    df     = df.reset_index(drop=True)

    return df
#----------------------------------------
def convert_version():
    df          = data.df_fit
    df['ver']   = df.ver.str.extract(r'(\d+)').astype(int)
    data.df_fit = df
#----------------------------------------
def main():
    zfit.settings.changed_warnings.all = False
    args = get_args()

    pull_json = f'{args.plt_dir}/plots/pull_fits.json'
    if os.path.isfile(pull_json):
        data.df_fit = pnd.read_json(pull_json)
        convert_version()
        plot_fits(args, ['nsg', 'nbk|npr'])
        return

    for version in args.version:
        data.jsn_dir = f'{args.jsn_dir}/{version}'
        data.out_dir = f'{args.plt_dir}/plots/{version}'

        os.makedirs(data.out_dir, exist_ok=True)

        df_val = load_data('val')
        df_err = load_data('err')
        df_ini = load_data('ini')

        plot_vals(df_ini, df_val)
        plot_pulls(version, df_ini, df_val, df_err)

    data.df_fit.to_json(f'{args.plt_dir}/plots/pull_fits.json', indent=4)
    convert_version()
    plot_fits(args, ['nsg', 'nbk|npr'])
#---------------------------------
def pad_vers(df, var, s_ver, l_val):
    var = var.split('|')[0]
    df  = df.reset_index(drop=True)

    [mu_v, mu_e, sg_v, sg_e] = l_val
    for ver in s_ver:
        if df.ver.isin([ver]).any():
            continue

        log.info(f'Padding version {ver} for var {var}')

        df = utnr.add_row_to_df(df, [ver, var, mu_v, mu_e, sg_v, sg_e])

    df = df.sort_values(by='ver')

    return df
#---------------------------------
def plot_fits(args, l_var):
    plt_dir = f'{args.plt_dir}/plots/summary'
    os.makedirs(plt_dir, exist_ok=True)

    s_ver = set(data.df_fit['ver'].values)

    ax=None
    for var in l_var:
        name = var.replace('|', '_')
        df = data.df_fit[data.df_fit['var'].str.contains(var)]
        df = pad_vers(df, var, s_ver, [0, 0, 0, 0])
        df['ver'] = df['ver'].apply(lambda val : f'v{val}')
        ax = df.plot(x='ver', y='mean_v', yerr='mean_e', marker='o', label=name, ax=ax)

    plt.ylim(-2, +2)
    plt.xlabel('')
    plt.ylabel('$\mu$')
    plt.axhline(0, color='red', linestyle=':')
    plt.legend(['Expected', 'Signal', 'Background'])
    plt.savefig(f'{plt_dir}/mu.png')
    plt.close('all')

    ax=None
    for var in l_var:
        name = var.replace('|', '_')
        df = data.df_fit[data.df_fit['var'].str.contains(var)]
        df = pad_vers(df, var, s_ver, [0, 0, 1, 0])
        df['ver'] = df['ver'].apply(lambda val : f'v{val}')
        ax = df.plot(x='ver', y='width_v', yerr='width_e', marker='o', label=name, ax=ax)

    plt.ylim(0, +2)
    plt.xlabel('')
    plt.ylabel('$\sigma$')
    plt.axhline(1, color='red', linestyle=':')
    plt.legend(['Expected', 'Signal', 'Background'])
    plt.savefig(f'{plt_dir}/sg.png')
    plt.close('all')
#---------------------------------
def plot_pulls(version, df_ini, df_val, df_err):
    for nam in df_val.columns: 
        pull = (df_val[nam] - df_ini[nam]) / df_err[nam]

        (mu_v, mu_e), (sg_v, sg_e) = zut.fit_pull(pull.values, fit_sig=2, plot=True)
        pull_path=f'{data.out_dir}/pul_{nam}.png'
        log.info(f'Saving to: {pull_path}')
        plt.savefig(pull_path)
        plt.close('all')

        data.df_fit = utnr.add_row_to_df(data.df_fit, [version, nam, mu_v, mu_e, sg_v, sg_e])
#---------------------------------
def plot_vals(df_ini, df_val):
    for nam in df_val.columns: 
        ini = df_ini.loc[0, nam]
        df_val[nam].hist(bins=30) 
        plt.axvline(x=ini, color='r')
        val_path=f'{data.out_dir}/val_{nam}.png'
        log.info(f'Saving to: {val_path}')
        plt.savefig(val_path)
        plt.close('all')
#----------------------------------------
if __name__ == '__main__':
    main()

