#!/usr/bin/env python

# Copyright (C) 2012-2013 Educational Testing Service

# This file is part of SciKit-Learn Lab.

# SciKit-Learn Lab is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# SciKit-Learn Lab is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with SciKit-Learn Lab.  If not, see <http://www.gnu.org/licenses/>.

'''
Script that converts from CSV to MegaM format

:author: Dan Blanchard (dblanchard@ets.org)
:date: June 2012
'''

from __future__ import unicode_literals, print_function

import argparse
import sys
from decimal import Decimal, InvalidOperation

from bs4 import UnicodeDammit

from skll.version import __version__


def sanitize_name(feature_name):
    '''
    Replaces bad characters in feature names.
    '''
    return feature_name.replace(" ", "_").replace("#", "HASH")


if __name__ == '__main__':
    # Get command line arguments
    parser = argparse.ArgumentParser(description="Takes a delimited file with \
                                                  a header line and converts it\
                                                  to MegaM.",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('infile', help='MegaM input file',
                        type=argparse.FileType('rb'), default='-', nargs='?')
    parser.add_argument('-c', '--classfield',
                        help='Index of class field in CSV file. Note: fields \
                              are numbered starting at 0.',
                        default=-1, type=int)
    parser.add_argument('-d', '--delimiter',
                        help='The column delimiter.', default=',')
    parser.add_argument('-i', '--idfield',
                        help='Index of ID field in CSV file (if there is one).\
                              This will be included as a comment before each \
                              line.  Note: fields are numbered starting at 0.',
                        type=int)
    parser.add_argument('--version', action='version',
                        version='%(prog)s {0}'.format(__version__))
    args = parser.parse_args()

    if args.infile.isatty():
        print(("You are running this script interactively. Press CTRL-D at " +
               "the start of a blank line to signal the end of your input. " +
               "For help, run it with --help\n"),
              file=sys.stderr)

    # Initialize variables
    classes = set()
    instances = []
    fields = []

    # Iterate through input file
    first = True
    for line_num, line in enumerate(args.infile):
        stripped_line = UnicodeDammit(line.strip(),
                                      ['utf-8', 'windows-1252']).unicode_markup
        split_line = stripped_line.split(args.delimiter)
        # Skip blank lines
        if split_line:
            # Process header
            if first:
                # Check for weird commented-out header in tbl files
                fields = split_line[1:] if split_line[0] == '#' else split_line
                fields = [sanitize_name(field) for field in fields]
                # To fix sorting issues, make field indexes positive
                if args.idfield < 0:
                    args.idfield += len(fields)
                if args.classfield < 0:
                    args.classfield += len(fields)
                # Delete extra fields
                if args.idfield is not None:
                    # Have to sort descending so that we don't screw up the
                    # indices
                    for i in sorted((args.idfield, args.classfield),
                                    reverse=True):
                        del fields[i]
                else:
                    del fields[args.classfield]
                first = False
            else:
                # Delete extra fields
                if args.idfield is not None:
                    # Print id field
                    print("# {}".format(split_line[args.idfield]).encode('utf-8'))
                    # Print class
                    print('{}'.format(split_line[args.classfield]).encode('utf-8'),
                          end='\t')

                    # Have to sort descending so that we don't screw up the
                    # indices
                    for i in sorted((args.idfield, args.classfield),
                                    reverse=True):
                        try:
                            del split_line[i]
                        except IndexError as e:
                            print(("ERROR: Could not delete element at index " +
                                   "{} from list {}.").format(i, split_line),
                                  file=sys.stderr)
                            sys.exit(2)
                else:
                    # Print class
                    print('{}'.format(split_line[args.classfield]).encode('utf-8'),
                          end='\t')
                    del split_line[args.classfield]

                # Print features
                try:
                    print(' '.join(['{} {}'.format(field, value) for field, value in
                                    zip(fields, split_line) if (value not in ['.', '?'] and
                                                                Decimal(value) != 0)]).encode('utf-8'))
                except InvalidOperation:
                    for value in split_line:
                        if value not in ['.', '?']:
                            try:
                                Decimal(value)
                            except InvalidOperation as e:
                                print(("Could not convert '{}' to Decimal on " +
                                       "line {}.").format(value,
                                                          line_num).encode('utf-8'),
                                      file=sys.stderr)
                                sys.exit(2)
