#!/usr/bin/env python

import argparse
import logging
import os

import numpy
import tqdm

import limbo.data

parser = argparse.ArgumentParser(description="Print information about Limbo dataset(s).")
parser.add_argument("datadir", nargs="+", default=[], help="Limbo dataset director(ies).")
arguments = parser.parse_args()

logging.basicConfig(level=logging.INFO, format="%(message)s")
logging.getLogger("imagecat").setLevel(logging.WARN)

logging.info("Extracting statistics from:")
for path in arguments.datadir:
    logging.info(f"  {path}")

dataset = limbo.data.Dataset(arguments.datadir)

categories = set()
for sample in tqdm.tqdm(dataset, desc="Samples", unit="sample"):
    categories.update(sample.categories)
categories = sorted(list(categories))

column_map = {category: index for index, category in enumerate(categories)}

samples = numpy.zeros((len(dataset), len(categories)), dtype=numpy.int32)
for row, sample in enumerate(tqdm.tqdm(dataset, desc="Samples", unit="sample")):
    for category in sample.categories:
        samples[row, column_map[category]] = 1

logging.info(f"Total samples: {len(dataset)}.")
logging.info(f"Categories: {', '.join(categories)}.")
for index, category in enumerate(categories):
    logging.info(f"Samples in category {category}: {numpy.sum(samples[:,index])}")

for configuration in numpy.unique(samples, axis=0):
    label = " + ".join([categories[index] for index in numpy.flatnonzero(configuration)])
    matches = numpy.all(numpy.equal(samples, configuration), axis=1)
    logging.info(f"Samples tagged {label}: {numpy.sum(matches)}")
