#!/usr/bin/env python3

import argparse
import pandas as pnd
import ROOT
import os

from logzero    import logger as log
from rk.wgt_mgr import wgt_mgr

from importlib.resources import files

#-----------------------------------
class data:
    l_year = ['2011', '2012', '2015', '2016', '2017', '2018']
    l_trig = ['MTOS', 'ETOS', 'GTIS']

    vers = None
    nevt = None
#-----------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Will produce JSON files corresponding to DataFrames with weights and BDT scores for signal after full selection, but BDT')
    parser.add_argument('-y', '--year' , nargs='+', help='Years'   , choices=data.l_year, default=data.l_year)
    parser.add_argument('-t', '--trig' , nargs='+', help='Triggers', choices=data.l_trig, default=data.l_trig)
    parser.add_argument('-v', '--vers' , type=str , help='Version of output', required=True) 
    parser.add_argument('-n', '--nevt' , type=int , help='Maximum number of entries to process', default=-1) 
    args = parser.parse_args()

    data.l_year = args.year
    data.l_trig = args.trig
    data.vers   = args.vers
    data.nevt   = args.nevt
#-----------------------------------
def get_rdf(trig, year):
    log.info(f'Getting ROOT dataframe for: {trig}/{year}')
    cas_dir = os.environ['CASDIR']
    file_wc = f'{cas_dir}/tools/apply_selection/rare_backgrounds/sign/v10.21p2/{year}_{trig}/*.root'
    log.debug(f'Picking up files from: {file_wc}:{trig}')
    rdf = ROOT.RDataFrame(trig, file_wc)
    if data.nevt > 0:
        log.warning(f'Processing only {data.nevt} entries')
        rdf = rdf.Range(data.nevt)

    rdf.year    = year
    rdf.trigger = trig
    rdf.treename= trig
    rdf.filepath= file_wc

    nentries = rdf.Count().GetValue()
    if nentries == 0:
        log.error(f'Found {nentries} entries')
        raise
    else:
        log.debug(f'Found {nentries} entries')

    return rdf
#-----------------------------------
def get_weights(rdf, trig, year):
    d_set            = {}
    d_set['val_dir'] = f'./signal_table/{year}_{trig}'
    d_set['channel'] = 'electron'  if trig == 'MTOS' else 'muon'
    d_set['replica'] = 0 
    d_set['bts_sys'] ='nom'
    d_set['bts_ver'] = 200
    d_set['pid_sys'] ='nom'
    d_set['trk_sys'] ='nom'
    d_set['gen_sys'] ='nom'
    d_set['lzr_sys'] ='nom'
    d_set['hlt_sys'] ='nom'
    d_set['rec_sys'] ='nom'
    d_set['bdt_sys'] ='nom'
    d_set['dcm_sys'] ='000'

    if trig != 'MTOS':
        d_set['qsq_sys'] ='nom'

    obj   = wgt_mgr(d_set)
    rsl   = obj.get_reader('sel', rdf)
    d_wgt = rsl.get_weights()

    return d_wgt['nom'] 
#-----------------------------------
def get_df(trig, year):
    log.info(f'Getting Pandas dataframe for: {trig}/{year}')
    rdf     = get_rdf(trig, year)
    d_bdt   = rdf.AsNumpy(['BDT_cmb', 'BDT_prc'])
    arr_wgt = get_weights(rdf, trig, year)
    arr_cmb = d_bdt['BDT_cmb']
    arr_prc = d_bdt['BDT_prc']
    df      = pnd.DataFrame({'wgt' : arr_wgt, 'BDT_cmb' : arr_cmb, 'BDT_prc' : arr_prc})

    return df 
#-----------------------------------
def save_df(df, trig, year):
    log.info(f'Saving Pandas dataframe for: {trig}/{year}')

    json_dir = files('extractor_data').joinpath(f'sig_wgt/{data.vers}')
    os.makedirs(json_dir, exist_ok=True)
    json_path= f'{json_dir}/{trig}_{year}.json'

    log.info(f'Saving to: {json_path}')
    df.to_json(json_path, indent=4)
#-----------------------------------
def main():
    get_args()
    for trig in data.l_trig:
        for year in data.l_year:
            df = get_df(trig, year)
            save_df(df, trig, year)
#-----------------------------------
if __name__ == '__main__':
    main()

