#!/opt/miniconda3/bin/python

# TODO: * add search pattern / anti-pattern on datasets/groups
#       * nicer verbose output for scalar, h5refernce

import h5py
import argparse
from termcolor import colored
from collections import defaultdict

# globals
total_groups = 0
total_datasets = 0
group_color = 'green'
dataset_color = 'yellow'
attr_color = 'red'
scalar_color = attr_color
short_gap = ' '*2
long_gap = ' '*4

T_branch = '├── '
L_branch = '└── '
I_branch = '│   '
blank = ' '*4

terminated = defaultdict(lambda: False)

# create the parser options
parser = argparse.ArgumentParser(prog='h5tree')
parser.add_argument("path", type=str, help="filepath/grouppath")
parser.add_argument('-v', '--verbose', action='store_true', help='verbose output')
parser.add_argument('-a', '--attributes', action='store_true', help='show attributes')
parser.add_argument('-g', '--groups', action='store_true', help='only show groups')
parser.add_argument('-L', '--level', nargs='?', type=int, help='maximum number of directories to recuse into', default=None)
parser.add_argument('-p', '--pattern', nargs='?', type=str, help='pattern', default=None)

# argcomplete.autocomplete(parser)
args = parser.parse_args()

# parse the parsed input
filepath = args.path
locater = '.h5'
sep_index = args.path.find(locater) + len(locater)

filepath = args.path[:sep_index]
grouppath = args.path[sep_index+1:]
if not grouppath:
    grouppath = "/"

def str_count(n, name):
    """ return a string representing number of groups/datasets """
    if n != 1:
        return "{} {}s".format(n, name)
    else:
        return "{} {}".format(n, name)

def display_header(verbose = False):
    """ display the tree header """
    if grouppath == "/":
        header = filepath
    else:
        header = "{}/{}".format(filepath, grouppath)

    if verbose:
        message = str_count(len(group), "object")
        if group.attrs:
            message += ", " + str_count(len(group.attrs), "attribute")
        header += short_gap + "({})".format(message)

    print(colored(header, group_color))

def display_attributes(group, n, verbose = False):
    """ display the attribute on a single line """
    num_attrs = len(group.attrs)
    front = ""
    for i in range(n):
        if terminated[i]:
            front = front + blank
        else:
            front = front + I_branch

    if num_attrs > 0:
        for i,attr in enumerate(group.attrs):
            if i == num_attrs - 1 and (len(group.keys()) == 0 or args.groups):
                front_edit = front + L_branch
            else:
                front_edit = front + T_branch
            attr_output = front_edit + colored(attr, attr_color)

            if verbose:
                attr_output += colored(short_gap + str(group.attrs[attr]), None)
            print(attr_output)

def display(name, obj):
    """ display the group or dataset on a single line """
    global total_groups, total_datasets, terminated

    if args.pattern and args.pattern not in name:
        return

    depth = name.count('/')
    # abort if below level
    if args.level and depth >= args.level:
        return

    # reset terminated dict
    for d in terminated:
        if d > depth:
            terminated[d] = False

    # construct text at the front of line
    subname = name[name.rfind('/')+1:]
    front = ""
    for i in range(depth):
        if terminated[i]:
            front = front + blank
        else:
            front = front + I_branch

    if list(obj.parent.keys())[-1] == subname:
        front += L_branch
        terminated[depth] = True
    else:
        front += T_branch

    # is group
    if isinstance(obj, h5py.Group):
        output = front + colored(subname, group_color)
        if args.verbose:
            message = str_count(len(obj), "object")
            if obj.attrs:
                message += ", " + str_count(len(obj.attrs), "attribute")
            output += colored(short_gap + '({})'.format(message), group_color)

        total_groups += 1
        print(output)

    # is dataset
    elif not args.groups:
        color = dataset_color
        if not obj.shape:
            color = scalar_color 
        output = front + colored(subname, dataset_color)

        if args.verbose:
            output += short_gap + '{}, {}'.format(obj.shape, obj.dtype) 

        total_datasets += 1
        print(output)

    # include attributes
    if args.attributes:
        display_attributes(obj, depth + 1, args.verbose)

# open file and print tree
with h5py.File(filepath, 'r') as f:
    group = f[grouppath]
    display_header(args.verbose)
    if args.attributes:
        display_attributes(group, 1, args.verbose)

    group.visititems(display)

if args.verbose:
    footer = '{}, {}'.format(str_count(total_groups, "group"), str_count(total_datasets, "dataset"))
    print('\n', footer, sep='')
