#!/usr/bin/env python

# Copyright (C) 2012-2013 Educational Testing Service

# This file is part of SciKit-Learn Lab.

# SciKit-Learn Lab is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# SciKit-Learn Lab is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with SciKit-Learn Lab.  If not, see <http://www.gnu.org/licenses/>.

'''
Script that converts feature files from one format to another

:author: Nitin Madnani (nmadnani@ets.org)
:date: September 2013
'''

from __future__ import print_function, unicode_literals

import argparse
import os
import sys
from functools import partial

from skll.data import (_tsv_dict_iter, _json_dict_iter, _megam_dict_iter,
                       write_feature_file)
from skll.version import __version__

if __name__ == '__main__':
    # Get command line arguments
    parser = argparse.ArgumentParser(description="Takes an input feature file \
                                                  and converts it to another \
                                                  format. Formats are \
                                                  determined automatically from\
                                                  file extensions.",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('infile',
                        help='input feature file (ends in .jsonlines, .tsv, or \
                              .megam)')
    parser.add_argument('outfile',
                        help='output feature file (ends in .jsonlines, .tsv, or\
                              .megam)')
    parser.add_argument('--tsv_label',
                        help='Name of the column which contains \
                              the class labels in TSV files.',
                        default='y')
    parser.add_argument('--version', action='version',
                        version='%(prog)s {0}'.format(__version__))
    args = parser.parse_args()

    # make sure the input file extension is one we can process
    input_extension = os.path.splitext(args.infile)[1]

    if input_extension.endswith(".tsv"):
        example_gen_func = partial(_tsv_dict_iter, tsv_label=args.tsv_label)
    elif input_extension.endswith(".jsonlines"):
        example_gen_func = _json_dict_iter
    elif input_extension.endswith(".megam"):
        example_gen_func = _megam_dict_iter
    else:
        print('Input file must be in either .tsv, .megam, or ' +
              '.jsonlines format. You specified: {}'.format(input_extension),
              file=sys.stderr)

    # Iterate through input file and collect the information we need
    ids = []
    classes = []
    feature_dicts = []
    for example_id, class_name, feature_dict in example_gen_func(args.infile):
        feature_dicts.append(feature_dict)
        classes.append(class_name)
        ids.append(example_id)

    # write out the file in the requested output format
    write_feature_file(args.outfile, ids, classes, feature_dicts)
