#!/usr/bin/env python3

import os
import re
import ROOT
import zfit
import glob
import math
import numpy
import pprint
import argparse
import pandas              as pnd
import jacobi              as jac
import logzero
import tarfile
import rk.utilities        as rkut
import utils_noroot        as utnr
import concurrent.futures  as cf

from rk_model   import rk_model  as model
from np_reader  import np_reader as np_rdr
from extractor  import extractor as ext
from logzero    import logger    as log
from zutils     import utils     as zut
from cmb_ck     import combiner  as cmb_ck

#--------------------------------
class data:
    out_dir = 'results' 
    d_wp    = {'BDT_cmb' : 0.977, 'BDT_prc' : 0.480751}
    rk      = 1
    l_seed  = None
    l_dset  = None
    l_fixed = None
    l_model = None

    log_lvl = None 
    for_syst= None
    for_pull= None
#--------------------------------
def get_ne_args(d_pos, d_pre, ck_name, nsg_mm_name):
    nsg_mm, _   = d_pos[nsg_mm_name]
    rk, _       = d_pos['rk']

    if ck_name not in d_pos:
        ck, _ = d_pre[ck_name]
        l_par = [nsg_mm_name, 'rk']
    else:
        ck, _ = d_pos[ck_name]
        l_par = [ck_name, nsg_mm_name, 'rk']

    l_val       = [ck, nsg_mm, rk]
    cov         = d_pos['cov']
    cov         = numpy.array(cov)
    cov         = cov.astype(float)

    l_par_excl  = [ par                     for par in d_pos['par'] if par not in l_par ]
    l_ind_excl  = [ d_pos['par'].index(par) for par in l_par_excl                       ]
    cov         = numpy.delete(cov, l_ind_excl, axis=0)
    cov         = numpy.delete(cov, l_ind_excl, axis=1)

    if ck_name not in d_pos:
        cov = numpy.pad(cov, (1, 0))

    return l_val, cov
#--------------------------------
def get_ne(suffix, d_pos, d_pre):
    suffix_tos  = suffix.replace('_TIS_', '_TOS_')
    ck_name     = f'ck_{suffix}'
    nsg_mm_name = f'nsg_mm_{suffix_tos}'

    l_val, cov  = get_ne_args(d_pos, d_pre, ck_name, nsg_mm_name)

    nsg_ee_val, nsg_ee_var = jac.propagate(lambda x : (x[0] * x[1]) / x[2], l_val, cov ) 
    nsg_ee_err = math.sqrt(nsg_ee_var)

    nsg_ee_val = float(nsg_ee_val)
    nsg_ee_err = float(nsg_ee_err)

    return [nsg_ee_val, nsg_ee_err]
#--------------------------------
def add_ne(d_pos, d_pre):
    regex='nsg_mm_(.*_TOS_.*)'
    d_pos_ext = {}
    for var_name in d_pos:
        mtch = re.match(regex, var_name)
        if not mtch:
            continue

        nsg_mm_name= mtch.group(0)
        suffix_tos = mtch.group(1)
        suffix_tis = mtch.group(1).replace('_TOS_', '_TIS_')

        d_pos_ext[f'nsg_ee_{suffix_tos}'] = get_ne(suffix_tos, d_pos, d_pre) 
        if f'ck_{suffix_tis}' in d_pos:
            d_pos_ext[f'nsg_ee_{suffix_tis}'] = get_ne(suffix_tis, d_pos, d_pre)
        else:
            log.warning(f'TIS ck not found, skiping electron TIS yield')

    d_pos.update(d_pos_ext)

    return d_pos
#--------------------------------
def get_data(d_eff=None, d_nent=None, rseed=None): 
    mod       = model(preffix='toys_gen', d_eff=d_eff, d_nent=d_nent, l_dset=data.l_dset)
    mod.bdt_wp= data.d_wp
    d_dat     = mod.get_data(rseed=rseed)

    return d_dat
#--------------------------------
def fit(rseed=None, fix_var=None, mod_var=None):
    log.info(f'Seed: {rseed:04}')
    log.info(f'Variable fixed : {fix_var}')
    log.info(f'Model variation: {mod_var}')

    rdr          = np_rdr(sys='v65', sta='v63', yld='v24')
    rdr.cache    = True
    rdr.cache_dir= 'v65_v63_v24'
    cv_sys       = rdr.get_cov(kind='sys')
    cv_sta       = rdr.get_cov(kind='sta')
    d_eff        = rdr.get_eff()
    d_rjpsi      = rdr.get_rjpsi()
    d_byld       = rdr.get_byields()
    d_nent       = rkut.average_byields(d_byld, l_exclude=['TIS'])
    d_rare_yld   = rkut.reso_to_rare(d_nent, kind='jpsi')

    mod          = model(rk=data.rk, preffix='toys_fit', d_eff=d_eff, d_nent=d_rare_yld, l_dset=data.l_dset)
    mod.bdt_wp   = data.d_wp
    mod.kind     = mod_var
    d_mod        = mod.get_model() 
    d_val, d_var = mod.get_cons() 
    d_pre        = mod.get_prefit_pars(d_var=d_var, ck_cov=cv_sys+cv_sta)

    if mod_var is None:
        d_dat    = mod.get_data(rseed=rseed)
    else:
        d_dat    = get_data(d_eff=d_eff, d_nent=d_rare_yld, rseed=rseed)

    if data.l_dset == ['all_TOS'] or data.l_dset == ['all_TOS', 'all_TIS']:
        cmb                 = cmb_ck(rk=data.rk, eff=d_eff, yld=d_rare_yld)
        cmb.out_dir         = 'plots/combination'
        t_comb              = cmb.get_combination()
        d_rjpsi, d_eff, cov = t_comb
    else:
        cov = cv_sys + cv_sta

    obj          = ext(dset=data.l_dset, drop_correlations=False)
    obj.plt_dir  = f'plots/fits_{rseed:03}'
    obj.rjpsi    = d_rjpsi
    obj.eff      = d_eff
    obj.data     = d_dat
    obj.model    = d_mod 
    obj.fix      = None if fix_var is None else [fix_var]
    obj.cov      = cov 
    obj.const    = d_val, d_var
    result       = obj.get_fit_result()

    log.info(f'Calculating errors')
    result.hesse()
    d_pos = rkut.result_to_dict(result) 
    d_pos = add_ne(d_pos, d_pre)
    result.freeze()

    cleanup_env()

    return result, {'pre' : d_pre, 'pos' : d_pos} 
#--------------------------------
def initialize():
    log.setLevel(data.log_lvl)
    data.l_seed  = get_seeds()
    check_job_kind()

    os.makedirs(data.out_dir, exist_ok=True)
#--------------------------------
def check_job_kind():
    nseed = len(data.l_seed)
    nfix  = len(data.l_fixed)
    nmod  = len(data.l_model)

    data.for_syst = nseed == 1 and (nfix >  0 or  nmod  > 0)
    data.for_pull = nseed >  1 and  nfix == 0 and nmod == 0

    if   data.for_syst:
        log.info(f'Running systematics job')
    elif data.for_pull:
        log.info(f'Running pulls job')
    else:
        log.error(f'Misconfigured job, seeds/fixed = {nseed}/{nfix}')
        raise
#--------------------------------
def cleanup_env():
    d_par = zfit.Parameter._existing_params
    l_key = list(d_par.keys())

    for key in l_key:
        del(d_par[key])
#--------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used run toy fits on model used to extract RK')
    parser.add_argument('-l', '--level' , type =int, help='Logging level', choices=[logzero.DEBUG, logzero.INFO, logzero.WARNING], default=logzero.INFO)
    parser.add_argument('-d', '--dset'  , nargs='+', help='Datasets to use', default='all') 
    parser.add_argument('-v', '--vfix'  , nargs='+', help='Model parameters, a fit will be done with each of them fixed', default=['none'])
    parser.add_argument('-m', '--mode'  , nargs='+', help='Variants of the model, used to generate toys, fitted with nominal', default=['none']) 
    args = parser.parse_args()

    data.log_lvl = args.level
    data.l_dset  = [ dst for dst in args.dset if dst != 'all'  ]
    data.l_dset  = None if len(data.l_dset) == 0 else data.l_dset
    data.l_fixed = [ var for var in args.vfix if var != 'none' ]
    data.l_model = [ mod for mod in args.mode if mod != 'none' ]
#--------------------------------
def print_args():
    log.info('-' * 40)
    log.info(f'Args for {__file__}:')
    log.info('-' * 40)
    log.info(f'{"Level":<20}{data.log_lvl}')
    log.info(f'{"Datasets":<20}{data.l_dset}')
    log.info(f'{"Vars fixed":<20}{data.l_fixed}')
    log.info(f'{"Models":<20}{data.l_model}')
    log.info('-' * 40)
#--------------------------------
def run_pull_fits():
    if not data.for_pull:
        return

    for rseed in data.l_seed:
        with cf.ProcessPoolExecutor(max_workers=1) as executor:
            res, d_inf = executor.submit(fit, **{'rseed': rseed}).result()

        print(res)
        utnr.dump_pickle(res, f'{data.out_dir}/result_pkl/result_{rseed:04}.pkl')
        utnr.dump_json(d_inf, f'{data.out_dir}/result_jsn/result_{rseed:04}.json')
#--------------------------------
def run_syst_fits():
    if not data.for_syst:
        return

    [rseed] = data.l_seed
    with cf.ProcessPoolExecutor(max_workers=1) as executor:
        res, d_inf = executor.submit(fit, **{'rseed': rseed}).result()

    print(res)
    utnr.dump_pickle(res, f'{data.out_dir}/result_pkl/result_{rseed:04}.pkl')
    utnr.dump_json(d_inf, f'{data.out_dir}/result_jsn/result_{rseed:04}.json')

    for fix_var in data.l_fixed:
        with cf.ProcessPoolExecutor(max_workers=1) as executor:
            res, d_inf = executor.submit(fit, **{'rseed': rseed, 'fix_var' : fix_var}).result()

        print(res)
        utnr.dump_pickle(res, f'{data.out_dir}/result_pkl/result_{fix_var}.pkl')
        utnr.dump_json(d_inf, f'{data.out_dir}/result_jsn/result_{fix_var}.json')

    for mod_var in data.l_model:
        with cf.ProcessPoolExecutor(max_workers=1) as executor:
            res, d_inf = executor.submit(fit, **{'rseed': rseed, 'mod_var' : mod_var}).result()

        print(res)
        utnr.dump_pickle(res, f'{data.out_dir}/result_pkl/result_{mod_var}.pkl')
        utnr.dump_json(d_inf, f'{data.out_dir}/result_jsn/result_{mod_var}.json')
#--------------------------------
def main():
    get_args()
    print_args()
    initialize()
    run_pull_fits()
    run_syst_fits()

    with tarfile.open(f'{data.out_dir}/result_pkl.tar.gz', 'w:gz') as tar:
        tar.add(f'{data.out_dir}/result_pkl', arcname='result_pkl')

    with tarfile.open(f'{data.out_dir}/result_jsn.tar.gz', 'w:gz') as tar:
        tar.add(f'{data.out_dir}/result_jsn', arcname='result_jsn')
#--------------------------------
def get_file_seeds(seed_file):
    l_seed = []
    with open(seed_file) as ifile:
        l_seed = ifile.read().splitlines()

    return l_seed
#--------------------------------
def get_seeds():
    l_seed_file = glob.glob('*.sd')
    l_seed  = []
    for seed_file in l_seed_file:
        l_seed += get_file_seeds(seed_file)

    if len(l_seed) == 0:
        log.error(f'No seeds found')
        raise

    log.debug(f'Using seeds: {l_seed}')

    l_seed_int = [ int(rseed) for rseed in l_seed ]

    return l_seed_int
#--------------------------------
if __name__ == '__main__':
    main()

