#!/usr/bin/env python
# Copyright (C) 2012-2013 Educational Testing Service

# This file is part of SciKit-Learn Lab.

# SciKit-Learn Lab is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# SciKit-Learn Lab is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with SciKit-Learn Lab.  If not, see <http://www.gnu.org/licenses/>.

'''
Loads a trained model and outputs predictions based on input feature files.

:author: Dan Blanchard
:contact: dblanchard@ets.org
:organization: ETS
:date: February 2013
'''

from __future__ import print_function, unicode_literals

import argparse

from skll import Learner, load_examples
from skll.learner import _REGRESSION_MODELS
from skll.version import __version__


class Predictor(object):
    """
    Little wrapper around a ``Learner`` to load models and get
    predictions for feature strings.
    """

    def __init__(self, model_prefix, threshold=None, positive_class=1):
        '''
        Initialize the predictor.

        :param model_prefix: Prefix to use when loading trained model (and its
                             vocab).
        :type model_prefix: basestring
        :param threshold: If the model we're using is generating probabilities
                          of the positive class, return 1 if it meets/exceeds
                          the given threshold and 0 otherwise.
        :type threshold: float
        :param positive_class: If the model is only being used to predict the
                               probability of a particular class, this
                               specifies the index of the class we're
                               predicting. 1 = second class, which is default
                               for binary classification.
        :type positive_class: int
        '''
        self._learner = Learner()
        self._learner.load('{}.model'.format(model_prefix))
        self._pos_index = positive_class
        self.threshold = threshold

    def predict(self, data):
        '''
        Return a list of predictions for a given numpy array of examples
        (which are dicts)
        '''
        # Must make a list around a dictionary to fit format that
        # Learner.predict expects
        preds = self._learner.predict(data).tolist()

        if self._learner.probability:
            if self.threshold is None:
                return [pred[self._pos_index] for pred in preds]
            else:
                return [int(pred[self._pos_index] >= self.threshold)
                        for pred in preds]
        elif self._learner.model_type in _REGRESSION_MODELS:
            return preds
        else:
            return [self._learner.label_list[int(pred[0])] for pred in preds]


def main():
    ''' Main function that does all the work. '''
    # Get command line arguments
    parser = argparse.ArgumentParser(
        description="Loads a trained model and outputs predictions based \
                     on input feature files.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        conflict_handler='resolve')
    parser.add_argument('model_prefix', help='Prefix to use when loading \
                                              trained model (and its vocab).')
    parser.add_argument('input_file',
                        help='A csv file, json file, or megam file \
                              (with or without the label column), \
                              with the appropriate suffix.',
                        nargs='+')
    parser.add_argument('-l', '--has_labels',
                        help="Indicates that the input file includes \
                              labels and that the features start at the \
                              2nd column for csv and megam files.",
                        action='store_true')
    parser.add_argument('-p', '--positive_class',
                        help="If the model is only being used to predict the \
                              probability of a particular class, this \
                              specifies the index of the class we're \
                              predicting. 1 = second class, which is default \
                              for binary classification. Keep in mind that \
                              classes are sorted lexicographically.",
                        default=1, type=int)
    parser.add_argument('-t', '--threshold',
                        help="If the model we're using is generating \
                              probabilities of the positive class, return 1 \
                              if it meets/exceeds the given threshold and 0 \
                              otherwise.",
                        type=float)
    parser.add_argument('--version', action='version',
                        version='%(prog)s {0}'.format(__version__))
    args = parser.parse_args()

    # Create the classifier and load the model
    predictor = Predictor(args.model_prefix,
                          positive_class=args.positive_class,
                          threshold=args.threshold)

    for input_file in args.input_file:
        data = load_examples(input_file, has_labels=args.has_labels)
        for pred in predictor.predict(data):
            print(pred)


if __name__ == '__main__':
    main()
