#!/usr/bin/env python
#
# Author: Yipeng Sun
# License: BSD 2-clause
# Last Change: Thu Apr 14, 2022 at 05:44 AM -0400

import sys
import uproot
import numpy as np

from argparse import ArgumentParser
from itertools import product
from tabulate import tabulate


################
# Configurable #
################

LBL_PREC = 2


################################
# Command line argument parser #
################################

def parse_input(descr='print histogram content, with errors.'):
    parser = ArgumentParser(descr)

    parser.add_argument('ntp', help='specify path to ntuple.')

    parser.add_argument('-H', '--histo', required=True,
                        help='specify name of histogram.')

    parser.add_argument('-p', '--precision', default=4, type=int,
                        help='specify float print precision.')

    parser.add_argument('-t', '--tablefmt', default='fancy_grid',
                        help='specify table format.')

    parser.add_argument('-n', '--no-headers', action='store_true',
                        help='remove headers and indices.')

    parser.add_argument('-f', '--flow', action='store_true',
                        help='include under/overflow.')

    parser.add_argument('-r', '--roll', default=2, type=int,
                        help='build 2D histogram for each value along the specified axis.')

    return parser.parse_args()


#####################
# Histogram printer #
#####################

def make_data(values, errors, precision):
    data = np.empty(values.shape, dtype=object)
    idx = [list(range(i)) for i in values.shape]

    for i in product(*idx):
        data[i] = f'{values[i]:,.{precision}f} ± {errors[i]:,.{precision}f}'

    return data.tolist()


def print_cell_range(lower, upper):
    return f'{lower:,.{LBL_PREC}f}\n{upper:,.{LBL_PREC}f}'


def print_2d(values, errors, xbins, ybins,
             tablefmt='simple', precision=4, no_headers=False, legend='x / y'):
    header = [legend] + [print_cell_range(i, j)
                         for i, j in zip(ybins[:-1], ybins[1:])]
    index = [print_cell_range(i, j) for i, j in zip(xbins[:-1], xbins[1:])]
    data = make_data(values, errors, precision)

    if no_headers:
        print(tabulate(data, tablefmt=tablefmt))
    else:
        print(
            tabulate(data, headers=header, showindex=index, tablefmt=tablefmt))


########
# Main #
########

if __name__ == '__main__':
    args = parse_input()
    ntp = uproot.open(args.ntp)

    try:
        histo = ntp[args.histo]
    except Exception:
        print(f'"{args.histo}" is not in {args.ntp}! Available histos:')
        for i in ntp:
            print(f'  {i}')
        sys.exit(1)


    values, errors = histo.values(args.flow), histo.errors(args.flow)
    _, *binning = histo.to_numpy(args.flow)

    if len(binning) == 1:
        print('1D histogram not supported!')

    elif len(binning) == 2:
        print()
        print_2d(values, errors, binning[0], binning[1],
                 args.tablefmt, args.precision, args.no_headers)
        print()

    elif len(binning) == 3:
        axis_lbl = {
            0: ('x', 'y / z', 1, 2),
            1: ('y', 'x / z', 0, 2),
            2: ('z', 'x / y', 0, 1),
        }

        lbl_main, lbl_histo, bin_idx1, bin_idx2 = axis_lbl[args.roll]

        for v, e, low, high in \
                zip(np.moveaxis(values, args.roll, 0),
                    np.moveaxis(errors, args.roll, 0),
                    binning[args.roll][:-1], binning[args.roll][1:]):
            if not args.no_headers:
                print()
                print(f'## {lbl_main} ∈ [{low:,.{LBL_PREC}f}, {high:,.{LBL_PREC}f})')
            print()
            print_2d(v, e, binning[bin_idx1], binning[bin_idx2],
                     args.tablefmt, args.precision, args.no_headers,
                     legend=lbl_histo)
            print()

    else:
        print(f'{len(binning)}D histogram not supported!')
