#!/usr/bin/env python3

import os
import ROOT
import numpy
import pprint
import argparse
import pandas             as pnd
import utils_noroot       as utnr
import matplotlib.pyplot  as plt
import version_management  as vm
import selection.utilities as slut

from scipy.stats import gaussian_kde
from log_store   import log_store

from importlib.resources import files

import extset

log=log_store.add_logger('rk_extractor:check_dist')
#------------------------------------
class data:
    l_kind = ['bdt_cmb_sig', 'mass_qsq', 'bdt_tran']
    l_year = ['2011', '2012', '2015', '2016', '2017', '2018']
    l_trig = ['MTOS', 'ETOS']

    plt_dir= None
#------------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to make plots of different distributions')
    parser.add_argument('-k', '--kind' , type =str, help='Kind of plot', choices=data.l_kind, required=True)
    parser.add_argument('-y', '--year' , nargs='+', help='Datasets'    , choices=data.l_year, required=True)
    parser.add_argument('-t', '--trig' , nargs='+', help='Triggers'    , choices=data.l_trig, required=True)
    parser.add_argument('-p', '--pdir' , type =str, help='Plotting directory', required=True)
    args = parser.parse_args()

    data.kind   = args.kind
    data.l_year = args.year
    data.l_trig = args.trig
    data.plt_dir= utnr.make_dir_path(args.pdir)
#------------------------------------
def check_dist(kind):
    if   kind == 'bdt_cmb_sig':
        check_bdt_cmb_sig() 
    elif kind == 'mass_qsq':
        check_mass_qsq() 
    elif kind == 'bdt_tran':
        check_bdt_tran() 
    else:
        log.error(f'Kind not recognized: {kind}')
        raise
#------------------------------------
def check_bdt_tran():
    l_xval = [0.01 * x for x in range(-100, 101)]
    l_yval = [ slut.transform_bdt(x) for x in l_xval ]

    plt.plot(l_xval, l_yval)
    plt.xlabel('Original')
    plt.ylabel('Transformed')
    overlay_bdt_bounds(vertical=False, recalculated=False)
    plt.savefig(f'{data.plt_dir}/bdt_tran.png')
    plt.close('all')
#------------------------------------
def check_mass_qsq():
    cas_dir = os.environ['CASDIR']
    dat_dir = f'{cas_dir}/tools/apply_selection/blind_fits/data'
    for year in data.l_year:
        for trig in data.l_trig:
            dat_wc    = f'{dat_dir}/v10.21p2/{year}_{trig}/*.root'
            rdf       = ROOT.RDataFrame(trig, dat_wc)
            d_dat     = rdf.AsNumpy(['Jpsi_M', 'B_M'])
            df        = pnd.DataFrame(d_dat)
            df['qsq'] = df.Jpsi_M ** 2

            dat = numpy.vstack([df.B_M, df.qsq])
            z   = gaussian_kde(dat)(dat)

            plt.scatter(df.B_M, df.qsq, c=z, s=1)
            plt.savefig(f'{data.plt_dir}/bdt_qsq.png')
            plt.close('all')
#------------------------------------
def get_bounds(recalculated=False):
    l_bound = []
    if not recalculated:
        for bdt_bin in [1, 2, 3, 4, 5]:
            d_bin, _, _     = extset.get_bdt_bin_settings(bdt_bin)
            bound           = d_bin['BDT_cmb']
            l_bound.append(bound)

        return l_bound

    d_bin, _, _     = extset.get_bdt_bin_settings(5)
    [min_x, max_x]  = d_bin['BDT_cmb']
    min_x = slut.inverse_transform_bdt(min_x)
    max_x = slut.inverse_transform_bdt(max_x)
    dx    = max_x - min_x
    l_bound.append([min_x, max_x])
    for bdt_bin in [1, 2, 3, 4]:
        max_x = min_x
        min_x = max_x - bdt_bin * dx

        l_bound.append([min_x, max_x])

    return l_bound
#------------------------------------
def overlay_bdt_bounds(vertical=True, recalculated=False):
    l_bound = get_bounds(recalculated=recalculated)
    for [min_x, max_x] in l_bound:
        if vertical:
            plt.axvline(x=min_x, linestyle='-.', linewidth=0.3, color='r')
            plt.axvline(x=max_x, linestyle='-.', linewidth=0.3, color='r')
        else:
            plt.axhline(y=min_x, linestyle='-.', linewidth=0.3, color='r')
            plt.axhline(y=max_x, linestyle='-.', linewidth=0.3, color='r')
#------------------------------------
def check_bdt_cmb_sig():
    dir_path = files('extractor_data').joinpath('sig_wgt')
    ver_path = vm.get_last_version(dir_path=dir_path, version_only=False) 
    for trig in data.l_trig:
        for year in data.l_year:
            json_path = f'{ver_path}/{trig}_{year}.json'
            df = pnd.read_json(json_path)
            df['BDT_org'] = df.BDT_cmb.apply(slut.inverse_transform_bdt)
            plt.hist(df.BDT_cmb, bins=50, weights=df.wgt, label=f'{trig}; {year}', histtype='step')
            plt.hist(df.BDT_org, bins=50, weights=df.wgt, label=f'{trig}; {year}', histtype='step')
            overlay_bdt_bounds()

    plt.legend(['Transformed', 'Original'])
    plt.yscale('log')
    plt.savefig(f'{data.plt_dir}/bdt_cmb_sig.png')
    plt.close('all')
#------------------------------------
def main():
    get_args()
    check_dist(data.kind)
#------------------------------------
if __name__ == '__main__':
    main()

