#!/usr/bin/env python3

from model_analyzer import analyzer  as mana
from np_reader      import np_reader as np_rdr
from logzero        import logger    as log
from rk_model       import rk_model

import os
import math
import zfit
import argparse
import rk.utilities as rkut

#--------------------------------
class data:
    jsn_dir = None
    out_dir = None
    ijob    = None
    nfit    = None
    nev_fac = 1
    obs     = zfit.Space('x', limits=(4500, 6000))
#--------------------------------
def prepare_check():
    try:
        os.makedirs(data.jsn_dir, exist_ok=True)
    except:
        log.error(f'Cannot make {data.jsn_dir}')
        raise
#--------------------------------
def fix_shapes(model):
    s_par = model.get_params()
    for par in s_par:
        if not par.name.startswith('n'):
            par.floating = False 

    return model
#--------------------------------
def skip_pdf(yld_name):
    is_sig = yld_name.startswith('nsg_')
    is_prc = yld_name.startswith('npr_')

    return not is_prc
#--------------------------------
def modify_model(model):
    model = fix_shapes(model)

    l_pdf = []
    for pdf in model.pdfs:
        yld = pdf.get_yield()
        if skip_pdf(yld.name):
            continue

        l_pdf.append(pdf)

    #data.obs = l_pdf[0].space

    #esg   = get_gaus()
    #l_pdf.append(esg)

    #model = zfit.pdf.SumPDF(l_pdf)

    #return model

    return l_pdf[0]
#--------------------------------
def get_rk_model():
    log.info('Getting nuisance parameters')
    rdr          = np_rdr(sys='v65', sta='v63', yld='v24')
    rdr.cache    = True
    d_eff        = rdr.get_eff()
    d_byld       = rdr.get_byields()
    d_nent       = rkut.average_byields(d_byld, l_exclude=['TIS'])
    d_rare_yld   = rkut.reso_to_rare(d_nent, kind='jpsi')

    log.info('Building model')
    mod         = rk_model(preffix='all_tos', d_eff=d_eff, d_nent=d_rare_yld, l_dset=['all_TOS'])
    mod.bdt_wp  = {'BDT_cmb' : 0.977, 'BDT_prc' : 0.480751}
    d_mod       = mod.get_model()
    d_val, d_var= mod.get_cons()
    _, mod_ee   = d_mod['all_TOS']
    mod_ee      = modify_model(mod_ee)

    d_const = { key : [val, math.sqrt(var)] for key, val, var in zip(d_val, d_val.values(), d_var.values())}

    return mod_ee, d_const
#--------------------------------
def get_expo():
    lam   = zfit.Parameter('lam', -0.0005,  -0.010, 0)
    bkg   = zfit.pdf.Exponential(lam=lam, obs=data.obs)
    nbk   = zfit.Parameter('nbk', 2000, 0, 10000)
    ebk   = bkg.create_extended(nbk)

    return ebk
#--------------------------------
def get_gaus():
    mu    = zfit.Parameter("mu", 5280, 5200, 5350)
    sg    = zfit.Parameter("sg", 50,  30, 60)
    sig   = zfit.pdf.Gauss(obs=data.obs, mu=mu, sigma=sg)
    nsg   = zfit.Parameter('nsg',  370, 0, 10000)
    esg   = sig.create_extended(nsg)

    return esg
#--------------------------------
def get_sm_model():
    ebk   = get_expo()
    esg   = get_gaus()
    pdf   = zfit.pdf.SumPDF([esg, ebk])

    return pdf, {} 
#--------------------------------
def check_model():
    log.info(f'Analyzing model, using seed: {data.ijob}')
    zfit.settings.set_seed(data.ijob)
    
    model, d_const = get_rk_model()

    obj            = mana(pdf=model, d_const = d_const, nev_fac=data.nev_fac)
    obj.out_dir    = data.out_dir 
    df_ini, df_val, df_err = obj.fit(nfit=data.nfit)

    df_ini.to_json(f'{data.jsn_dir}/ini_{data.ijob:03}.json', indent=4)
    df_val.to_json(f'{data.jsn_dir}/val_{data.ijob:03}.json', indent=4)
    df_err.to_json(f'{data.jsn_dir}/err_{data.ijob:03}.json', indent=4)
#--------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Makes a list of PFNs for a specific set of eventIDs in case we need to reprocess them')
    parser.add_argument('-i', '--job_ind', type=int, help='Job index, for naming purposes', required=True)
    parser.add_argument('-j', '--jsn_dir', type=str, help='Directory where JSON files go' , required=True)
    parser.add_argument('-f', '--nfit'   , type=int, help='Number of toy datasets'        , required=True)

    parser.add_argument('-o', '--out_dir', type=str, help='Directory where output diagnostics go, if not specified, will not save output')
    parser.add_argument('-n', '--nev_fac', type=int, help='Scale statistics of toys by this factor', default=data.nev_fac)
    args = parser.parse_args()

    data.ijob    = args.job_ind
    data.jsn_dir = args.jsn_dir
    data.nfit    = args.nfit

    data.out_dir = args.out_dir
    data.nev_fac = args.nev_fac
#--------------------------------
if __name__ == '__main__':
    get_args()
    prepare_check()
    check_model()

