#!/usr/bin/env python3

import os
import glob
import pandas             as pnd
import utils_noroot       as utnr
import matplotlib.pyplot  as plt

from logzero import logger as log

#--------------------------------
class data:
    cas_dir = os.environ['CASDIR']
    ntp_ver = 'v10.21p2'

    l_year  = [2011, 2012, 2015, 2016, 2017, 2018]
    l_trig  = ['ETOS', 'GTIS']
    l_samp  = ['bpks', 'bdks', 'bpk1', 'bpk2', 'bpkp', 'bsph']

    d_sam_ltx         = {}
    d_sam_ltx['bpkp'] = r'Signal'
    d_sam_ltx['bpks'] = r'$B^+\to K^{*+}e^+e^-$'
    d_sam_ltx['bdks'] = r'$B^0\to K^{*0}e^+e^-$'
    d_sam_ltx['bsph'] = r'$B_s\to \phi e^+e^-$'
    d_sam_ltx['bpk1'] = r'$B^+\to K_{1}e^+e^-$'
    d_sam_ltx['bpk2'] = r'$B^+\to K_{2}e^+e^-$' 

    out_dir = utnr.make_dir_path('/home/acampove/Packages/RK/rk_extractor/plots')
#--------------------------------
def df_from_path(path):
    df = pnd.read_json(path)
    df = df.drop('Efficiency', axis=1)
    df = df.drop('Cumulative', axis=1)

    return df
#--------------------------------
def merge_df(l_df):
    df_sum = None
    for df in l_df:
        df_sum = df if df_sum is None else df.add(df_sum)

    df               = df_sum
    df['Efficiency'] = df.Pased / df.Total
    df['Cumulative'] = df.Efficiency.cumprod()

    return df
#--------------------------------
def get_df(sample, year, trig):
    path_wc = f'{data.cas_dir}/tools/apply_selection/rare_backgrounds/{sample}/{data.ntp_ver}/{year}_{trig}/*eff.json'
    l_path  = glob.glob(path_wc)
    if len(l_path) == 0:
        log.error(f'Cannot find any file in: {path_wc}')
        raise

    l_df = [ df_from_path(path) for path in l_path ]
    df   = merge_df(l_df)

    return df
#--------------------------------
def make_table(df, sample, year, trig):
    d_format = {'Cut'        : '{}',
                'Total'      : '{}',
                'Pased'      : '{}',
                'Efficiency' : '{:.3e}',
                'Cumulative' : '{:.3e}'}

    df['Cut'] = df.index
    utnr.df_to_tex(df, f'{data.out_dir}/{sample}_{year}_{trig}.tex', d_format=d_format)
#--------------------------------
def plot_df(d_df, year, trig):
    fig, ax = plt.subplots(figsize=(10,6))
    for samp, df in d_df.items():
        ltx= data.d_sam_ltx[samp]
        ax = df.Cumulative.plot(label=ltx, ax=ax)

    df['cut'] = df.index
    df=df.reset_index(drop=True)
    ax.set_xticks(df.index)
    ax.set_xticklabels(df.cut, rotation=60)

    plot_path = f'{data.out_dir}/{year}_{trig}.png'
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.yscale('log')
    plt.savefig(plot_path)
    plt.close('all')
#--------------------------------
def main():
    for year in data.l_year:
        for trig in data.l_trig:
            d_df = {}
            for samp in data.l_samp:
                df = get_df(samp, year, trig)
                make_table(df, samp, year, trig)
                d_df[samp] = df 

            plot_df(d_df, year, trig)
#--------------------------------
if __name__ == '__main__':
    main()

