#!/usr/bin/env python
"""
Usage: $ parrot-train data_file output_network <flags>
  
Driver script for training a bidirectional recurrent neural network with user
specified parameters. For more information on usage, use the '-h' flag.

.............................................................................
idptools-parrot was developed by the Holehouse lab
     Original release ---- 2020

Question/comments/concerns? Raise an issue on github:
https://github.com/idptools/parrot

Licensed under the MIT license. 
"""

import os
import sys

import torch 
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np 
import argparse

from parrot import process_input_data as pid
from parrot import brnn_architecture
from parrot import train_network
from parrot import brnn_plot
from parrot import encode_sequence

# Parse the command line arguments
parser = argparse.ArgumentParser(description='Train and test a bi-directional RNN using entire sequence.')
parser.add_argument('data_file', help='path to tsv file with format: <idx> <sequence> <data>')
parser.add_argument('output_network', help='path for the returned trained network')
parser.add_argument('-d', '--datatype', metavar='dtype', default='sequence', type=str, required=True,
					help="Required. Format of the input data file, must be 'sequence' or 'residues'")
parser.add_argument('-c', '--classes', default=1, type=int, metavar='num_classes', required=True,
					help='Required. Number of output classes, for regression put 1')
parser.add_argument('-hs', '--hidden-size', default=5, type=int, metavar='hidden_size', 
						help='hidden vector size (def=5)')
parser.add_argument('-nl', '--num-layers', default=1, type=int, metavar='num_layers', 
						help='number of layers per direction (def=1)')
parser.add_argument('-lr', '--learning-rate', default=0.001, type=float, 
						metavar='learning_rate', help='(def=0.001)')
parser.add_argument('-b', '--batch', default=32, type=int, metavar='batch_size', 
						help='size of training batch (def=32)')
parser.add_argument('-e', '--epochs', default=50, type=int, metavar='num_epochs', 
						help='number of training epochs (def=50)')
parser.add_argument('--split', default='', metavar='split_file', type=str, 
	help="file indicating how to split datafile into training, validation, and test sets")
parser.add_argument('--stop', default='iter', metavar='stop_condition', 
						type=str, help="training stop condition: either 'auto' or 'iter' (default 'iter')")
parser.add_argument('--set-fractions', nargs=3, default=[0.7, 0.15, 0.15], type=float, 
		dest='setFractions', metavar=('train', 'val', 'test'),
		help='proportion of dataset that should be divided into training, validation, and test sets')
parser.add_argument('--encode', default='onehot', type=str, metavar='encoding_scheme',
		help="'onehot' (default), 'biophysics', or specify a path to a user-created scheme")
parser.add_argument('--exclude-seq-id', dest='excludeSeqID', action='store_true', 
		help='use if data_file lacks sequence IDs in the first column of each line')
parser.add_argument('--force-cpu', dest='forceCPU', action='store_true', 
		help='force network to train on CPU, even if GPU is available')
parser.add_argument('--verbose', '-v', action='count', default=0, 
		help='''how descriptive output to console should be. Excluding this flag will 
			cause no output, using this flag two or more times will cause maximum output''')

args = parser.parse_args()

# Hyper-parameters
hidden_size = args.hidden_size
num_layers = args.num_layers
learning_rate = args.learning_rate
batch_size = args.batch
num_epochs = args.epochs

# Data format
dtype = args.datatype
num_classes = args.classes

# Other flags
split_file = args.split
stop_cond = args.stop
encode = args.encode
verbosity = args.verbose
forceCPU = args.forceCPU
setFractions = args.setFractions
excludeSeqID = args.excludeSeqID

# Device configuration
if forceCPU:
	device = 'cpu'
else:
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

###############################################################################
################    Validate arguments and initialize:      ###################

# Ensure that provided data_file exists
data_file = os.path.abspath(args.data_file)
if not os.path.isfile(data_file):
	raise FileNotFoundError('Datafile does not exist.')

# Ensure that output network location is valid
saved_weights = os.path.abspath(args.output_network)
network_filename = saved_weights.split('/')[-1]
output_dir = saved_weights[:-len(network_filename)]
if not os.path.exists(output_dir):
	raise FileNotFoundError('Output network directory does not exist.')

# If provided, check that split_file exists
if split_file != '':
	split_file = os.path.abspath(split_file)
	if not os.path.isfile(split_file):
		raise FileNotFoundError('Split-file does not exist.')
else:
	split_file=None

# Set encoding scheme and/or validate user scheme
if encode == 'onehot':
	encoding_scheme = 'onehot'
	input_size = 20
	encoder = None
elif encode == 'biophysics':
	encoding_scheme = 'biophysics'
	input_size = 9
	encoder = None
else:
	encoding_scheme = 'user'
	encode_file = encode
	encoder = encode_sequence.UserEncoder(encode_file)
	input_size = len(encoder)
	
# Initialize network as classifier or regressor
if num_classes > 1:
	problem_type = 'classification'
elif num_classes == 1:
	problem_type = 'regression'
else:
	raise ValueError('Number of classes must be a positive integer.')

# Ensure that learning rate is between 0 and 1
if learning_rate >= 1 or learning_rate <= 0:
	raise ValueError('Learning rate must be between 0 and 1.')

# Ensure that stop condition is 'iter' or 'auto'
if stop_cond == 'auto':
	if num_epochs > 10:
		print("Warning: Stop condition is set to 'auto' and num_epochs > 10." +
							 " Network training may take a long time.\n")
elif stop_cond != 'iter':
	raise ValueError('Invalid argument for `--stop` -- must be "auto" or "iter".')

# Ensure that hidden size, num layers, batch size, and num epochs are all positive ints
if hidden_size < 1:
	raise ValueError('Hidden vector size must be a positive integer.')
if num_layers < 1:
	raise ValueError('Number of layers must be a positive integer.')
if num_epochs < 1:
	raise ValueError('Number of epochs must be a positive integer.')
if batch_size < 1:
	raise ValueError('Batch size must be a positive integer.')

# Ensure that the sum of setFractions adds up to 1
for frac in setFractions:
	if 0 >= frac or frac >= 1:
		raise ValueError('All set fractions must be between 0 and 1.')
if sum(setFractions) != 1.0:
	raise ValueError('Set fractions must sum to 1.')

# Initialize network architecture depending on data format
if dtype == 'sequence':
	# Use a many-to-one architecture
	brnn_network = brnn_architecture.BRNN_MtO(input_size, hidden_size, 
									num_layers, num_classes, device).to(device)
	# Set collate function
	if problem_type == 'classification':
		collate_function = pid.seq_class_collate
	else:
		collate_function = pid.seq_regress_collate
elif dtype == 'residues':
	# Use a many-to-many architecture
	brnn_network = brnn_architecture.BRNN_MtM(input_size, hidden_size, 
									num_layers, num_classes, device).to(device)	
	# Set collate function
	if problem_type == 'classification':
		collate_function = pid.res_class_collate
	else:
		collate_function = pid.res_regress_collate
else:
	raise ValueError('Invalid argument `--datatype`: must be "residues" or "sequence".')

###############################################################################

# Split data
train, val, test = pid.split_data(data_file, datatype=dtype, problem_type=problem_type, 
					num_classes=num_classes, excludeSeqID=excludeSeqID, split_file=split_file, 
					encoding_scheme=encoding_scheme, encoder=encoder, 
					percent_val=setFractions[1], percent_test=setFractions[2])

# Add data to dataloaders
train_loader = torch.utils.data.DataLoader(dataset=train,
                                           batch_size=batch_size,
                                           collate_fn=collate_function,
                                           shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val,
                                           batch_size=batch_size,
                                           collate_fn=collate_function,
                                           shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset=test,
                                          batch_size=1,		# Set test batch size to 1
                                          collate_fn=collate_function,
                                          shuffle=False)

# Output to std out
if verbosity > 0:
	print("PARROT with user-specified parameters")
	print("-------------------------------------")
	if verbosity > 1:
		print('Train on:\t%s' % device)
		print("Datatype:\t%s" % dtype)
		print("ML Task:\t%s" % problem_type)
		print("Learning rate:\t%f" % learning_rate)
		print("Number of layers:\t%d" % num_layers)
		print("Hidden vector size:\t%d" % hidden_size)
		print("Batch size:\t%d\n" % batch_size)

	print("Validation set loss per epoch:")

# Train network
train_loss, val_loss = train_network.train(brnn_network, train_loader, val_loader, datatype=dtype, 
						problem_type=problem_type, weights_file=saved_weights, stop_condition=stop_cond,
						device=device, learn_rate=learning_rate, n_epochs=num_epochs, verbosity=verbosity)
brnn_plot.training_loss(train_loss, val_loss, output_dir=output_dir)

# Test network
test_loss, test_set_predictions = train_network.test_labeled_data(brnn_network, test_loader, 
									datatype=dtype, problem_type=problem_type, 
									weights_file=saved_weights, num_classes=num_classes, 
									device=device)

if verbosity > 0:
	print('\nTest Loss: %.4f' % test_loss)

# Output the test set predictions to a text file
brnn_plot.output_predictions_to_file(test_set_predictions, excludeSeqID, encoding_scheme, 
														encoder, output_dir=output_dir)
