#!/usr/bin/env python3

import utils_noroot      as utnr
import matplotlib.pyplot as plt
import pandas            as pnd
import pprint
import argparse
import math
import re
import os 

from np_reader import np_reader as np_rdr
from logzero   import logger    as log

#-----------------------------------
class data:
    version=None
    fit_dir='/home/angelc/Packages/RK/rk_extractor/tests/real'

    rk_val = None
    rk_err = None

    yld_ver= 'v24'
    eff_sta= 'v63'
    eff_sys= 'v65'
#-----------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('-v','--version', type=str, help='Version of fit', required=True)
    args = parser.parse_args()

    data.version = args.version
#-----------------------------------
def get_par_info(name):
    rgx_1=r'([a-z]+)_(ee|mm)_([a-z0-9]+)_(TIS|TOS)_[a-z]+'
    rgx_2=r'ck_([a-z0-9]+)_(TIS|TOS)'

    rgx  = rgx_2 if name.startswith('ck_') else rgx_1

    mtch = re.match(rgx, name)
    if not mtch:
        log.error(f'Cannot match {name} with {rgx}')
        raise

    if not name.startswith('ck_'):
        name = mtch.group(1)
        chan = mtch.group(2)
        dset = mtch.group(3)
        trig = mtch.group(4)
    else:
        name = 'ck' 
        chan = 'both'
        dset = mtch.group(1)
        trig = mtch.group(2)

    return [name, chan, dset, trig]
#-----------------------------------
def get_df():
    pkl_path = f'{data.fit_dir}/{data.version}/result.pkl'
    res = utnr.load_pickle(pkl_path)

    df = pnd.DataFrame(columns=['Name', 'Channel', 'Dataset', 'Trigger', 'Value', 'Error'])
    for name, d_val in res.params.items():
        val = d_val['value']
        err = d_val['hesse']['error'] 

        if name == 'rk':
            data.rk_val = val
            data.rk_err = err
            continue

        row = get_par_info(name)
        row+= [val, err]
        if row[1] == 'both':
            row[1] = 'ee'
            utnr.add_row_to_df(df, row)

            row[1] = 'mm'
            utnr.add_row_to_df(df, row)
        else:
            utnr.add_row_to_df(df, row)

    return df
#-----------------------------------
def save_table(name, df, formatter):
    table_path = f'{data.out_dir}/tables/{name}' 
    utnr.df_to_tex(df, table_path, hide_index=True, d_format=formatter, caption=None)
#-----------------------------------
def make_yield_tables(df):
    df_yld     = df[df.Name.str.startswith('n')]
    df_yld_tos = df_yld[df_yld.Trigger == 'TOS']
    df_yld_tis = df_yld[df_yld.Trigger == 'TIS']

    df_yld_tos = df_yld_tos.sort_values('Channel')
    df_yld_tis = df_yld_tis.sort_values('Channel')

    df_yld_tos = df_yld_tos.drop('Trigger', axis=1)
    df_yld_tis = df_yld_tis.drop('Trigger', axis=1)

    formatter={'Value' :'{:.0f}', 'Error' : '{:.0f}'}
    save_table('yld_tos.tex', df_yld_tos, formatter)
    save_table('yld_tis.tex', df_yld_tis, formatter)
#-----------------------------------
def get_ck_prefit():
    rdr       = np_rdr(sys=data.eff_sys, sta=data.eff_sta, yld=data.yld_ver)
    rdr.cache = True
    d_rjpsi   = rdr.get_rjpsi()
    d_eff     = rdr.get_eff()
    cov_sys   = rdr.get_cov(kind='sys')
    cov_sta   = rdr.get_cov(kind='sta')
    cov_tot   = cov_sys + cov_sta

    l_dset   = ['r1_TOS', 'r1_TIS', 'r2p1_TOS', 'r2p1_TIS', '2017_TOS', '2017_TIS', '2018_TOS', '2018_TIS']
    l_var    = [cov_sys[i][i] for i in range(len(l_dset))]
    l_err    = [ math.sqrt(var) for var in l_var ]
    d_ck_err = {dset : err for dset, err in zip(l_dset, l_err)}
    d_ck_val = {}
    for dset, rjpsi in d_rjpsi.items():
        eff_mm, eff_ee = d_eff[dset]
        d_ck_val[dset] = (eff_ee/eff_mm) / rjpsi

    return d_ck_val, d_ck_err
#-----------------------------------
def add_ck_prefit(df):
    df          = df[df.Name.str.startswith('ck')]
    df['label'] = df.Dataset + '_' + df.Trigger
    df          = df.set_index('label')

    d_ck_val, d_ck_err = get_ck_prefit()

    df['Value prefit'] = d_ck_val
    df['Error prefit'] = d_ck_err
    df = df.drop('Channel', axis=1)
    df = df.drop('Dataset', axis=1)
    df = df.drop_duplicates()

    df_tos = df[df.Trigger == 'TOS']
    df_tis = df[df.Trigger == 'TIS']

    return df_tos, df_tis
#-----------------------------------
def analyze(df):
    make_yield_tables(df)
    plot_pre_pos(df)
#-----------------------------------
def plot_pre_pos(df):
    df_tos, df_tis = add_ck_prefit(df)
    plot_pre_pos_trig(df_tos, 'TOS')
    plot_pre_pos_trig(df_tis, 'TIS')
#-----------------------------------
def plot_pre_pos_trig(df, trig):
    ax=None
    ax=df.plot(y='Value', yerr='Error'       , capsize=4, label='Postfit', marker='o', linestyle='none', ax=ax, rot=20)
    ax=df.plot(y='Value prefit', yerr='Error prefit', capsize=4, label='Prefit' , marker='o', linestyle='none', ax=ax, rot=20)
    plt_dir = f'{data.fit_dir}/{data.version}/result/plots'
    os.makedirs(plt_dir, exist_ok=True)

    plt.xlabel('')
    plt.ylabel('$c_{K}^{rt}$')
    plt.grid()
    plt.savefig(f'{plt_dir}/ck_pre_pos_{trig}.png')
    plt.close('all')
#-----------------------------------
def main():
    get_args()

    data.out_dir = f'{data.fit_dir}/{data.version}/result'
    os.makedirs(data.out_dir, exist_ok=True)

    df = get_df()
    analyze(df)

    log.info(f'RK={data.rk_val:.3f}+/-{data.rk_err:.3f}')
#-----------------------------------
if __name__ == '__main__':
    main()

