#!/usr/bin/env python3

from logzero             import logger     as log
from dataclasses         import dataclass
from importlib.resources import files
from rkex_model          import model
from iminuit             import Minuit
from iminuit.cost        import LeastSquares

import matplotlib.pyplot as plt
import utils_noroot      as utnr
import pandas            as pnd
import jacobi            as jac
import argparse
import numpy 
import math
import os
import re

#-----------------------------------
class data:
    l_vers  = [1, 2, 3, 4, 5]
    vers    = None
    dset    = None
    trig    = None
    out_dir = None
#-----------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to plot PDF parameters in function of bins for SB fits')
    parser.add_argument('-v','--vers', type =str, help='Version'           , required=True)
    parser.add_argument('-t','--trig', type =str, help='Trigger'           , required=True)
    parser.add_argument('-d','--dset', type =str, help='Dataset'           , required=True)
    parser.add_argument('-p','--path', type =str, help='Output directory'  , required=True)
    args = parser.parse_args()

    data.vers   = args.vers
    data.dset   = args.dset
    data.trig   = args.trig
    data.out_dir= args.path
    data.x_calc = 5
#-----------------------------------
def get_vers_df(vers):
    log.debug(vers)

    json_path = files('extractor_data').joinpath(f'sb_fits/{data.vers}/{data.dset}_{data.trig}.json')
    d_par     = utnr.load_json(json_path)
    d_par     = { key.replace(f'_{vers}', '') : value for key, value in d_par.items() if key.endswith(f'_{vers}') }

    d_par_val = {var : [l_val[0]] for var, l_val in d_par.items()}
    d_par_err = {var : [l_val[1]] for var, l_val in d_par.items()}

    df_val = pnd.DataFrame(d_par_val)
    df_err = pnd.DataFrame(d_par_err)

    return df_val, df_err
#-----------------------------------
def get_df():
    l_tdf  = [get_vers_df(vers) for vers in data.l_vers ]
    l_df_v = [ tdf[0] for tdf in l_tdf]
    l_df_e = [ tdf[1] for tdf in l_tdf]

    df_v = pnd.concat(l_df_v, axis=0)
    df_e = pnd.concat(l_df_e, axis=0)

    df_v = df_v.reset_index(drop=True)
    df_e = df_e.reset_index(drop=True)

    df_v.index += 1
    df_e.index += 1

    return df_v, df_e
#-----------------------------------
def plot_pars(df_v, df_e):
    l_var    = df_v.columns 
    plot_dir = f'{data.out_dir}/plots'
    os.makedirs(plot_dir, exist_ok=True)
    for var in l_var:
        plot_par(df_v[var], df_e[var])

        plot_path = f'{plot_dir}/{var}.png'
        log.info(f'Saving to: {plot_path}')
        plt.grid()
        plt.ylabel('Value')
        plt.xlabel(var)
        plt.savefig(plot_path)
        plt.close('all')
#-----------------------------------
def plot_par(sr_val, sr_err):
    plt.errorbar(x=sr_val.index, y=sr_val, yerr=sr_err, fmt='o', capsize=3)
#-----------------------------------
def line(x, a, b):
    return a + x * b
#-----------------------------------
def fit_cmb(df_v, df_e):
    df_v    = df_v.iloc[:-1]
    df_e    = df_e.iloc[:-1]

    arr_var = df_v.index 
    arr_val = df_v.ncb_ee_all_TOS_sb_fits.values
    arr_err = df_e.ncb_ee_all_TOS_sb_fits.values

    least_squares = LeastSquares(arr_var, arr_val, arr_err, line)

    mini=Minuit(least_squares, a=0, b=0)
    mini.migrad()
    mini.hesse()

    return mini
#-----------------------------------
def calculate_ncmb(mini):
    print(mini)
    cov = mini.covariance
    val = mini.values
    ncm, var = jac.propagate(lambda x : x[0] + data.x_calc * x[1], val, cov)
    err = math.sqrt(var)

    log.info(f'Estimate at {data.x_calc}: {ncm:.0f}+/-{err:.0f}')

    return ncm, err 
#-----------------------------------
def plot_fit(df_v, df_e, mini, ncm, err):
    arr_x = df_v.index 
    arr_y = df_v.ncb_ee_all_TOS_sb_fits.values
    arr_e = df_e.ncb_ee_all_TOS_sb_fits.values

    plt.plot(arr_x, line(arr_x, *mini.values), label="fit")
    plt.errorbar(data.x_calc, ncm, yerr=err, marker='o', markersize=10, capsize=4, color='r')
    plt.errorbar(arr_x, arr_y, arr_e, fmt="ok", label="data")
    
    fit_info = [
        f"$\\chi^2$/$n_\\mathrm{{dof}}$ = {mini.fval:.1f} / {mini.ndof:.0f} = {mini.fmin.reduced_chi2:.1f}",
    ]

    for p, v, e in zip(mini.parameters, mini.values, mini.errors):
        fit_info.append(f"{p} = ${v:.3f} \\pm {e:.3f}$")


    plt.xlabel('BDT bin')
    plt.ylabel('$N_{cmb}$')
    plot_dir = f'{data.out_dir}/plots'
    plt.savefig(f'{plot_dir}/fit.png')
    plt.close('all')
#-----------------------------------
def main():
    get_args()

    df_v, df_e = get_df()

    plot_pars(df_v, df_e)

    mini=fit_cmb(df_v, df_e)

    ncm, err=calculate_ncmb(mini)

    plot_fit(df_v, df_e, mini, ncm, err)
#-----------------------------------
if __name__ == '__main__':
    main()

