#!/usr/bin/env python
import codecs
import semspaces.space
import argparse


parser = argparse.ArgumentParser(description="Subset semantic space")

parser.add_argument('-i', '--inputfile', help='file with semantic space')
parser.add_argument('-o', '--outfile', help='output file')
parser.add_argument('-f', '--format', default='csv',
                    help='input space format')
parser.add_argument('-u', '--outformat', default=None,
                    help='output space format')
parser.add_argument('-t', '--title', default=None,
                    help='title of the new space')
parser.add_argument('-v', '--vocab',
                    help='file with allowed words, default: allow all words')

args = parser.parse_args()

output_file = args.outfile
space_file = args.inputfile
space_format = args.format
vocab_file = args.vocab
title = args.title

if args.outformat is None:
    space_format_out = space_format
else:
    space_format_out = args.outformat


def load_vocab(vocabfile):
    fin = codecs.open(vocabfile, 'r', encoding='utf-8')
    words = set([w.strip() for w in fin.readlines()])
    return words

vocab = load_vocab(vocab_file)

vocab_comment = 'Vocabulary file entries: %s' % len(vocab)
print vocab_comment

if space_format == 'csv':
    sspace = semspaces.space.SemanticSpace.from_csv(space_file)
elif space_format == 'ssmarket':
    sspace = semspaces.space.SemanticSpace.from_ssmarket(space_file)
else:
    raise 'Give correct input format!'

orig_entries = 'Original semantic space entries: %s' % sspace.vectors.shape[0]
print orig_entries

vocab = list(set(sspace.included_words()).intersection(set(vocab)))

sspace = sspace.subset(vocab)

common_entries = 'Common entries: %s' % sspace.vectors.shape[0]
print common_entries

print 'New space shape: ', sspace.vectors.shape

if title:
    sspace.title = title

readme_addition = '\n'.join([vocab_comment, orig_entries, common_entries])

if sspace.readme:
    sspace.readme = '%s\n\n%s' % (sspace.readme, readme_addition)
else:
    sspace.readme = readme_addition

if space_format_out == 'csv':
    sspace.save_csv(output_file)
elif space_format_out == 'ssmarket':
    sspace.save_ssmarket(output_file)
else:
    raise 'Give correct input format!'
