#!/usr/bin/env python3
from collections import OrderedDict
import argparse
import pickle
import sys

from questionary import Choice
import pandas as pd
import questionary as q

from labeler.labeler import Labeler, LabelError


class CustomParser(argparse.ArgumentParser):
    def error(self, message):
        sys.stderr.write('error: %s\n' % message)
        self.print_help()
        sys.exit(2)


class Options:
    def __init__(
        self,
        retrain_interval: int,
    ):
        self.retrain_interval = retrain_interval


def main(options: Options, labeler: Labeler):
    menu = OrderedDict({
        'label': label,
        'label remainder automatically': label_remainder_automatically,
        'save checkpoint': save_checkpoint,
        'save labeled csv': save_labeled_csv,
        'set options': set_options,
        'quit': quit,
    })
    menu.move_to_end('quit')
    while True:
        print(f'\nLabeled: {labeler.n_reviewed}/{labeler.n_distinct}\n')
        option = q.select('Menu:', choices=menu.keys()).ask()
        menu[option](options, labeler)


def label(options: Options, labeler: Labeler):
    stack = []
    stack_size = 100
    n_labeled = 0
    navs = ['back', 'menu']
    while labeler.n_remaining > 0:
        print(f'\nLabeled: {labeler.n_reviewed}/{labeler.n_distinct}\n')
        sample = labeler.review(n=1).fillna(0).to_dict('records')[0]
        res = _label(sample, labeler, navs)
        if res == 'back':
            history = [x['text'] for x in stack]
            res = q.select('History:',
                           choices=history + ['cancel'],
                           default=history[-1]).ask()
            if res != 'cancel':
                sample = [s for s in stack if s['text'] == res][0]
                _label(sample, labeler, navs=['cancel'])
        elif res == 'menu':
            return
        else:
            stack.append(sample)
            stack = stack[-stack_size:]
            n_labeled += 1
            if n_labeled % options.retrain_interval == 0:
                labeler.train()


def _label(sample, labeler, navs):
    if labeler.mutually_exclusive_labels:
        default = sample['label'] if sample['label'] else labeler.labels[0]
        res = q.select(sample['text'],
                       choices=labeler.labels + navs,
                       default=default).ask()
        if res in navs:
            return res
        sample['label'] = res
    else:
        choices = [
            Choice(lbl, checked=bool(sample[lbl])) for lbl in labeler.labels
        ]
        res = q.checkbox(sample['text'],
                         choices=choices + navs,
                         validate=validate_multiple).ask()
        if res[0] in navs:
            return res[0]
        for lbl in labeler.labels:
            sample[lbl] = 1 if lbl in res else 0
    labeler.update(pd.DataFrame([sample]))


def validate_multiple(res):
    if not res:
        return False
    if set(res).intersection({'back', 'cancel', 'menu'}) and len(res) > 1:
        return False
    return True


def label_remainder_automatically(options: Options, labeler: Labeler):
    try:
        labeler.label_remainder_automatically()
    except LabelError as e:
        print(e)


def save_checkpoint(options: Options, labeler: Labeler):
    path = q.path('checkpoint file path:').ask()
    if not path.endswith('.pkl'):
        path += '.pkl'
    with open(path, 'wb') as f:
        pickle.dump({'options': options, 'labeler': labeler}, f)
        print(f'Saved: {path}')


def save_labeled_csv(options: Options, labeler: Labeler):
    path = q.path('csv file path:').ask()
    if not path.endswith('.csv'):
        path += '.csv'
    labeler.save_csv(path)
    print(f'Saved: {path}')


def set_options(options: Options, labeler: Labeler):
    selection = q.select('Option:', choices=[
        'retrain interval',
        'back',
    ]).ask()
    selection += ':'
    if selection == 'retrain interval:':
        options.retrain_interval = int(
            q.select(selection,
                     choices=['10', '25', '50', '75', '100'],
                     default=str(options.retrain_interval)).ask())


def load_checkpoint(path: str):
    with open(path, 'rb') as f:
        d = pickle.load(f)
    return d['options'], d['labeler']


def quit(options: Options, labeler: Labeler):
    sys.exit(0)


def parse_args(argv):
    parser = argparse.ArgumentParser(
        description='labels raw texts',
        prog=argv[0],
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    sub = parser.add_subparsers(help='commands', dest='command')
    start = sub.add_parser(
        'start',
        help='start a new labeler instance',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    start.add_argument('csv_path', help='csv with text to label')
    start.add_argument('text_column_name', help='column in csv to label')
    start.add_argument('labels', nargs='+', help='label categories for text')
    start.add_argument('-xor',
                       '--mutually_exclusive_labels',
                       action='store_true',
                       help='make labels mutually exclusive')
    start.add_argument('-ri',
                       '--retrain_interval',
                       type=int,
                       default=10,
                       help='retrain a model every `n` labels')
    resume = sub.add_parser(
        'resume',
        help='resume a labeler instance',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    resume.add_argument('checkpoint_pkl')
    if len(sys.argv) == 1:
        parser.print_help(sys.stderr)
        sys.exit(1)
    return parser.parse_args(argv[1:])


def float_zero_one(arg):
    try:
        f = float(arg)
    except ValueError:
        raise argparse.ArgumentTypeError('Must be a floating point number')
    if f < 0 or f > 1.0:
        raise argparse.ArgumentTypeError('Argument must be between 0 and 1')
    return f


if __name__ == '__main__':
    args = parse_args(sys.argv)
    if args.command == 'start':
        options = Options(args.retrain_interval, )
        labeler = Labeler(
            texts=pd.read_csv(args.csv_path)[args.text_column_name].values,
            labels=args.labels,
            mutually_exclusive_labels=args.mutually_exclusive_labels,
        )
    elif args.command == 'resume':
        options, labeler = load_checkpoint(args.checkpoint_pkl)
    main(options, labeler)
