#!/usr/bin/env python
"""
Usage: $ parrot-optimize data_file output_network <flags>
  
Driver script for finding optimal hyperparameters for a bidirectional recurrent 
neural network on a given dataset, then training a network with those parameters
For more information on usage, use the '-h' flag.

.............................................................................
idptools-parrot was developed by the Holehouse lab
     Original release ---- 2020

Question/comments/concerns? Raise an issue on github:
https://github.com/idptools/parrot

Licensed under the MIT license. 
"""

import os
import sys

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import argparse

from parrot import process_input_data as pid
from parrot import brnn_architecture
from parrot import train_network
from parrot import brnn_plot
from parrot import bayesian_optimization
from parrot import encode_sequence

# Parse the command line arguments
parser = argparse.ArgumentParser(description='Train and test a bi-directional RNN using entire sequence.')
parser.add_argument('data_file', help='path to tsv file with format: <idx> <sequence> <data>')
parser.add_argument('output_network', help='location to save the trained network')
parser.add_argument('-d', '--datatype', metavar='dtype', type=str, required=True,
                    help="Required. Format of the input data file, must be 'sequence' or 'residues'")
parser.add_argument('-c', '--classes', type=int, metavar='num_classes', required=True,
                    help='Required. Number of output classes, for regression put 1')
parser.add_argument('-b', '--batch', default=32, type=int, metavar='batch_size',
                    help='size of training batch (def=32)')
parser.add_argument('-e', '--epochs', default=100, type=int, metavar='num_epochs',
                    help='number of training epochs (def=100)')
parser.add_argument('--max-iter', default=50, type=int, metavar='max_iter',
                    help='Maximum number of iterations for the optimization procedure (def=50)')
parser.add_argument('--split', default='', metavar='split_file', type=str,
                    help="file indicating how to split datafile into training, validation, and testing sets")
parser.add_argument('--set-fractions', nargs=3, default=[0.7, 0.15, 0.15], type=float,
                    dest='setFractions', metavar=('train', 'val', 'test'),
                    help='Proportion of dataset that should be divided into training, validation, and test sets')
parser.add_argument('--encode', default='onehot', type=str, metavar='encoding_scheme',
                                        help="'onehot' (default), 'biophysics', or specify a path to a user-created scheme")
parser.add_argument('--exclude-seq-id', dest='excludeSeqID', action='store_true',
                    help='use if data_file lacks sequence IDs in the first column of each line')
parser.add_argument('--probabilistic-classification', dest='probabilistic_classification',
                    action='store_true', help='Optional implementation for binary sequence classificaion')
parser.add_argument('--include-figs', dest='include_figs', action='store_true',
                    help='Generate figures from training results and save to same location as network')
parser.add_argument('--force-cpu', dest='forceCPU', action='store_true',
                    help='force network to train on CPU, even if GPU is available')
parser.add_argument('--verbose', '-v', action='count', default=0,
                    help='''how descriptive output to console should be. Excluding this flag will 
			cause no output, using this flag two or more times will cause maximum output''')

args = parser.parse_args()

# Hyper-parameters
batch_size = args.batch
num_epochs = args.epochs
max_iterations = args.max_iter

# Data format
dtype = args.datatype
num_classes = args.classes

# Other flags
split_file = args.split
encode = args.encode
verbosity = args.verbose
forceCPU = args.forceCPU
setFractions = args.setFractions
excludeSeqID = args.excludeSeqID
probabilistic_classification = args.probabilistic_classification
include_figs = args.include_figs

# Device configuration
if forceCPU:
    device = 'cpu'
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

###############################################################################
################    Validate arguments and initialize:      ###################

# Ensure that provided data_file exists
data_file = os.path.abspath(args.data_file)
if not os.path.isfile(data_file):
    raise FileNotFoundError('Datafile does not exist.')

# Ensure that output network location is valid
saved_weights = os.path.abspath(args.output_network)
network_filename = saved_weights.split('/')[-1]
output_dir = saved_weights[:-len(network_filename)]
if not os.path.exists(output_dir):
    raise FileNotFoundError('Output network directory does not exist.')

# If provided, check that split_file exists
if split_file != '':
    split_file = os.path.abspath(split_file)
    if not os.path.isfile(split_file):
        raise FileNotFoundError('Split-file does not exist.')
else:
    split_file = None

# Set encoding scheme and/or validate user scheme
if encode == 'onehot':
    encoding_scheme = 'onehot'
    input_size = 20
    encoder = None
elif encode == 'biophysics':
    encoding_scheme = 'biophysics'
    input_size = 9
    encoder = None
else:
    encoding_scheme = 'user'
    encode_file = encode
    encoder = encode_sequence.UserEncoder(encode_file)
    input_size = len(encoder)

# Initialize network as classifier or regressor
if num_classes > 1:
    problem_type = 'classification'
elif num_classes == 1:
    problem_type = 'regression'
else:
    raise ValueError('Number of classes must be a positive integer.')

# Initialize network architecture depending on data format
if dtype == 'sequence':
    # Set collate function
    if problem_type == 'classification':
        collate_function = pid.seq_class_collate
    else:
        collate_function = pid.seq_regress_collate
elif dtype == 'residues':
    # Set collate function
    if problem_type == 'classification':
        collate_function = pid.res_class_collate
    else:
        collate_function = pid.res_regress_collate
else:
    raise ValueError('Invalid argument `--datatype`: must be "residues" or "sequence".')

# Ensure that  batch size and num epochs are both positive ints
if num_epochs < 1:
    raise ValueError('Number of epochs must be a positive integer.')
if batch_size < 1:
    raise ValueError('Batch size must be a positive integer.')

# Ensure that the sum of setFractions adds up to 1
for frac in setFractions:
    if 0 >= frac or frac >= 1:
        raise ValueError('All set fractions must be between 0 and 1.')
if sum(setFractions) != 1.0:
    raise ValueError('Set fractions must sum to 1.')

# Ensure that task is binary sequence classification if
# probabilistic_classfication is set
if probabilistic_classfication:
    if dtype != 'sequence' or num_classes != 2:
        raise ValueError('Proportional classification only implemented for binary sequence classification')

###############################################################################

# Split data
cvs, train, val, test = pid.split_data_cv(data_file, datatype=dtype, problem_type=problem_type,
                                          num_classes=num_classes, excludeSeqID=excludeSeqID,
                                          split_file=split_file, encoding_scheme=encoding_scheme,
                                          encoder=encoder,
                                          percent_val=setFractions[1], percent_test=setFractions[2])

# Convert CV datasets to dataloaders
cv_loaders = []
for cv_train, cv_val in cvs:
    cv_train_loader = torch.utils.data.DataLoader(dataset=cv_train, batch_size=batch_size,
                                                  collate_fn=collate_function, shuffle=True)
    cv_val_loader = torch.utils.data.DataLoader(dataset=cv_val, batch_size=batch_size,
                                                collate_fn=collate_function, shuffle=False)
    cv_loaders.append((cv_train_loader, cv_val_loader))

# Output to std out
if verbosity > 0:
    print("PARROT with hyperparameter optimization")
    print("---------------------------------------")
    if verbosity > 1:
        print('Train on:\t%s' % device)
        print("Datatype:\t%s" % dtype)
        print("ML Task:\t%s" % problem_type)
        print("Batch size:\t%d" % batch_size)
        print("Number of epochs:\t%d" % num_epochs)
        print("Number of optimization iterations:\t%d\n" % max_iterations)

# Optimization procedure
optimizer = bayesian_optimization.BayesianOptimizer(cv_loaders, input_size, num_epochs,
                                                    num_classes, dtype, saved_weights,
                                                    max_iterations, device, verbosity)

best_hyperparams = optimizer.optimize()
lr = 10**best_hyperparams[0]
nl = int(best_hyperparams[1])
hs = int(best_hyperparams[2])

# Save these hyperparamters to a file so that the user has a record
params_file = saved_weights[:-len(network_filename)] + 'optimal_hyperparams.txt'
with open(params_file, 'w') as f:
    f.write('Learning rate: %.5f\n' % lr)
    f.write('Num Layers: %d\n' % nl)
    f.write('Hidden vector size: %d\n' % hs)

# Use these best hyperparams to train the network from scratch using the entire train/val sets
# Add data to dataloaders
train_loader = torch.utils.data.DataLoader(dataset=train,
                                           batch_size=batch_size,
                                           collate_fn=collate_function,
                                           shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val,
                                         batch_size=batch_size,
                                         collate_fn=collate_function,
                                         shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset=test,
                                          batch_size=1,		# Set test batch size to 1
                                          collate_fn=collate_function,
                                          shuffle=False)

# Initialize network:
if dtype == 'sequence':
    brnn_network = brnn_architecture.BRNN_MtO(input_size, hs, nl, num_classes, device).to(device)
else:  # dtype == 'residues'
    brnn_network = brnn_architecture.BRNN_MtM(input_size, hs, nl, num_classes, device).to(device)

# Train network
if verbosity > 0:
    print('Training with optimal hyperparams:')
train_loss, val_loss = train_network.train(brnn_network, train_loader, val_loader, datatype=dtype,
                                           problem_type=problem_type, weights_file=saved_weights, stop_condition='iter',
                                           device=device, learn_rate=lr, n_epochs=num_epochs*2, verbosity=verbosity)

if include_figs:  # Plot training & validation loss per epoch
    brnn_plot.training_loss(train_loss, val_loss, output_dir=output_dir)

# Test network
test_loss, test_set_predictions = train_network.test_labeled_data(brnn_network, test_loader,
                                                                  datatype=dtype, problem_type=problem_type,
                                                                  weights_file=saved_weights, num_classes=num_classes,
                                                                  probabilistic_classification=probabilistic_classification,
                                                                  include_figs=include_figs, device=device)
if verbosity > 0:
    print('\nTest Loss: %.4f' % test_loss)

# Output the test set predictions to a text file
brnn_plot.output_predictions_to_file(test_set_predictions, excludeSeqID, encoding_scheme,
                                     encoder, output_dir=output_dir)
