#!/usr/bin/env python3
#
# Author: Yipeng Sun
# License: BSD 2-clause
# Last Change: Wed Feb 16, 2022 at 03:31 PM -0500

import sys
import mplhep as hep
import numpy as np

from argparse import Action

from pyTuplingUtils.argparse import (
    single_branch_parser_no_output, DataPairAction, split_ntp_tree)

from pyTuplingUtils.utils import gen_histo
from pyTuplingUtils.plot import plot_top, plot_histo, plot_step, plot_vlines
from pyTuplingUtils.plot import ax_add_args_histo, ax_add_args_step, \
    ax_add_args_vlines
from pyTuplingUtils.boolean.eval import BooleanEvaluator


################################
# Command line argument parser #
################################

class VlineAction(Action):
    def __call__(self, parser, namespace, values, option_string=None):
        vlines = []
        for vl in values:
            x, ymin, ymax, color, *args = vl.split(',')

            lbl = '_nolegend_' if not args else args[0]
            vlines.append({'x': float(x), 'ymin': float(ymin),
                           'ymax': float(ymax), 'color': color, 'lbl': lbl})

        setattr(namespace, self.dest, vlines)


def parse_input(descr='plot branches from ntuples.'):
    parser = single_branch_parser_no_output(descr)

    parser.add_argument('-o', '--output',
                        required=True,
                        help='specify output file.')

    parser.add_argument('-XD', '--x-data-range',
                        nargs=2,
                        action=DataPairAction,
                        default=None,
                        help='specify x-axis range.')

    parser.add_argument('-YD', '--y-data-range',
                        nargs=2,
                        action=DataPairAction,
                        default=None,
                        help='specify y-axis range.')

    parser.add_argument('-XL', '--xlabel',
                        default=None,
                        help='specify x-axis label.')

    parser.add_argument('-YL', '--ylabel',
                        default='Number of candidates',
                        help='specify y-axis label.')

    parser.add_argument('--title',
                        default=None,
                        help='specify title.')

    parser.add_argument('-l', '--labels',
                        required=True,
                        action='append',
                        nargs='+',
                        help='specify plot legend labels.')

    parser.add_argument('-c', '--colors',
                        default=None,
                        action='append',
                        nargs='+',
                        help='specify plot colors.')

    parser.add_argument('--cuts',
                        default=None,
                        action='append',
                        nargs='+',
                        help='specify optional cuts for each branch.')

    parser.add_argument('--weights',
                        default=None,
                        action='append',
                        nargs='+',
                        help='specify optional weight for each branch.')

    parser.add_argument('--yscale',
                        nargs='?',
                        default='linear',
                        help='y axis scale (linear or log).')

    parser.add_argument('--normalize',
                        action='store_true',
                        help='normalize plots to 1.')

    parser.add_argument('--debug',
                        action='store_true',
                        help='enable debug mode.')

    parser.add_argument('--vlines',
                        nargs='+',
                        default=[],
                        action=VlineAction,
                        help='specify optional vertical lines.')

    return parser.parse_args()


########
# Main #
########

if __name__ == '__main__':
    args = parse_input()
    hep.style.use('LHCb2')

    first_plot = True
    plotters = []

    # Add vertical lines if defined
    for vl in args.vlines:
        vline_add_args = plotters.append(
            lambda fig, ax, x=vl['x'], y=(vl['ymin'], vl['ymax']),
            add=ax_add_args_vlines(
                color=vl['color'], linestyles='dashed', linewidth=2,
                label=vl['lbl']):
            plot_vlines(x, y, add, figure=fig, axis=ax, show_legend=False,
                        convert_x=False)
        )

    xmin, xmax = (0, 0) if not args.x_data_range else args.x_data_range
    if not args.y_data_range:
        ymin = ymax = 0 if args.yscale == 'linear' else 1
    else:
        ymin, ymax = args.y_data_range

    if not args.colors:
        good_colors = ['cornflowerblue', 'black', 'crimson', 'darkgoldenrod']
        good_colors.reverse()
        try:
            args.colors = [[good_colors.pop() for _ in br]
                           for br in args.ref_branch]
        except IndexError:
            print('No color specified and we are running out of default colors.')
            print('Please specify plot colors manually.')
            sys.exit(255)

    if not args.cuts:
        args.cuts = [[None]*len(br) for br in args.ref_branch]

    if not args.weights:
        args.weights = [[None]*len(br) for br in args.ref_branch]

    for ntp_tree, branches, colors, labels, cuts, weights in \
            zip(args.ref, args.ref_branch, args.colors, args.labels,
                args.cuts, args.weights):
        ntp_name, tree = split_ntp_tree(ntp_tree)
        cutter = BooleanEvaluator(ntp_name, tree)

        if args.debug:
            print('Working on: {}, tree: {}'.format(ntp_name, tree))

        for br_name, clr, lbl, cut, weight in \
                zip(branches, colors, labels, cuts, weights):
            br = cutter.eval(br_name)

            if args.debug:
                print('  with branch: {}, color: {}, label: {}.'.format(
                    br_name, clr, lbl))
                print('    number of candidates: {}'.format(br.size))

            if weight is None or weight == 'None':
                br_wt = np.ones(br.size)
            else:
                br_wt = cutter.eval(weight)
                if args.debug:
                    print('    apply weights: {} on: {}'.format(
                        weight, br_name))
                    print('    total sum of weights: {}'.format(np.sum(br_wt)))

            if cut:
                br = br[cutter.eval(cut)]
                br_wt = br_wt[cutter.eval(cut)]
                if args.debug:
                    print('    apply cuts: {} on: {}'.format(cut, br_name))
                    print('    after cuts: {}'.format(br.size))

            histo, bins = gen_histo(
                br, args.bins, data_range=args.x_data_range,
                density=args.normalize, weights=br_wt.astype(np.double))

            if not args.x_data_range:
                xmin = bins[0] if bins[0] < xmin else xmin
                xmax = bins[-1] if bins[-1] > xmax else xmax

            if not args.y_data_range:
                ymax = 1.1*histo.max() if ymax < histo.max() else ymax

            if args.debug:
                print('    xmin: {}, xmax: {}'.format(xmin, xmax))
                print('    ymin: {}, ymax: {}'.format(ymin, ymax))

            if first_plot:
                first_plot = False
                histo_args = ax_add_args_histo(lbl, clr)
                # NOTE: Closure binds late, so if we don't bind these temp
                #       variables (e.g. 'histo') to the lambda, they'll all
                #       point to the last histo and all plots will be the same
                plotters.append(
                    lambda fig, ax, b=bins, h=histo, add=histo_args: plot_histo(
                        b, h, add, figure=fig, axis=ax, show_legend=False))

            else:
                step_args = ax_add_args_step(lbl, clr)
                plotters.append(
                    lambda fig, ax, b=bins, h=histo, add=step_args: plot_step(
                        b, h, add, figure=fig, axis=ax, show_legend=False))

    if not args.y_data_range and args.yscale == 'log':
        ylim = None
    else:
        ylim = (ymin, ymax)

    plot_top(plotters, output=args.output, title=args.title,
             xlim=(xmin, xmax), ylim=ylim,
             xlabel=args.xlabel, ylabel=args.ylabel, yscale=args.yscale)

    if args.debug:
        print('Plot saved to: {}'.format(args.output))
