#!/usr/bin/env python3

import os
import re
import glob
import numpy
import pprint
import tarfile
import argparse

import pandas            as pnd
import utils_noroot      as utnr
import matplotlib.pyplot as plt

from logzero     import logger as log
from scipy.stats import sem    as spy_sem
#-----------------------------------
class data:
    nrow      = None
    res_dir   = None
    out_dir   = None
    var       = 'rk'
    d_var_err = {}
    d_var_val = {}
    d_var_name= None
#-------------------------------------------------------
def get_var_naming():
    d_name             = {}
    d_name['ssg_mm']   = '$r_{\sigma}^{\mu\mu}$'
    d_name['ssg_ee']   = '$r_{\sigma}^{ee}$'
    d_name['ssg']      = '$r_{\sigma}$'

    d_name['nsg_mm']   = '$N_{signal}^{\mu\mu}$'
    d_name['nsg_ee']   = '$N_{signal}^{ee}$'
    d_name['nsg'   ]   = '$N_{signal}$'

    d_name['npr_ee']   = '$N_{PRec}^{ee}$'

    d_name['r0_ee']    = '$r_0^{Brem}$'
    d_name['r1_ee']    = '$r_1^{Brem}$'
    d_name['r2_ee']    = '$r_2^{Brem}$'

    d_name['ncb_mm']   = '$N_{comb}^{\mu\mu}$'
    d_name['ncb_ee']   = '$N_{comb}^{ee}$'
    d_name['ncb']      = '$N_{comb}$'

    d_name['mu_cb_mm'] = '$\mu_{comb}^{\mu\mu}$'
    d_name['mu_cb_ee'] = '$\mu_{comb}^{ee}$'
    d_name['mu_cb']    = '$\mu_{comb}$'

    d_name['lm_cb_mm'] = '$\lambda_{comb}^{\mu\mu}$'
    d_name['lm_cb_ee'] = '$\lambda_{comb}^{ee}$'
    d_name['lm_cb']    = '$\lambda_{comb}$'

    d_name['dmu_mm']   = '$\Delta(\mu)^{\mu\mu}$'
    d_name['dmu_ee']   = '$\Delta(\mu)^{ee}$'
    d_name['dmu']      = '$\Delta(\mu)$'

    d_name['ck']       = '$c_{K}$'
    d_name['None']     = 'None'

    d_name['cmb_ee:use_etos'] = r'$\text{PDF}_{eTOS}[+\text{comb}]$'

    return d_name
#-----------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='This script makes plots and tables to assess the sources of systematics on RK')
    parser.add_argument('-d','--res_dir', type=str, help='Directory with the output of systematics jobs')
    parser.add_argument('-o','--out_dir', type=str, help='Directory with plots', default='plots') 
    args = parser.parse_args()

    data.res_dir = args.res_dir
    data.out_dir = args.out_dir

    try:
        os.makedirs(data.out_dir, exist_ok=True)
    except:
        log.error(f'Cannot create: {data.out_dir}')
        raise
#-----------------------------------
def check_job(l_jsn):
    if   data.nrow is None:
        data.nrow = len(l_jsn)
    elif len(l_jsn) != data.nrow:
        log.error(f'Found incompatible number of JSON files in {jsn_wc}')
        raise
    else:
        pass
#-----------------------------------
def read_data(json_path):
    d_data     = utnr.load_json(json_path)
    d_pos      = d_data['pos']
    [val, err] = d_pos[data.var]

    return val, err
#-----------------------------------
def get_var_name(jsn_path):
    file_name = os.path.basename(jsn_path) 
    var_name  = file_name.replace('result_', '')
    var_name  = var_name.replace('.json'  , '')

    if re.match('\d{4}', var_name):
        var_name = 'None'

    return var_name
#-----------------------------------
def check_size(cnt, msg):
    if len(cnt) == 0:
        log.error(msg)
        raise
#-----------------------------------
def add_val(var, qty, d_dat):
    if var in d_dat:
        d_dat[var].append(qty)
    else:
        d_dat[var] = [qty]
#-----------------------------------
def update_data(dir_path):
    jsn_wc = f'{dir_path}/*.json'
    l_jsn  = glob.glob(jsn_wc)
    check_size(l_jsn, f'Empty list of JSON files in: {jsn_wc}')

    check_job(l_jsn)

    d_var = { get_var_name(jsn_path) : jsn_path for jsn_path in l_jsn }
    check_size(d_var, 'Empty variable -> JSON path dictionary')

    for var, json_path in d_var.items():
        val, err = read_data(json_path)

        if ':'     in var or var == 'None':
            add_val(var, val, data.d_var_val)

        if ':' not in var or var == 'None':
            add_val(var, err, data.d_var_err)

    check_size(data.d_var_err, 'Empty variable -> error dictionary')
    check_size(data.d_var_val, 'Empty variable -> value dictionary')
#-----------------------------------
def plot_error():
    df = pnd.DataFrame(columns=['Variable', 'Value', 'Error'])

    for var, l_error in data.d_var_err.items():
        name = data.d_var_name[var]
        plt.hist(l_error, range=(0, 0.2), bins=30, label=var)
        mu=numpy.mean(l_error)

        sg=spy_sem(l_error)
        plt.axvline(x=mu   , color='red', linestyle='--')
        plt.axvline(x=mu-sg, color='red', linestyle=':')
        plt.axvline(x=mu+sg, color='red', linestyle=':')
        plot_path = f'{data.out_dir}/rk_err_{var}.png'
        plt.legend(['Error', f'$\mu={mu:.3f}$'])
        plt.title(name)
        plt.xlabel(r'$\varepsilon(R_{K})$')
        plt.savefig(plot_path)
        plt.close('all')

        df = utnr.add_row_to_df(df, [name, mu, sg])

    df = df.sort_values(by=['Value'], ascending=True)
    df = df.reset_index(drop=True)

    return df
#-----------------------------------
def plot_value():
    df = pnd.DataFrame(columns=['Variable', 'Value', 'Error'])

    for var, l_value in data.d_var_val.items():
        name = data.d_var_name[var]
        plt.hist(l_value, range=(0, 2), bins=30, label=var)
        mu=numpy.mean(l_value)

        sg=spy_sem(l_value)
        plt.axvline(x=mu   , color='red', linestyle='--')
        plt.axvline(x=mu-sg, color='red', linestyle=':')
        plt.axvline(x=mu+sg, color='red', linestyle=':')
        plot_path = f'{data.out_dir}/rk_val_{var}.png'
        plt.legend(['$R_K$', f'$\mu={mu:.3f}$'])
        plt.title(name)
        plt.xlabel(r'$R_{K}$')
        plt.savefig(plot_path)
        plt.close('all')

        df = utnr.add_row_to_df(df, [name, mu, sg])

    df = df.sort_values(by=['Value'], ascending=True)
    df = df.reset_index(drop=True)

    return df
#-----------------------------------
def is_good_tarfile(tar_path):
    if not os.path.isfile(tar_path):
        return False

    dir_path = os.path.dirname(tar_path)
    if os.path.isdir(f'{dir_path}/result_jsn'):
        return True

    with tarfile.open(tar_path) as itar:
        itar.extractall(path=dir_path)
        return True 
#-----------------------------------
def get_df():
    df = pnd.DataFrame(columns=['var', 'val', 'err'])

    l_obj = glob.glob(f'{data.res_dir}/*')
    l_dir = [ obj for obj in l_obj if re.match('\d{8}', os.path.basename(obj)) ]
    l_tar = [ f'{dir_path}/result_jsn.tar.gz' for dir_path in l_dir ]
    l_dir = [ tar_file.replace('.tar.gz', '') for tar_file in l_tar if is_good_tarfile(tar_file) ]

    njob=len(l_dir)
    if njob == 0:
        log.error(f'Found {njob} jobs in {data.res_dir}')
        raise
    else:
        log.info(f'Found {njob} job outputs')

    log.debug(f'Found {njob} jobs')
    for dir_path in l_dir:
        update_data(dir_path) 

    df_err=plot_error()
    df_val=plot_value()

    return df_err, df_val
#-----------------------------------
def plot_err_df(df):
    nom_rk = df[df.Variable == 'None'].Value.iloc[0]
    df     = df[df.Variable != 'None']

    ax=df.plot(x='Variable', y='Value', xerr='Error', kind='barh', color='none')
    plt.gca().set_xlim(0, 0.1)
    plt.ylabel('')
    plt.axvline(x=nom_rk, color='red', linestyle=':')
    plt.xlabel(r'$\varepsilon(R_K)$')
    plt.grid()

    for i_val, val in enumerate(df.Value):
        plt.plot(val, i_val, 'o', color='blue')
        plt.text(val - 0.02, i_val, f'{val:.3f}', color="k")

    plt.title('Uncertainty after fixing a single parameter')
    plt.legend(['All floating', 'Fix one'], loc='upper left')
    plt.savefig(f'{data.out_dir}/parameter_systematics.png')
    plt.close('all')
#-----------------------------------
def plot_val_df(df):
    nom_val_rk  = df[df.Variable == 'None'].Value.iloc[0]
    nom_err_rk  = df[df.Variable == 'None'].Error.iloc[0]

    df['del_val'] =  df.Value - nom_val_rk
    df['del_err'] = (df.Error - nom_err_rk).abs()

    df=df[df.Variable != 'None']

    ax=df.plot(x='Variable', y='del_val', xerr='del_err', kind='barh', color='none')

    plt.gca().set_xlim(-0.03, +0.03)
    plt.ylabel('')
    plt.xlabel(r'$\Delta R_K$')
    plt.grid()

    for i_val, val in enumerate(df.del_val):
        plt.plot(val, i_val, 'o', color='blue')
        plt.text(val - 0.02, i_val, f'{val:.3f}', color="k")

    plt.title('Bias from fits to nominal toys with alternative models')
    plt.legend([])
    plt.tight_layout()
    plt.savefig(f'{data.out_dir}/model_systematics.png')
    plt.close('all')
#-----------------------------------
def main():
    data.d_var_name = get_var_naming()

    df_err, df_val = get_df()

    plot_err_df(df_err)
    plot_val_df(df_val)
#-----------------------------------
if __name__ == '__main__':
    get_args()
    main()

