#!/usr/bin/env python3
#
# Author: Yipeng Sun
# License: BSD 2-clause
# Last Change: Wed Feb 16, 2022 at 04:15 PM -0500

import sys
import numpy as np
import mplhep as hep

from pyTuplingUtils.argparse import (
    diff_branch_parser_no_output, DataPairAction, split_ntp_tree)

from pyTuplingUtils.utils import gen_histo
from pyTuplingUtils.utils import extract_uid
from pyTuplingUtils.plot import plot_top, plot_histo, plot_step
from pyTuplingUtils.plot import ax_add_args_histo, ax_add_args_step
from pyTuplingUtils.boolean.eval import BooleanEvaluator


################################
# Command line argument parser #
################################

def parse_input(
        descr='plot differences between branches from ntuples, UID-matched.'):
    parser = diff_branch_parser_no_output(descr)

    parser.add_argument('-o', '--output',
                        required=True,
                        help='specify output file.')

    parser.add_argument('-XD', '--x-data-range',
                        nargs=2,
                        action=DataPairAction,
                        default=None,
                        help='specify x-axis range.')

    parser.add_argument('-YD', '--y-data-range',
                        nargs=2,
                        action=DataPairAction,
                        default=None,
                        help='specify y-axis range.')

    parser.add_argument('-XL', '--xlabel',
                        default=None,
                        help='specify x-axis label.')

    parser.add_argument('-YL', '--ylabel',
                        default='Number of candidates',
                        help='specify y-axis label.')

    parser.add_argument('--title',
                        default=None,
                        help='specify title.')

    parser.add_argument('-l', '--labels',
                        required=True,
                        action='append',
                        nargs='+',
                        help='specify plot legend labels.')

    parser.add_argument('-c', '--colors',
                        default=None,
                        action='append',
                        nargs='+',
                        help='specify plot colors.')

    parser.add_argument('--ref-cuts',
                        default=None,
                        action='append',
                        nargs='+',
                        help='specify optional cuts for each reference branch.')

    parser.add_argument('--comp-cuts',
                        default=None,
                        action='append',
                        nargs='+',
                        help='specify optional cuts for each comparison branch.')

    parser.add_argument('--yscale',
                        nargs='?',
                        default='linear',
                        help='y axis scale (linear or log).')

    parser.add_argument('--normalize',
                        action='store_true',
                        help='normalize plots to 1.')

    parser.add_argument('--debug',
                        action='store_true',
                        help='enable debug mode.')

    return parser.parse_args()


########
# Main #
########

if __name__ == '__main__':
    args = parse_input()
    hep.style.use('LHCb2')

    first_plot = True
    plotters = []

    xmin, xmax = (0, 0) if not args.x_data_range else args.x_data_range
    if not args.y_data_range:
        ymin = ymax = 0 if args.yscale == 'linear' else 1
    else:
        ymin, ymax = args.y_data_range

    if not args.colors:
        good_colors = ['cornflowerblue', 'black', 'crimson', 'darkgoldenrod']
        good_colors.reverse()
        try:
            args.colors = [[good_colors.pop() for _ in br]
                           for br in args.ref_branch]
            # NOTE: The size and shape between ref_branch and comp_branch should
            #       be identical
        except IndexError:
            print('No color specified and we are running out of default colors.')
            print('Please specify plot colors manually.')
            sys.exit(255)

    if not args.ref_cuts:
        args.ref_cuts = [[None]*len(br) for br in args.ref_branch]
    if not args.comp_cuts:
        args.comp_cuts = [[None]*len(br) for br in args.comp_branch]

    for ref_ntp_tree, ref_branches, comp_ntp_tree, comp_branches, \
            ref_cuts, comp_cuts, colors, labels in zip(
                args.ref, args.ref_branch,
                args.comp, args.comp_branch,
                args.ref_cuts, args.comp_cuts,
                args.colors, args.labels):
        ref_ntp_name, ref_tree = split_ntp_tree(ref_ntp_tree)
        ref_ntp = ref_ntp_name
        ref_cutter = BooleanEvaluator(ref_ntp, ref_tree)

        if comp_ntp_tree == ref_ntp_tree:
            comp_ntp_name, comp_tree = ref_ntp_name, ref_tree
            comp_ntp = ref_ntp
            comp_cutter = ref_cutter
        else:
            comp_ntp_name, comp_tree = split_ntp_tree(comp_ntp_tree)
            comp_ntp = comp_ntp_name
            comp_cutter = BooleanEvaluator(comp_ntp, comp_tree)

        if args.debug:
            print('Working on reference: {}, tree: {}'.format(
                ref_ntp_name, ref_tree))
            print('Working on comparison: {}, tree: {}'.format(
                comp_ntp_name, comp_tree))

        for ref_br_name, comp_br_name, ref_cut, comp_cut, clr, lbl in zip(
                ref_branches, comp_branches, ref_cuts, comp_cuts, colors,
                labels):
            if args.debug:
                print('  with branch: {}-{}, color: {}, label: {}.'.format(
                    ref_br_name, comp_br_name, clr, lbl))

            ref_br = ref_cutter.eval(ref_br_name)
            comp_br = comp_cutter.eval(comp_br_name)
            if args.debug:
                print('    before cuts, ref: {}'.format(ref_br.size))
                print('    before cuts, comp: {}'.format(comp_br.size))

            # Cut defaults to None
            ref_cut_br = ref_cutter.eval(ref_cut) if ref_cut else ref_cut
            ref_uid, ref_idx, *_ = extract_uid(
                ref_ntp, ref_tree, conditional=ref_cut_br)

            comp_cut_br = comp_cutter.eval(comp_cut) if comp_cut else comp_cut
            comp_uid, comp_idx, *_ = extract_uid(
                comp_ntp, comp_tree, conditional=comp_cut_br)

            # Find common UIDs
            _, ref_common_idx, comp_common_idx = np.intersect1d(
                ref_uid, comp_uid, assume_unique=True, return_indices=True)
            ref_br = ref_br[ref_idx[ref_common_idx]]
            comp_br = comp_br[comp_idx[comp_common_idx]]

            if args.debug and ref_cut is not None:
                print('    apply cuts: {} on: {}'.format(ref_cut, ref_br_name))
                print('    after cuts: {}'.format(ref_br.size))
            if args.debug and comp_cut is not None:
                print('    apply cuts: {} on: {}'.format(
                    comp_cut, comp_br_name))
                print('    after cuts: {}'.format(comp_br.size))

            br = ref_br - comp_br
            if args.debug:
                print('    number of common candidates: {}'.format(br.size))

            histo, bins = gen_histo(
                br, args.bins, data_range=args.x_data_range,
                density=args.normalize)

            if not args.x_data_range:
                xmin = bins[0] if bins[0] < xmin else xmin
                xmax = bins[-1] if bins[-1] > xmax else xmax

            if not args.y_data_range:
                ymax = 1.1*histo.max() if ymax < histo.max() else ymax

            if args.debug:
                print('    xmin: {}, xmax: {}'.format(xmin, xmax))
                print('    ymin: {}, ymax: {}'.format(ymin, ymax))

            if first_plot:
                first_plot = False
                histo_args = ax_add_args_histo(lbl, clr)
                plotters.append(
                    lambda fig, ax, b=bins, h=histo, add=histo_args: plot_histo(
                        b, h, add, figure=fig, axis=ax, show_legend=False))

            else:
                step_args = ax_add_args_step(lbl, clr, linewidth=3)
                plotters.append(
                    lambda fig, ax, b=bins, h=histo, add=step_args: plot_step(
                        b, h, add, figure=fig, axis=ax, show_legend=False))

    plot_top(plotters, output=args.output, title=args.title,
             xlim=(xmin, xmax), ylim=(ymin, ymax),
             xlabel=args.xlabel, ylabel=args.ylabel, yscale=args.yscale)

    if args.debug:
        print('Plot saved to: {}'.format(args.output))
