#!/usr/bin/env python
# -*- coding: utf-8 -*-
import seutils
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
    'lfns', type=str, nargs='+',
    help='Paths to datasets (directories that contain .root files)'
    )
parser.add_argument(
    '-t', '--tree', type=str,
    help='Print one tree only'
    )
parser.add_argument(
    '-f', '--force', action='store_true',
    help='Forces a recount of events (i.e. cache is ignored)'
    )
args = parser.parse_args()

def main():
    # First preprocess any wildcards
    lfns = []
    for lfn in args.lfns:
        if '*' in lfn:
            lfns.extend(seutils.ls_wildcard(lfn))
        else:
            lfns.append(lfn)

    # Do the counting
    for lfn in lfns:
        lfn = seutils.cli_flexible_format(lfn, seutils.cli_detect_fnal())
        if len(lfns) > 1: print('Dataset {}'.format(lfn))
        output = seutils.root.count_dataset(lfn, read_cache=not(args.force))

        # Compute the tree total counts
        tree_totals = {}
        for rootfile, values in output.items():
            for key, value in values.items():
                if key == '_mtime': continue
                tree_totals.setdefault(key, 0)
                tree_totals[key] += value

        for tree, count in sorted(tree_totals.items()):
            if (args.tree and tree == args.tree) or not(args.tree):
                print(
                    '\033[31mTTree {0} ({1} entries)\033[0m'
                    .format(tree, count)
                    )
                for rootfile, entry in sorted(output.items()):
                    print('  {:6d}: {}'.format(entry[tree], rootfile))


if __name__ == '__main__':
    main()