#!/usr/bin/env python3

import os
import re
import ROOT
import zfit
import pprint
import argparse
import utils_noroot   as utnr
import zutils.utils   as zut
import read_selection as rs

from importlib.resources import files
from rk_model            import rk_model
from normalizer          import normalizer
from logzero             import logger    as log
from np_reader           import np_reader as np_rdr
from rk                  import utilities as rkut

#-----------------------------
class data:
    l_dset  = ['r1', 'r2p1', '2017', '2018', 'all']
    l_trig  = ['MTOS', 'ETOS', 'GTIS']

    dset    = None
    trig    = None

    version = None
    out_dir = None
    d_bdt_wp=None
#-----------------------------
def prepare_output():
    out_dir = files('extractor_data').joinpath(f'sb_fits/{data.version}')

    os.makedirs(out_dir, exist_ok=True)

    data.out_dir = out_dir
#-----------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to calculate normalizations for combinatorial and PRec from data sidebands')
    parser.add_argument('-c', '--cmb' , type =float, help='Combinatorial BDT WP, if not specified, will use nominal')
    parser.add_argument('-p', '--prc' , type =float, help='PRec BDT WP, if not specified, will use nominal')

    parser.add_argument('-v', '--version' , type=str, help='Version of output, used to name directories', required=True)
    parser.add_argument('-t', '--trigger' , type=str, help='Trigger', choices=data.l_trig)
    parser.add_argument('-d', '--dataset' , type=str, help='Dataset', choices=data.l_dset)
    args = parser.parse_args()
    
    data.version = args.version
    data.dset    = args.dataset
    data.trig    = args.trigger

    data.d_bdt_wp= make_bdt_wp(args)
#-----------------------------
def make_bdt_wp(args):
    year = '2018' if data.dset == 'all' else data.dset
    log.warning(f'Using 2018 WP for "all" dataset')
    bdt_cut = rs.get('bdt', data.trig, q2bin='none', year=year)
    regex   = 'BDT_cmb > ([0-9,\.]+) && BDT_prc > ([0-9,\.]+)'

    mtch = re.match(regex, bdt_cut)
    if not mtch:
        log.error(f'Cannot extract WP from: {bdt_cut}')
        raise

    [bdt_cmb, bdt_prc]  = [float(wp) for wp in mtch.groups()]

    d_bdt_wp            = {}
    d_bdt_wp['BDT_cmb'] = bdt_cmb if args.cmb is None else args.cmb
    d_bdt_wp['BDT_prc'] = bdt_prc if args.prc is None else args.prc

    return d_bdt_wp
#-----------------------------
def get_model(dset):
    rdr             = np_rdr(sys='v65', sta='v63', yld='v24')
    rdr.cache       = True
    d_eff           = rdr.get_eff()
    d_byld          = rdr.get_byields()
    d_nent          = rkut.average_byields(d_byld, l_exclude=['TIS'])
    d_rare_yld      = rkut.reso_to_rare(d_nent, kind='jpsi')

    mod             = rk_model(preffix='sb_fits', d_eff=d_eff, d_nent=d_rare_yld, l_dset=[dset])
    mod.read_yields = False 
    mod.kind        = 'sb_fits'
    mod.bdt_wp      = data.d_bdt_wp
    d_mod           = mod.get_model()
    d_val, d_var    = mod.get_cons()

    return d_mod, d_val, d_var
#-----------------------------
def main():
    pars_path = f'{data.out_dir}/{data.dset}_{data.trig}.json'
    if os.path.isfile(pars_path):
        log.info(f'Parameters already found for {data.trig}-{data.dset}, skipping')
        return

    trg                 = 'TIS' if data.trig == 'GTIS' else 'TOS'
    key                 = f'{data.dset}_{trg}'
    d_mod, d_val, d_var = get_model(key)
    mod_mm, mod_ee      = d_mod[key]

    mod          = mod_mm if data.trig == 'MTOS' else mod_ee
    obj          = normalizer(dset=data.dset, trig=data.trig, model=mod, d_val=d_val, d_var=d_var)
    obj.bdt_wp   = data.d_bdt_wp
    obj.out_dir  = data.out_dir
    res          = obj.get_fit_result()

    d_par = zut.res_to_dict(res, frozen=True)
    utnr.dump_json(d_par, pars_path)
    log.info(f'Saving to: {pars_path}')

    delete_all_pars()
#-----------------------------
def delete_all_pars():
    d_par = zfit.Parameter._existing_params
    l_key = list(d_par.keys())

    for key in l_key:
        del(d_par[key])
#-----------------------------
if __name__ == '__main__':
    get_args()
    prepare_output()
    main()

