#!/usr/bin/env python

from b4msa.utils import tweet_iterator
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import f1_score, accuracy_score, precision_score
import json
import os
import sys

def performance(gold, predicted, output=None, avgf1=None):
    y = [str(tweet['klass']) for tweet in tweet_iterator(gold)]
    le = LabelEncoder().fit(y)
    y = le.transform(y)

    hy = [str(tweet['klass']) for tweet in tweet_iterator(predicted)]
    hy = le.transform(hy)
    #with open(fname.replace('_output.json', '.params')) as fpt:
    #    a = json.load(fpt)
    f1 = {le.inverse_transform([klass])[0]: f1 for klass, f1 in enumerate(f1_score(y, hy, average=None))}
    precision = {le.inverse_transform([klass])[0]: p for klass, p in enumerate(precision_score(y, hy, average=None))}

    report = {
        'macrof1': f1_score(y, hy, average='macro'),
        'microf1': f1_score(y, hy, average='micro'),
        'accuracy': accuracy_score(y, hy),
        'filename': predicted,
    }

    for key, value in f1.items():
        report["f1_{0}".format(key)] = value
        # report["precision_{0}".format(key)] = precision[key]

    if avgf1:
        L = [f1.get(k) for k in avgf1.split(':')]
        report['avg_f1_{0}'.format(avgf1)] = sum(L) / len(L)

    if output is None:
        print(json.dumps(report, sort_keys=1, indent=2) + "\n")
    else:
        with open(output, 'w') as f:
            f.write(json.dumps(report, sort_keys=1) + "\n")


if __name__ == '__main__':
    from argparse import ArgumentParser
    parser = ArgumentParser(description='microtc')
    parser.add_argument('gold', nargs=1, default=None, help='Gold standard set')
    parser.add_argument('predicted', nargs=1, default=None, help='Predicted set')
    parser.add_argument('-a', '--avgf1', dest='avgf1', default=None, type=str, help='Indicates if microtc-perf must show the average of some selected classes (using `:` as separator')
    parser.add_argument('-o', '--output', dest='output', default=None, type=str, help='Indicates the name of the output file, defaults to stdout')
    args = parser.parse_args()
    performance(args.gold[0], args.predicted[0], args.output, args.avgf1)
