#!/usr/bin/env python
"""
Usage: $ parrot-train data_file output_network <flags>
  
Driver script for training a bidirectional recurrent neural network with user
specified parameters. For more information on usage, use the '-h' flag.

.............................................................................
idptools-parrot was developed by the Holehouse lab
     Original release ---- 2020

Question/comments/concerns? Raise an issue on github:
https://github.com/idptools/parrot

Licensed under the MIT license. 
"""

import os
import sys

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import argparse

from parrot import process_input_data as pid
from parrot import brnn_architecture
from parrot import train_network
from parrot import brnn_plot
from parrot import encode_sequence

# Parse the command line arguments
parser = argparse.ArgumentParser(description='Train and test a bi-directional RNN using entire sequence.')
parser.add_argument('data_file', help='path to tsv file with format: <idx> <sequence> <data>')
parser.add_argument('output_network', help='location to save the trained network')
parser.add_argument('-d', '--datatype', metavar='dtype', type=str, required=True,
                    help="Required. Format of the input data file, must be 'sequence' or 'residues'")
parser.add_argument('-c', '--classes', type=int, metavar='num_classes', required=True,
                    help='Required. Number of output classes, for regression put 1')
parser.add_argument('-hs', '--hidden-size', default=10, type=int, metavar='hidden_size',
                    help='hidden vector size (def=10)')
parser.add_argument('-nl', '--num-layers', default=1, type=int, metavar='num_layers',
                    help='number of layers per direction (def=1)')
parser.add_argument('-lr', '--learning-rate', default=0.001, type=float,
                    metavar='learning_rate', help='(def=0.001)')
parser.add_argument('-b', '--batch', default=32, type=int, metavar='batch_size',
                    help='size of training batch (def=32)')
parser.add_argument('-e', '--epochs', default=100, type=int, metavar='num_epochs',
                    help='number of training epochs (def=100)')
parser.add_argument('--split', default='', metavar='split_file', type=str,
                    help="file indicating how to split datafile into training, validation, and test sets")
parser.add_argument('--stop', default='iter', metavar='stop_condition',
                    type=str, help="training stop condition: either 'auto' or 'iter' (default 'iter')")
parser.add_argument('--set-fractions', nargs=3, default=[0.7, 0.15, 0.15], type=float,
                    dest='setFractions', metavar=('train', 'val', 'test'),
                    help='proportion of dataset that should be divided into training, validation, and test sets')
parser.add_argument('--encode', default='onehot', type=str, metavar='encoding_scheme',
                    help="'onehot' (default), 'biophysics', or specify a path to a user-created scheme")
parser.add_argument('--exclude-seq-id', dest='excludeSeqID', action='store_true',
                    help='use if data_file lacks sequence IDs in the first column of each line')
parser.add_argument('--probabilistic-classification', dest='probabilistic_classification',
                    action='store_true', help='Optional implementation for binary sequence classificaion')
parser.add_argument('--include-figs', dest='include_figs', action='store_true',
                    help='Generate figures from training results and save to same location as network')
parser.add_argument('--force-cpu', dest='forceCPU', action='store_true',
                    help='force network to train on CPU, even if GPU is available')
parser.add_argument('--verbose', '-v', action='count', default=0,
                    help='''how descriptive output to console should be. Excluding this flag will 
			cause no output, using this flag two or more times will cause maximum output''')

args = parser.parse_args()

# Hyper-parameters
hidden_size = args.hidden_size
num_layers = args.num_layers
learning_rate = args.learning_rate
batch_size = args.batch
num_epochs = args.epochs

# Data format
dtype = args.datatype
num_classes = args.classes

# Other flags
split_file = args.split
stop_cond = args.stop
encode = args.encode
verbosity = args.verbose
forceCPU = args.forceCPU
setFractions = args.setFractions
excludeSeqID = args.excludeSeqID
probabilistic_classification = args.probabilistic_classification
include_figs = args.include_figs

# Device configuration
if forceCPU:
    device = 'cpu'
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

###############################################################################
################    Validate arguments and initialize:      ###################

# Ensure that provided data_file exists
data_file = os.path.abspath(args.data_file)
if not os.path.isfile(data_file):
    raise FileNotFoundError('Datafile does not exist.')

# Ensure that output network location is valid
saved_weights = os.path.abspath(args.output_network)
network_filename = saved_weights.split('/')[-1]
output_dir = saved_weights[:-len(network_filename)]
if not os.path.exists(output_dir):
    raise FileNotFoundError('Output network directory does not exist.')

# If provided, check that split_file exists
if split_file != '':
    split_file = os.path.abspath(split_file)
    if not os.path.isfile(split_file):
        raise FileNotFoundError('Split-file does not exist.')
else:
    split_file = None

# Set encoding scheme and/or validate user scheme
if encode == 'onehot':
    encoding_scheme = 'onehot'
    input_size = 20
    encoder = None
elif encode == 'biophysics':
    encoding_scheme = 'biophysics'
    input_size = 9
    encoder = None
else:
    encoding_scheme = 'user'
    encode_file = encode
    encoder = encode_sequence.UserEncoder(encode_file)
    input_size = len(encoder)

# Initialize network as classifier or regressor
if num_classes > 1:
    problem_type = 'classification'
elif num_classes == 1:
    problem_type = 'regression'
else:
    raise ValueError('Number of classes must be a positive integer.')

# Ensure that learning rate is between 0 and 1
if learning_rate >= 1 or learning_rate <= 0:
    raise ValueError('Learning rate must be between 0 and 1.')

# Ensure that stop condition is 'iter' or 'auto'
if stop_cond == 'auto':
    if num_epochs > 10:
        print("Warning: Stop condition is set to 'auto' and num_epochs > 10." +
              " Network training may take a long time.\n")
elif stop_cond != 'iter':
    raise ValueError('Invalid argument for `--stop` -- must be "auto" or "iter".')

# Ensure that hidden size, num layers, batch size, and num epochs are all positive ints
if hidden_size < 1:
    raise ValueError('Hidden vector size must be a positive integer.')
if num_layers < 1:
    raise ValueError('Number of layers must be a positive integer.')
if num_epochs < 1:
    raise ValueError('Number of epochs must be a positive integer.')
if batch_size < 1:
    raise ValueError('Batch size must be a positive integer.')

# Ensure that the sum of setFractions adds up to 1
for frac in setFractions:
    if 0 >= frac or frac >= 1:
        raise ValueError('All set fractions must be between 0 and 1.')
if sum(setFractions) != 1.0:
    raise ValueError('Set fractions must sum to 1.')

# Ensure that task is binary sequence classification if
# probabilistic_classfication is set
if probabilistic_classification:
    if dtype != 'sequence' or num_classes != 2:
        raise ValueError('Proportional classification only implemented for binary sequence classification')

# Initialize network architecture depending on data format
if dtype == 'sequence':
    # Use a many-to-one architecture
    brnn_network = brnn_architecture.BRNN_MtO(input_size, hidden_size,
                                              num_layers, num_classes, device).to(device)
    # Set collate function
    if problem_type == 'classification':
        collate_function = pid.seq_class_collate
    else:
        collate_function = pid.seq_regress_collate
elif dtype == 'residues':
    # Use a many-to-many architecture
    brnn_network = brnn_architecture.BRNN_MtM(input_size, hidden_size,
                                              num_layers, num_classes, device).to(device)
    # Set collate function
    if problem_type == 'classification':
        collate_function = pid.res_class_collate
    else:
        collate_function = pid.res_regress_collate
else:
    raise ValueError('Invalid argument `--datatype`: must be "residues" or "sequence".')

###############################################################################
################################  Main code  ##################################

# Split data
train, val, test = pid.split_data(data_file, datatype=dtype, problem_type=problem_type,
                                  num_classes=num_classes, excludeSeqID=excludeSeqID, split_file=split_file,
                                  encoding_scheme=encoding_scheme, encoder=encoder,
                                  percent_val=setFractions[1], percent_test=setFractions[2])

# Add data to dataloaders
train_loader = torch.utils.data.DataLoader(dataset=train,
                                           batch_size=batch_size,
                                           collate_fn=collate_function,
                                           shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val,
                                         batch_size=batch_size,
                                         collate_fn=collate_function,
                                         shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset=test,
                                          batch_size=1,		# Set test batch size to 1
                                          collate_fn=collate_function,
                                          shuffle=False)

# Output to std out
if verbosity > 0:
    print("PARROT with user-specified parameters")
    print("-------------------------------------")
    if verbosity > 1:
        print('Train on:\t%s' % device)
        print("Datatype:\t%s" % dtype)
        print("ML Task:\t%s" % problem_type)
        print("Learning rate:\t%f" % learning_rate)
        print("Number of layers:\t%d" % num_layers)
        print("Hidden vector size:\t%d" % hidden_size)
        print("Batch size:\t%d\n" % batch_size)

    print("Validation set loss per epoch:")

# Train network
train_loss, val_loss = train_network.train(brnn_network, train_loader, val_loader, datatype=dtype,
                                           problem_type=problem_type, weights_file=saved_weights, stop_condition=stop_cond,
                                           device=device, learn_rate=learning_rate, n_epochs=num_epochs, verbosity=verbosity)

if include_figs:  # Plot training & validation loss per epoch
    brnn_plot.training_loss(train_loss, val_loss, output_dir=output_dir)

# Test network
test_loss, test_set_predictions = train_network.test_labeled_data(brnn_network, test_loader,
                                                                  datatype=dtype, problem_type=problem_type,
                                                                  weights_file=saved_weights, num_classes=num_classes,
                                                                  probabilistic_classification=probabilistic_classification,
                                                                  include_figs=include_figs, device=device)

if verbosity > 0:
    print('\nTest Loss: %.4f' % test_loss)

# Output the test set predictions to a text file
brnn_plot.output_predictions_to_file(test_set_predictions, excludeSeqID, encoding_scheme,
                                     encoder, output_dir=output_dir)
