#!/usr/bin/env python3

import glob
import os
import re
import tqdm
import math
import numpy
import shutil
import tarfile
import argparse
import mplhep

import pandas            as pnd
import matplotlib.pyplot as plt
import utils_noroot      as utnr
import zutils.utils      as zut

from stats.covariance   import covariance as sta_cov
from plotting.utilities import plot_pull  as put_pull
from logzero            import logger     as log

#-------------------------------------------------------
class data:
    job_name = None
    out_path = None
    good_fit = None
    df_pull  = pnd.DataFrame(columns=['mu', 'sg', 'var', 'trig', 'dset'])

    l_trig = None
    l_dset = None
#-------------------------------------------------------
def rename_jsn():
    os.makedirs(f'{data.out_path}/output', exist_ok=True)
    for jsn_path in glob.glob('result_jsn/*.json'):
        jsn_name = os.path.basename(jsn_path)
        os.replace(jsn_path, f'{data.out_path}/output/{jsn_name}')
#-------------------------------------------------------
def untar(tar_path):
    tar = tarfile.open(tar_path)
    tar.extractall()
    tar.close()
#-------------------------------------------------------
def make_json():
    '''
    Will take tarballs, untar them, and put all the JSON files in output directory
    '''

    if not data.job_name:
        return

    if os.path.isdir(f'{data.out_path}/output'):
        log.info('JSON directory found, not making it')
        return

    l_dirname = [ dirname for dirname in glob.glob(f'{data.out_path}/*') if re.match(f'{data.out_path}/'+ '\d{9}', dirname)]
    if len(l_dirname) == 0:
        log.error(f'Found no sandboxes in {data.out_path}')
        raise

    l_tar_path= [ f'{dirname}/result_jsn.tar.gz' for dirname in l_dirname if os.path.isfile( f'{dirname}/result_jsn.tar.gz')]

    if len(l_tar_path) == 0:
        log.error(f'Found no tarballs for {data.job_name}')
        raise

    log.info('Unpacking JSON files')
    for tar_path in tqdm.tqdm(l_tar_path, ascii=' -'):
        try:
            untar(tar_path)
        except tarfile.ReadError:
            log.warning(f'Read error: {tar_path}')
            continue
        except EOFError:
            log.warning(f'EOFError: {tar_path}')
            continue

        rename_jsn()
        shutil.rmtree('result_jsn')
#-------------------------------------------------------
def get_data(json_path, kind):
    '''
    Takes path to result_xxxx.json and returns dictionary with {str : float}
    mapping of parameters, etc
    '''
    try:
        d_data = utnr.load_json(json_path)
    except:
        log.warning(f'Cannot load: {json_path}')
        return dict() 

    d_data = d_data[kind]

    d_data_pars = {key : val        for key, val in d_data.items() if isinstance(val, list) and len(val) == 2}
    d_data_meta = {key : float(val) for key, val in d_data.items() if isinstance(val, (float, bool, int))}

    d_data_rename = {}
    for name, [val, err] in d_data_pars.items():
        d_data_rename[f'{name} value'] = [float(val)]
        d_data_rename[f'{name} error'] = [float(err)]

    d_data_rename.update(d_data_meta)

    return d_data_rename
#-------------------------------------------------------
def get_df(kind):
    l_df = [ pnd.DataFrame(get_data(json_path, kind)) for json_path in get_json_paths()]
    df   = pnd.concat(l_df, axis=0)
    df   = df.reset_index(drop=True)

    return df
#-------------------------------------------------------
def get_moments(ser):
    l_val = ser.tolist() 
    l_val = utnr.remove_outliers(l_val)

    mu = numpy.mean(l_val) 
    sg = numpy.std(l_val) 

    return mu, sg
#-------------------------------------------------------
def get_json_paths():
    json_wc = f'{data.out_path}/output/*.json' if data.job_name is not None else f'{data.out_path}/*/results/result_jsn/*.json'
    l_json_path  = glob.glob(json_wc)
    if len(l_json_path) == 0:
        log.error(f'No JSON file found in: {json_wc}')
        raise

    return l_json_path
#-------------------------------------------------------
def store_pull(mu, sd, var):
    if var == 'rk':
        data.df_pull = utnr.add_row_to_df(data.df_pull, [mu, sd, 'rk', 'all', 'all'])
        return

    mtch = re.match('(.*)_(r1|r2p1|2017|2018|all)_(TOS|TIS)_(.*)', var)
    if not mtch:
        log.error(f'Cannot extract information from: {var}')
        raise

    [name, dset, trig, _]=mtch.groups()

    data.df_pull = utnr.add_row_to_df(data.df_pull, [mu, sd, name, trig, dset])
#-------------------------------------------------------
def plot(df_pos=None, df_pre=None, var=None):
    os.makedirs(f'{data.out_path}/plots', exist_ok = True)

    sr_val = df_pos[f'{var} value']
    sr_err = df_pos[f'{var} error']
    sr_pre = df_pre[f'{var} value']
    sr_cns = df_pre[f'{var} error']
    sr_pul = (sr_val - sr_pre) / sr_err

    plot_pull(sr_pul, var)

    plot_vals(sr_val, sr_pre[0], sr_cns[0], var)
    plot_errs(sr_err, sr_pre[0], sr_cns[0], var)
    plot_qlty(df_pos)
#-------------------------------------------------------
def plot_pull(sr_pul, var):
    mu, sd = zut.fit_pull(sr_pul.values, fit_sig=2, plot=True)

    store_pull(mu, sd, var)

    os.makedirs(f'{data.out_path}/plots/pulls', exist_ok=True)

    plot_path=f'{data.out_path}/plots/pulls/{var}.png'
    log.debug(f'Saving to: {plot_path}')
    plt.savefig(plot_path)
    plt.close('all')
#-------------------------------------------------------
def get_range(var, error=False):
    if var == 'rk' and not error:
        return (0.7, 1.2)
    
    return
#-------------------------------------------------------
def plot_vals(sr_val, gen, cns, var):
    os.makedirs(f'{data.out_path}/plots/values', exist_ok=True)
    plot_path = f'{data.out_path}/plots/values/{var}.png'

    sr_val.plot.hist(bins=50, range=get_range(var, error=False), histtype='step')
    plt.axvline(x=gen, color='red', linestyle='--')
    if cns < 1e-6:
        plt.legend(['Fitted', f'Generated={gen:.3f}'])
    else:
        plt.axvline(x=gen + cns, color='red', linestyle=':')
        plt.axvline(x=gen - cns, color='red', linestyle=':')
        plt.legend(['Fitted', f'Generated={gen:.3f}', '+const', '-const'])

    plt.title(var)
    log.debug(f'Saving to: {plot_path}')
    plt.savefig(plot_path)
    plt.close('all')
#-------------------------------------------------------
def plot_errs(sr_val, gen, cns, var):
    os.makedirs(f'{data.out_path}/plots/errors', exist_ok=True)
    plot_path = f'{data.out_path}/plots/errors/{var}.png'

    sr_val.plot.hist(bins=50, range=get_range(var, error=True), histtype='step')
    var = var.replace('_', ' ')
    plt.title(f'$\\varepsilon{{{var}}}$')
    log.debug(f'Saving to: {plot_path}')

    if var.startswith('n'):
        err = math.sqrt(gen)
        plt.axvline(x=err, color='red', linestyle='--')
        plt.legend(['Error', '$\sqrt{Generated}$'])

    if cns > 1e-6:
        plt.axvline(x=cns, color='red', linestyle='--')
        plt.legend(['Error', 'Constraint width'])

    if var == 'rk':
        err_exp = sr_val.mean()
        plt.axvline(x=err_exp, color='red', linestyle='--')
        plt.legend(['Error', f'$\mu={err_exp:.3f}$'])

    plt.savefig(plot_path)
    plt.close('all')
#-------------------------------------------------------
def get_latex(l_var):
    d_var_lat = get_var_naming()

    l_lat = []
    for var in l_var:
        if var in d_var_lat: 
            lat = d_var_lat[var] 
            l_lat.append(lat)
            continue

        found=False
        for var_dic, lat in d_var_lat.items():
            if not var.startswith(var_dic):
                continue
            else:
                found=True
                break

        if not found:
            log.error(f'Cannot find any key -> name pairing for: {var}')
            raise

        suf= '${}^{,TIS}$' if '_TIS_' in var else '${}^{,TOS}$'

        lat = f'{lat}{suf}'
        l_lat.append(lat)

    return l_lat
#-------------------------------------------------------
def plot_cov(cov, l_var):
    os.makedirs(f'{data.out_path}/plots', exist_ok=True)

    l_lat = get_latex(l_var)

    plot_path = f'{data.out_path}/plots/covariance.png'
    utnr.plot_matrix(plot_path, l_lat, l_lat, cov, upper=False, title='', form=None, fsize=[25, 20])

    cor       = utnr.correlation_from_covariance(cov)
    plot_path = f'{data.out_path}/plots/correlation.png'
    utnr.plot_matrix(plot_path, l_lat, l_lat, cor, upper=False, title='', form=None, fsize=[25, 20])
#-------------------------------------------------------
def freq_one(df, quantity):
    sr_qnt = df[quantity]
    ntot   = len(sr_qnt)

    sr_one = sr_qnt == 1
    none   = len(sr_one)

    return none, ntot - none
#-------------------------------------------------------
def plot_qlty(df):
    os.makedirs(f'{data.out_path}/plots/quality', exist_ok=True)
    plot_path = f'{data.out_path}/plots/quality/summary.png'

    cnv_y, cnv_n = freq_one(df, 'converged') 
    sta_y, sta_n = freq_one(df, 'status') 
    val_y, val_n = freq_one(df, 'valid') 

    xerr = [0.5, 0.5, 0.5]
    xval = [1.0, 2.0, 3.0]
    plt.errorbar(xval, [cnv_y, sta_y, val_y], xerr=xerr, label='Good' , marker='o', linestyle='None')
    plt.errorbar(xval, [cnv_n, sta_n, val_n], xerr=xerr, label='Bad'  , marker='o', linestyle='None')

    plt.title('Fit quality')
    log.debug(f'Saving to: {plot_path}')
    plt.grid()
    plt.legend()
    plt.xticks(xval, ['Converged', 'Status', 'Valid'])
    plt.savefig(plot_path)
    plt.close('all')
#-------------------------------------------------------
def get_var_naming():
    d_name             = {}
    d_name['r0_ee_ee'] = r'$r_0^{ee}$'
    d_name['r1_ee_ee'] = r'$r_1^{ee}$'
    d_name['r2_ee_ee'] = r'$r_2^{ee}$'

    d_name['ssg_mm_mm']= r'$r_{\sigma}^{\mu\mu}$'
    d_name['nsg_mm']   = r'$N_{signal}^{\mu\mu}$'
    d_name['ncb_mm']   = r'$N_{comb}^{\mu\mu}$'
    d_name['mu_cb_mm'] = r'$\mu_{comb}^{\mu\mu}$'
    d_name['lm_cb_mm'] = r'$\lambda_{comb}^{\mu\mu}$'
    d_name['dmu_mm_mm']= r'$\Delta\mu^{\mu\mu}$'

    d_name['ssg_ee_ee']= r'$r_{\sigma}^{ee}$'
    d_name['nsg_ee']   = r'$N_{signal}^{ee}$'
    d_name['ncb_ee']   = r'$N_{comb}^{ee}$'
    d_name['mu_cb_ee'] = r'$\mu_{comb}^{ee}$'
    d_name['lm_cb_ee'] = r'$\lambda_{comb}^{ee}$'
    d_name['dmu_ee_ee']= r'$\Delta\mu^{ee}$'

    d_name['npr_ee']   = r'$N_{PRec}^{ee}$'
    d_name['nrbd_ee']  = r'$N_{\text{rare }B_d}^{ee}$'
    d_name['nrbp_ee']  = r'$N_{\text{rare }B_u}^{ee}$'
    d_name['nrbs_ee']  = r'$N_{\text{rare }B_s}^{ee}$'
    d_name['ck']       = r'$c_{K}$'
    d_name['rk']       = r'$R_{K}$'

    return d_name
#-------------------------------------------------------
def plot_pull_summary():
    pull_dir = f'{data.out_path}/plots/summary_pulls'
    os.makedirs(pull_dir, exist_ok=True)
    for dset in data.l_dset:
        for trig in data.l_trig:
            df = data.df_pull
            df = df[(df.dset == dset) | (df.dset == 'all')]
            df = df[(df.trig == trig) | (df.trig == 'all')]

            if len(df) == 0:
                continue

            bad_row = (df.mu.abs() > 2) | (df.sg.abs() > 2)
            df.loc[bad_row, 'mu' ] = numpy.nan
            df.loc[bad_row, 'sg' ] = numpy.nan

            df=df.drop('trig', axis=1)
            df=df.drop('dset', axis=1)
            df=df.replace(get_var_naming())
            df=df.sort_values(by=['var'])
            df=df.reset_index(drop=True)

            put_pull(df, var='var', val='mu', err='sg')

            plot_path = f'{pull_dir}/{dset}_{trig}.png'
            log.debug(f'Saving to: {plot_path}')
            plt.title(f'{dset}, {trig}')
            plt.savefig(plot_path)
            plt.close('all')

            print(df)
#-------------------------------------------------------
def filter_df(df):
    return df
    if not data.good_fit:
        log.info('Not filtering dataframe')
        return df

    log.warning('Filtering dataframe')

    df=df[df.valid     == 1]
    df=df[df.converged == 1]
    df=df[df.status    == 0] 

    df=df.reset_index(drop=True)

    return df
#-------------------------------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Will make plots from the results of toy fits')
    parser.add_argument('-n','--job_name', type=str, help='Name of job, for grid jobs')
    parser.add_argument('-p','--job_path', type=str, help='Path to job output, for IHEP tests')
    parser.add_argument('-g','--good_fit', help='Will plot only good fits', action='store_true')
    args = parser.parse_args()

    if args.job_name is None and args.job_path is None:
        log.error(f'Neither job name or job path passed')
        raise

    data.job_name = args.job_name
    data.out_path = f'output_{data.job_name}' if args.job_name else args.job_path
    data.good_fit = args.good_fit
#-------------------------------------------------------
def get_fit_pars(json_path, kind=None):
    try:
        d_data = utnr.load_json(json_path)
    except:
        log.warning(f'Cannot load: {json_path}')
        return

    l_name = d_data['pos']['par']
    l_val  = [ d_data[kind][name][0] for name in l_name ]

    return l_val 
#-------------------------------------------------------
def get_par_names(json_path):
    try:
        d_data = utnr.load_json(json_path)
    except:
        log.warning(f'Cannot load: {json_path}')
        return

    l_name = d_data['pos']['par']

    return l_name
#-------------------------------------------------------
def build_covariance():
    data_name = f'{data.out_path}/covariance.json'
    if os.path.isfile(data_name):
        log.info(f'Loading: {data_name}')
        d_dat = utnr.load_json(data_name)
        l_par = d_dat['par']
        mat   = d_dat['cov']
        mat   = numpy.array(mat)

        return l_par, mat

    log.info(f'Not found {data_name}, remaking it')
    l_fit_par = [ get_fit_pars(json_path, kind='pos') for json_path in get_json_paths() ]
    json_path = get_json_paths()[0]
    l_gen_par = get_fit_pars(json_path, kind='pre')
    l_par_nam = get_par_names(json_path)

    obj       = sta_cov(numpy.array(l_fit_par).T, numpy.array(l_gen_par))
    mat       = obj.get_cov()
    d_cov     = {'cov' : mat.tolist(), 'par' : l_par_nam}

    utnr.dump_json(d_cov, data_name)

    return l_par_nam, mat 
#-------------------------------------------------------
def build_df(kind):
    data_name = f'{data.out_path}/{kind}.json'
    if not os.path.isfile(data_name):
        make_json()
        df=get_df(kind)
        df.to_json(data_name, indent=4)

    df = pnd.read_json(data_name)
    df = filter_df(df)

    return df
#-------------------------------------------------------
def find_dset_trig(l_var):
    s_dset = set()
    s_trig = set()
    for var in l_var:
        if ('_TOS_' not in var) and ('_TIS_' not in var):
            continue

        mtch = re.match('.*_(r1|r2p1|2017|2018|all)_(TOS|TIS).*', var)
        if not mtch:
            log.error(f'Could not extract trigger and dataset from: {var}')
            raise

        [dset, trig] = mtch.groups()
        s_dset.add(dset)
        s_trig.add(trig)

    data.l_dset = list(s_dset)
    data.l_trig = list(s_trig)
#-------------------------------------------------------
def main():
    log.setLevel(20)
    plt.style.use(mplhep.style.LHCb2)

    df_pos = build_df('pos')
    df_pre = build_df('pre')

    l_pos  = df_pos.columns.tolist()
    l_pre  = df_pre.columns.tolist()

    find_dset_trig(l_pos)

    s_var = set(l_pos).intersection(l_pre)
    s_var = { var.replace(' value', '').replace(' error', '') for var in s_var }

    for var in s_var:
        plot(df_pos=df_pos, df_pre=df_pre, var=var)

    plot_pull_summary()

    l_par, cov = build_covariance()
    plot_cov(cov, l_par)
#-------------------------------------------------------
if __name__ == '__main__':
    get_args()
    main()

