#!/usr/bin/env python3

import warnings
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=DeprecationWarning)

import os
import re
import ROOT
import zfit
import pprint
import argparse
import utils_noroot   as utnr
import zutils.utils   as zut
import read_selection as rs

from importlib.resources import files
from rk_model            import rk_model
from normalizer          import normalizer
from np_reader           import np_reader as np_rdr
from rk                  import utilities as rkut
from model_manager       import manager
from log_store           import log_store
from builder             import builder   as cmb_bld

log = log_store.add_logger(name='rk_extractor:cmb_prec_nom')
#-----------------------------
class data:
    l_dset   = ['r1', 'r2p1', '2017', '2018', 'all']
    l_trig   = ['MTOS', 'ETOS', 'GTIS']
    l_bdt_bin= ['1', '2', '3', '4', '5']
    l_shr_par= ['mu', 'lm']

    dset    = None
    trig    = None

    version = None
    out_dir = None
    cmb_rep = None
#-----------------------------
def prepare_output():
    out_dir = files('extractor_data').joinpath(f'sb_fits/{data.version}')

    os.makedirs(out_dir, exist_ok=True)

    data.out_dir = out_dir
#-----------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to calculate normalizations for combinatorial and PRec from data sidebands')
    parser.add_argument('-v', '--version' ,  type=str, help='Version of output, used to name directories', required=True)
    parser.add_argument('-t', '--trigger' ,  type=str, help='Trigger', choices=data.l_trig)
    parser.add_argument('-d', '--dataset' ,  type=str, help='Dataset', choices=data.l_dset)
    parser.add_argument('-b', '--bdt_bin' , nargs='+', help='BDT bins to include in fit', choices=data.l_bdt_bin, required=True)
    parser.add_argument('-s', '--shr_par' , nargs='+', help='Parameters shared between bins', choices=data.l_shr_par, default=[])
    parser.add_argument('-r', '--cmb_rep' ,  type=int, help='Reparametrize combinatorial yield with linear dependence across bins', choices=[0, 1], default=1)
    args = parser.parse_args()
    
    data.version   = args.version
    data.dset      = args.dataset
    data.trig      = args.trigger
    data.cmb_rep   = args.cmb_rep
    data.l_shr_par = args.shr_par
    data.l_bdt_bin = [ int(sbin) for sbin in args.bdt_bin ]
#-----------------------------
def get_model(dset):
    rdr             = np_rdr(sys='v65', sta='v63', yld='v24')
    rdr.cache       = True
    rdr.cache_dir   = './v65_v63_v24'
    d_eff           = rdr.get_eff()
    d_byld          = rdr.get_byields()
    d_nent          = rkut.average_byields(d_byld, l_exclude=['TIS'])
    d_rare_yld      = rkut.reso_to_rare(d_nent, kind='jpsi')

    manager.reparametrize = data.cmb_rep == 1
    cmb_bld.d_shared      = { par : f'{par}_cb_shr' for par in data.l_shr_par}

    mod             = manager(preffix='sb_fits', d_eff=d_eff, d_nent=d_rare_yld, dset=dset)
    mod.fake        = False 
    mod.bdt_bin     = data.l_bdt_bin 
    d_mod           = mod.get_model()

    return d_mod
#-----------------------------
def main():
    pars_path = f'{data.out_dir}/{data.dset}_{data.trig}.json'
    if os.path.isfile(pars_path):
        log.info(f'Parameters already found for {data.trig}-{data.dset}, skipping')
        return

    key   = f'{data.dset}_TOS'
    d_mod = get_model(key)

    obj          = normalizer(dset=data.dset, trig=data.trig, d_model=d_mod, d_val={}, d_var={})
    obj.out_dir  = data.out_dir
    res          = obj.get_fit_result()

    d_par = zut.res_to_dict(res, frozen=True)
    d_par = {key : val for key, val in d_par.items() if not key.startswith('nsg_')}
    utnr.dump_json(d_par, pars_path)
    log.info(f'Saving to: {pars_path}')

    delete_all_pars()
#-----------------------------
def delete_all_pars():
    d_par = zfit.Parameter._existing_params
    l_key = list(d_par.keys())

    for key in l_key:
        del(d_par[key])
#-----------------------------
if __name__ == '__main__':
    get_args()
    prepare_output()
    main()

