#!/usr/bin/env python
"""
Usage: $ parrot-predict data_file saved_network output_file <flags>
  
Driver script for generating predictions on a list of sequences using a trained
bidirectional recurrent neural network. For more information on usage, use the '-h'
flag.

.............................................................................
idptools-parrot was developed by the Holehouse lab
     Original release ---- 2020

Question/comments/concerns? Raise an issue on github:
https://github.com/idptools/parrot

Licensed under the MIT license. 
"""

import os
import sys

import numpy as np
import torch
import argparse

from parrot import brnn_architecture
from parrot import process_input_data as pid
from parrot import train_network
from parrot import encode_sequence


# Parse the command line arguments
parser = argparse.ArgumentParser(description='Make predictions with a bi-directional RNN.')
parser.add_argument('seq_file', help='path to tsv file with format: <idx> <sequence> <data>')
parser.add_argument('saved_network', help='path to the trained network weights file')
parser.add_argument('output_file', help='file with sequences and their predicted values')
parser.add_argument('-d', '--datatype', metavar='dtype', default='sequence', type=str, required=True,
					help="Required. Format of the input data file, must be 'sequence' or 'residues'")
parser.add_argument('-c', '--classes', default=1, type=int, metavar='num_classes', required=True,
					help='Required. Number of output classes, for regression put 1')
parser.add_argument('-hs', '--hidden-size', default=5, type=int, metavar='hidden_size', 
						help='hidden vector size (def=5)')
parser.add_argument('-nl', '--num-layers', default=1, type=int, metavar='num_layers', 
						help='number of layers per direction (def=1)')
parser.add_argument('--encode', default='onehot', type=str, metavar='encoding_scheme',
					help="'onehot' (default), 'biophysics', or specify a path to a user-created scheme")
parser.add_argument('--exclude-seq-id', dest='excludeSeqID', action='store_true',
	help='use if data_file lacks sequence IDs in the first column of each line')

args = parser.parse_args()
device = 'cpu'

# Hyper-parameters
hidden_size = args.hidden_size
num_layers = args.num_layers

# Data format
dtype = args.datatype
num_classes = args.classes

# Other flags
encode = args.encode
excludeSeqID = args.excludeSeqID

###############################################################################
################    Validate arguments and initialize:      ###################

# Ensure that provided data_file exists
seq_file = os.path.abspath(args.seq_file)
if not os.path.isfile(seq_file):
	raise FileNotFoundError('Sequence file does not exist.')

# Ensure that the saved weights file exists
saved_weights = os.path.abspath(args.saved_network)
if not os.path.isfile(saved_weights):
	raise FileNotFoundError('Saved network does not exist.')

# Ensure that output file location is valid
output_file = os.path.abspath(args.output_file)
output_filename = output_file.split('/')[-1]
if not os.path.exists(output_file[:-len(output_filename)]):
	raise FileNotFoundError('Output directory does not exist.')

# Set encoding scheme and/or validate user scheme
if encode == 'onehot':
	encoding_scheme = 'onehot'
	input_size = 20
	encoder = None
elif encode == 'biophysics':
	encoding_scheme = 'biophysics'
	input_size = 9
	encoder = None
else:
	encoding_scheme = 'user'
	encode_file = encode
	encoder = encode_sequence.UserEncoder(encode_file)
	input_size = len(encoder)

# Initialize network as classifier or regressor
if num_classes > 1:
	problem_type = 'classification'
elif num_classes == 1:
	problem_type = 'regression'
else:
	raise ValueError('Number of classes must be a positive integer.')

# Ensure that hidden size and num layers are both positive ints
if hidden_size < 1:
	raise ValueError('Hidden vector size must be a positive integer.')
if num_layers < 1:
	raise ValueError('Number of layers must be a positive integer.')

# Initialize network architecture depending on data format
if dtype == 'sequence':
	# Use a many-to-one architecture
	brnn_network = brnn_architecture.BRNN_MtO(input_size, hidden_size, 
									num_layers, num_classes, device).to(device)
elif dtype == 'residues':
	# Use a many-to-many architecture
	brnn_network = brnn_architecture.BRNN_MtM(input_size, hidden_size, 
									num_layers, num_classes, device).to(device)
else:
	raise ValueError('Invalid argument `--datatype`: must be "residues" or "sequence".')

###############################################################################

brnn_network.load_state_dict(torch.load(saved_weights, 
				map_location=torch.device(device)))

# Convert sequence file to list of sequences:
if excludeSeqID:
	with open(seq_file) as f:
		sequences = [line.strip() for line in f]
else:
	with open(seq_file) as f:
		lines = [line.strip().split() for line in f]

	sequences = []
	seq_id_dict = {}
	for line in lines:
		sequences.append(line[1])
		seq_id_dict[line[1]] = line[0]

pred_dict = train_network.test_unlabeled_data(brnn_network, sequences, device, 
							encoding_scheme=encoding_scheme, encoder=encoder)

if problem_type == 'classification':
	if dtype == 'sequence':
		for seq, values in pred_dict.items():
			pred_dict[seq] = [np.argmax(values[0])]
	else:
		for seq, values in pred_dict.items():
			int_vals = []
			for row in values[0]:
				int_vals.append(np.argmax(row))
			pred_dict[seq] = int_vals
else:
	if dtype == 'sequence':
		for seq, values in pred_dict.items():
			pred_dict[seq] = list(values[0])
	else:
		for seq, values in pred_dict.items():
			pred_dict[seq] = list(values[0].flatten())


with open(output_file, 'w') as f:
	for seq, values in pred_dict.items():
		if excludeSeqID:
			out_str = seq + ' ' + ' '.join(map(str, values)) + '\n'
		else:
			out_str = seq_id_dict[seq] + ' ' + seq + ' ' + ' '.join(map(str, values)) + '\n'
		f.write(out_str)
