#!/usr/bin/env python
# BPNet command-line tool
# Author: Jacob Schreiber <jmschreiber91@gmail.com>

import sys
import numpy
import torch
import argparse

from bpnetlite import BPNet
from bpnetlite.io import DataGenerator
from bpnetlite.io import extract_peaks

import json

desc = """BPNet is an neural network primarily composed of dilated residual
	convolution layers for modeling the associations between biological
	sequences and biochemical readouts. This tool will take in a fasta
	file for the sequence, a bed file for signal peak locations, and bigWig
	files for the signal to predict and the control signal, and train a
	BPNet model for you."""

# Read in the arguments
parser = argparse.ArgumentParser(description=desc)
parser.add_argument("-p", "--parameters", type=str,
	help="A JSON file containing the parameters for training the model.")

# Pull the arguments
args = parser.parse_args()

with open(args.parameters, "r") as infile:
	parameters = json.load(infile)

default_parameters = {
	'n_filters': 64,
	'n_layers': 8,
	'batch_size': 64,
	'in_window': 2114,
	'out_window': 1000,
	'max_jitter': 128,
	'reverse_complement': True,
	'max_epochs': 250,
	'validation_iter': 100,
	'lr': 0.001,
	'alpha': 1,
	'verbose': False,

	'training_chroms': ['chr1', 'chr2', 'chr3', 'chr5', 'chr6', 'chr7', 
		'chr8', 'chr9', 'chr10', 'chr12', 'chr13', 'chr14', 'chr16', 
		'chr18', 'chr19', 'chr20', 'chr22'],
	'validation_chroms': ['chr4', 'chr15', 'chr21'],
	'sequence_path': None,
	'peak_path': None,
	'plus_bw_path': None,
	'minus_bw_path': None,
	'plus_ctl_bw_path': None,
	'minus_ctl_bw_path': None,
	'random_state': None
}

for parameter, value in default_parameters.items():
	if parameter not in parameters:
		if value is None:
			raise ValueError("Must provide value for '{}'".format(parameter))

		parameters[parameter] = value

###

train_sequences, train_signals, train_controls = extract_peaks(
	sequences=parameters['sequence_path'],
	plus_bw_path=parameters['plus_bw_path'],
	minus_bw_path=parameters['minus_bw_path'],
	plus_ctl_bw_path=parameters['plus_ctl_bw_path'],
	minus_ctl_bw_path=parameters['minus_ctl_bw_path'],
	peak_path=parameters['peak_path'],
	chroms=parameters['training_chroms'],
	max_jitter=parameters['max_jitter'],
	verbose=parameters['verbose']
)

valid_sequences, valid_signals, valid_controls = extract_peaks(
	sequences=parameters['sequence_path'],
	plus_bw_path=parameters['plus_bw_path'],
	minus_bw_path=parameters['minus_bw_path'],
	plus_ctl_bw_path=parameters['plus_ctl_bw_path'],
	minus_ctl_bw_path=parameters['minus_ctl_bw_path'],
	peak_path=parameters['peak_path'],
	chroms=parameters['validation_chroms'],
	max_jitter=0,
	verbose=parameters['verbose']
)

###

training_peaks = DataGenerator(
	sequences=train_sequences,
	signals=train_signals,
	controls=train_controls,
	in_window=parameters['in_window'],
	out_window=parameters['out_window'],
	max_jitter=parameters['max_jitter'],
	reverse_complement=parameters['reverse_complement'],
	random_state=parameters['random_state'])

training_data = torch.utils.data.DataLoader(training_peaks, 
	pin_memory=True, 
	batch_size=parameters['batch_size'])

trimming = (parameters['in_window'] - parameters['out_window']) // 2

model = BPNet(n_filters=parameters['n_filters'], 
	n_layers=parameters['n_layers'],
	alpha=parameters['alpha'],
	trimming=trimming).cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=parameters['lr'])

model.fit_generator(training_data, optimizer, X_valid=valid_sequences, 
	X_ctl_valid=valid_controls, y_valid=valid_signals, 
	max_epochs=parameters['max_epochs'], 
	validation_iter=parameters['validation_iter'], 
	batch_size=parameters['batch_size'])
