#!/usr/bin/env python
# tf-modisco command-line tool
# Author: Jacob Schreiber <jmschreiber91@gmail.com>

import h5py
import hdf5plugin
import argparse
import modiscolite

import numpy as np

desc = """TF-MoDISco is a motif detection algorithm that takes in nucleotide
	sequence and the attributions from a neural network model and return motifs
	that are repeatedly enriched for attriution score across the examples.
	This tool will take in one-hot encoded sequence, the corresponding
	attribution scores, and a few other parameters, and return the motifs."""

# Read in the arguments
parser = argparse.ArgumentParser(description=desc)
subparsers = parser.add_subparsers(help="Must be either 'motifs', 'report', or 'convert'.", required=True, dest='cmd')

motifs_parser = subparsers.add_parser("motifs", help="Run TF-MoDISco and extract the motifs.")
motifs_parser.add_argument("-s", "--sequences", type=str,
	help="A .npy or .npz file containing the one-hot encoded sequences.")
motifs_parser.add_argument("-a", "--attributions", type=str,
	help="A .npy or .npz file containing the hypothetical attributions, i.e., the attributions for all nucleotides at all positions.")
motifs_parser.add_argument("-i", "--h5py", type=str,
	help="A legacy h5py file containing the one-hot encoded sequences and shap scores.")
motifs_parser.add_argument("-n", "--max_seqlets", type=int, required=True,
	help="The maximum number of seqlets per metacluster.")
motifs_parser.add_argument("-l", "--n_leiden", type=int, default=2,
	help="The number of Leiden clusterings to perform with different random seeds.")
motifs_parser.add_argument("-w", "--window", type=int, default=400,
	help="The window surrounding the peak center that will be considered for motif discovery.")
motifs_parser.add_argument("-o", "--output", type=str, default="modisco_results.h5",
	help="The path to the output file.")
motifs_parser.add_argument("-v", "--verbose", action="store_true", default=False,
	help="Controls the amount of output from the code.")

report_parser = subparsers.add_parser("report", help="Create a HTML report of the results.")
report_parser.add_argument("-i", "--h5py", type=str, required=True,
	help="An HDF5 file containing the output from modiscolite.")
report_parser.add_argument("-o", "--output", type=str, required=True,
	help="A directory to put the output results including the html report.")
report_parser.add_argument("-s", "--suffix", type=str, default="./",
	help="The suffix to add to the beginning of images. Should be equal to the output if using a Jupyter notebook.")
report_parser.add_argument("-m", "--meme_db", type=str, required=True,
	help="A MEME file containing motifs.")
report_parser.add_argument("-n", "--n_matches", type=int, default=3,
	help="The number of top TOMTOM matches to include in the report.")

convert_parser = subparsers.add_parser("convert", help="Convert an old h5py to the new format.")
convert_parser.add_argument("-i", "--h5py", type=str, required=True,
	help="An HDF5 file formatted in the original way.")
convert_parser.add_argument("-o", "--output", type=str, required=True,
	help="An HDF5 file formatted in the new way.")

convertback_parser = subparsers.add_parser("convert-backward", help="Convert a new h5py to the old format.")
convertback_parser.add_argument("-i", "--h5py", type=str, required=True,
	help="An HDF5 file formatted in the new way.")
convertback_parser.add_argument("-o", "--output", type=str, required=True,
	help="An HDF5 file formatted in the old way.")


# Pull the arguments
args = parser.parse_args()

if args.cmd == "motifs":
	if args.h5py is not None:
		# Load the scores
		scores = h5py.File(args.h5py, 'r')

		try:
			center = scores['hyp_scores'].shape[1] // 2
			start = center - args.window // 2
			end = center + args.window // 2
			
			attributions = scores['hyp_scores'][:, start:end, :]
			sequences = scores['input_seqs'][:, start:end, :]
		except KeyError:
			center = scores['shap']['seq'].shape[2] // 2
			start = center - args.window // 2
			end = center + args.window // 2
			
			attributions = scores['shap']['seq'][:, :, start:end].transpose(0, 2, 1)
			sequences = scores['raw']['seq'][:, :, start:end].transpose(0, 2, 1)

		scores.close()

	else:
		if args.sequences[-3:] == 'npy':
			sequences = np.load(args.sequences)
		elif args.sequences[-3:] == 'npz':
			sequences = np.load(args.sequences)['arr_0']

		if args.attributions[-3:] == 'npy':
			attributions = np.load(args.attributions)
		elif args.attributions[-3:] == 'npz':
			attributions = np.load(args.attributions)['arr_0']

		center = sequences.shape[2] // 2
		start = center - args.window // 2
		end = center + args.window // 2

		sequences = sequences[:, :, start:end].transpose(0, 2, 1)
		attributions = attributions[:, :, start:end].transpose(0, 2, 1)

	if sequences.shape[1] < args.window:
		raise ValueError("Window ({}) cannot be ".format(args.window) +
			"longer than the sequences".format(sequences.shape))

	sequences = sequences.astype('float32')
	attributions = attributions.astype('float32')

	pos_patterns, neg_patterns = modiscolite.tfmodisco.TFMoDISco(
		hypothetical_contribs=attributions, 
		one_hot=sequences,
		max_seqlets_per_metacluster=args.max_seqlets,
		sliding_window_size=20,
		flank_size=5,
		target_seqlet_fdr=0.05,
		n_leiden_runs=args.n_leiden,
		verbose=args.verbose)

	modiscolite.io.save_hdf5(args.output, pos_patterns, neg_patterns)

elif args.cmd == 'report':
	modiscolite.report.report_motifs(args.h5py, args.output, suffix=args.suffix, 
		top_n_matches=args.n_matches, meme_motif_db=args.meme_db)

elif args.cmd == 'convert':
	modiscolite.io.convert(args.h5py, args.output)

elif args.cmd == 'convert-backward':
	modiscolite.io.convert_new_to_old(args.h5py, args.output)