#!/usr/bin/env python
# tf-modisco command-line tool
# Author: Jacob Schreiber <jmschreiber91@gmail.com>

import json
import h5py
import hdf5plugin
import random
import argparse

import modiscolite
import modiscolite.report

import numpy as np

random.seed(0)
np.random.seed(0)

def save_hierarchy(h, grp):
	root_node_names = []
	for i in range(len(h.root_nodes)):
		node_name = "root_node"+str(i)
		root_node_names.append(node_name)
		save_node(h.root_nodes[i], grp.create_group(node_name))

	dset = grp.create_dataset("root_node_names", (len(root_node_names),), dtype=h5py.special_dtype(vlen=bytes))
	dset[:] = root_node_names

def save_node(node, grp):
	if node.indices_merged is not None:
		grp.create_dataset("indices_merged", data=np.array(node.indices_merged)) 
		grp.create_dataset("submat_crosscontam", data=np.array(node.submat_crosscontam)) 
		grp.create_dataset("submat_alignersim", data=np.array(node.submat_alignersim))

	save_pattern(node.pattern, grp.create_group("pattern"))

	if node.child_nodes is not None:
		child_node_names = []
		for i in range(len(node.child_nodes)):
			child_node_name = "child_node"+str(i)
			child_node_names.append(child_node_name)
			save_node(node.child_nodes[i], grp.create_group(child_node_name))

		dset = grp.create_dataset("child_node_names", (len(child_node_names),), dtype=h5py.special_dtype(vlen=bytes))
		dset[:] = child_node_names

def save_pattern(pattern, grp):
	track_grp = grp.create_group("sequence")
	track_grp.create_dataset("fwd", data=pattern.sequence)
	track_grp.create_dataset("rev", data=pattern.sequence[::-1, ::-1])
	track_grp.attrs["has_pos_axis"] = True

	track_grp = grp.create_group("task0_contrib_scores")
	track_grp.create_dataset("fwd", data=pattern.contrib_scores)
	track_grp.create_dataset("rev", data=pattern.contrib_scores[::-1, ::-1])
	track_grp.attrs["has_pos_axis"] = True

	track_grp = grp.create_group("task0_hypothetical_contribs")
	track_grp.create_dataset("fwd", data=pattern.hypothetical_contribs)
	track_grp.create_dataset("rev", data=pattern.hypothetical_contribs[::-1, ::-1])
	track_grp.attrs["has_pos_axis"] = True

	seqlets_and_alnmts_grp = grp.create_group("seqlets_and_alnmts")

	dset = seqlets_and_alnmts_grp.create_dataset("seqlets", (len(pattern.seqlets),), dtype=h5py.special_dtype(vlen=bytes))
	dset[:] = [str(x) for x in pattern.seqlets]

	seqlets_and_alnmts_grp.create_dataset("alnmts", 
		data=np.zeros(len(pattern.seqlets)))

	if pattern.subclusters is not None:
		grp.create_dataset("subclusters", data=pattern.subclusters)
		subcluster_to_subpattern_grp = grp.create_group("subcluster_to_subpattern")

		dset = subcluster_to_subpattern_grp.create_dataset("subcluster_names", (len(pattern.subcluster_to_subpattern.keys()),), dtype=h5py.special_dtype(vlen=bytes))
		dset[:] = ["subcluster_"+str(x) for x in pattern.subcluster_to_subpattern.keys()]

		for subcluster, subpattern in pattern.subcluster_to_subpattern.items():
			subpattern_grp = subcluster_to_subpattern_grp.create_group("subcluster_"+str(subcluster)) 
			save_pattern(subpattern, subpattern_grp)

def save_patterns(patterns, grp):
	all_pattern_names = []
	for idx, pattern in enumerate(patterns):
		pattern_name = "pattern_"+str(idx)
		all_pattern_names.append(pattern_name)
		pattern_grp = grp.create_group(pattern_name)
		save_pattern(pattern, pattern_grp)

	dset = grp.create_dataset("all_pattern_names", (len(all_pattern_names),), dtype=h5py.special_dtype(vlen=bytes))
	dset[:] = all_pattern_names

def save_string_list(string_list, dset_name, grp):
	dset = grp.create_dataset(dset_name, (len(string_list),),
							  dtype=h5py.special_dtype(vlen=bytes))
	dset[:] = string_list

def save_seqlet_coords(seqlets, dset_name, grp):
	coords_strings = [str(x) for x in seqlets] 
	save_string_list(string_list=coords_strings,
					 dset_name=dset_name, grp=grp)

def motif_discovery_main(sequences, attributions, output, max_seqlets):
	(multitask_seqlet_creation_results, metaclustering_results, 
		metacluster_idx_to_submetacluster_results) = modiscolite.tfmodisco.TFMoDISco(
		hypothetical_contribs=attributions, one_hot=sequences,
		max_seqlets_per_metacluster=max_seqlets,
		sliding_window_size=20,
		flank_size=5,
		target_seqlet_fdr=0.05,
		n_leiden_runs=2)

	########
	### Ugly saving code
	########

	grp = h5py.File(output, 'w')

	multitask_group = grp.create_group("multitask_seqlet_creation_results")
	save_string_list(
		  string_list=list(multitask_seqlet_creation_results['task_name_to_coord_producer_results'].keys()),
		  dset_name="task_names", grp=multitask_group)

	multitask2_group = multitask_group.create_group("multitask_seqlet_creator")
	multitask2_group.create_group("coord_producer")
	multitask2_group.create_group("overlap_resolver")
	multitask2_group.attrs['verbose'] = True

	save_seqlet_coords(seqlets=multitask_seqlet_creation_results['final_seqlets'], dset_name="final_seqlets", grp=multitask_group)

	tntcpg = multitask_group.create_group("task_name_to_coord_producer_results")

	coord_producer_results = multitask_seqlet_creation_results['task_name_to_coord_producer_results']['task0']
	save_string_list(string_list=[str(x) for x in coord_producer_results['seqlets']], dset_name="coords", grp=tntcpg)

	tnt_group = tntcpg.create_group("tnt_results")
	tnt_group.attrs['class'] = 'FWACTransformAndThresholdResults'

	metaclustering_group = grp.create_group("metaclustering_results")
	metaclustering_group.create_dataset("metacluster_indices", data=metaclustering_results['metacluster_indices'])

	metaclustering_group.attrs["class"] = "SignBasedPatternClustering"
	save_string_list(['task0'], dset_name="task_names", grp=metaclustering_group) 
	
	task_name_to_value_provider_grp = metaclustering_group.create_group("task_name_to_value_provider")

	task0_group = task_name_to_value_provider_grp.create_group("task0")
	task0_group.attrs["class"] = "TransformCentralWindowValueProvider"
	task0_group.attrs["track_name"] = "task0_contrib_scores"
	task0_group.attrs["central_window"] = 20
	
	val_transformer_group = task0_group.create_group("val_transformer")
	val_transformer_group.attrs["class"] = "AbsPercentileValTransformer"
	
	metaclustering_group.attrs["min_cluster_size"] = 100
	metaclustering_group.attrs["verbose"] = True
	metaclustering_group.attrs["fit_called"] = True

	activity_pattern_to_cluster_idx_grp = metaclustering_group.create_group("activity_pattern_to_cluster_idx")
	for activity_pattern, cluster_idx in metaclustering_results['pattern_to_cluster_idx'].items():
		activity_pattern_to_cluster_idx_grp.attrs[activity_pattern] = cluster_idx

	group = grp.create_group("metacluster_idx_to_submetacluster_results")
	for idx, results in metacluster_idx_to_submetacluster_results.items():
		smc_group = group.create_group("metacluster_{}".format(idx))
		smc_group.attrs['size'] = results['metacluster_size']
		save_seqlet_coords(seqlets=results['seqlets'], dset_name="seqlets", grp=smc_group)

		patterns = results['seqlets_to_patterns_result']
		pattern_group = smc_group.create_group("seqlets_to_patterns_result")
		pattern_group.attrs["success"] = patterns['success']
		pattern_group.attrs["other_config"] = json.dumps(patterns['other_config'], indent=4, separators=(',', ': '))

		if patterns['success']:
			if patterns['each_round_initcluster_motifs'] is not None:
				rounds_motifs = pattern_group.create_group("each_round_initcluster_motifs")
				all_round_names = []

				for round_idx, initcluster_motifs in enumerate(patterns['each_round_initcluster_motifs']):
					round_name = "round_" + str(round_idx)
					save_patterns(patterns=initcluster_motifs, grp=rounds_motifs.create_group(round_name)) 

				save_string_list(string_list=all_round_names, dset_name="all_round_names", grp=rounds_motifs)

			save_patterns(patterns['patterns'], pattern_group.create_group("patterns"))
			save_patterns(patterns['remaining_patterns'], pattern_group.create_group("remaining_patterns"))

			cluster_group = pattern_group.create_group("cluster_results")
			cluster_group.attrs['class'] = 'ClusterResults'
			cluster_group.create_dataset("cluster_indices", data=patterns['cluster_results']['cluster_indices'])

			save_hierarchy(patterns['pattern_merge_hierarchy'], pattern_group.create_group("pattern_merge_hierarchy"))

	print("Saved modisco results to file {}".format(output))
	####

desc = """TF-MoDISco is a motif detection algorithm that takes in nucleotide
	sequence and the attributions from a neural network model and return motifs
	that are repeatedly enriched for attriution score across the examples.
	This tool will take in one-hot encoded sequence, the corresponding
	attribution scores, and a few other parameters, and return the motifs."""

# Read in the arguments
parser = argparse.ArgumentParser(description=desc)
subparsers = parser.add_subparsers(help="Must be either 'motifs' or 'report'.", required=True, dest='cmd')

motifs_parser = subparsers.add_parser("motifs", help="Run TF-MoDISco and extract the motifs.")
motifs_parser.add_argument("-s", "--sequences", type=str,
	help="A .npy or .npz file containing the one-hot encoded sequences.")
motifs_parser.add_argument("-a", "--attributions", type=str,
	help="A .npy or .npz file containing the hypothetical attributions, i.e., the attributions for all nucleotides at all positions.")
motifs_parser.add_argument("-i", "--h5py", type=str,
	help="A legacy h5py file containing the one-hot encoded sequences and shap scores.")
motifs_parser.add_argument("-n", "--max_seqlets", type=int, required=True,
	help="The maximum number of seqlets per metacluster.")
motifs_parser.add_argument("-w", "--window", type=int, default=400,
	help="The window surrounding the peak center that will be considered for motif discovery.")
motifs_parser.add_argument("-o", "--output", type=str, default="modisco_results.h5",
	help="The path to the output file.")
motifs_parser.add_argument("-v", "--verbose", action="store_true", default=False,
	help="Controls the amount of output from the code.")

report_parser = subparsers.add_parser("report", help="Create a HTML report of the results.")
report_parser.add_argument("-i", "--h5py", type=str, required=True,
	help="An HDF5 file containing the output from modiscolite.")
report_parser.add_argument("-o", "--output", type=str, required=True,
	help="A directory to put the output results including the html report.")
report_parser.add_argument("-s", "--suffix", type=str, default="./",
	help="The suffix to add to the beginning of images. Should be equal to the output if using a Jupyter notebook.")
report_parser.add_argument("-m", "--meme_db", type=str, required=True,
	help="A MEME file containing motifs.")
report_parser.add_argument("-l", "--meme_motifs", type=str, required=True,
	help="A directory including extracted MEME motif pfms.")
report_parser.add_argument("-n", "--n_matches", type=int, default=3,
	help="The number of top TOMTOM matches to include in the report.")

# Pull the arguments
args = parser.parse_args()

if args.cmd == "motifs":
	if args.h5py is not None:
		# Load the scores
		scores = h5py.File(args.h5py, 'r')

		try:
			center = scores['hyp_scores'].shape[1] // 2
			start = center - args.window // 2
			end = center + args.window // 2
			
			attributions = scores['hyp_scores'][:, start:end, :]
			sequences = scores['input_seqs'][:, start:end, :]
		except KeyError:
			center = scores['shap']['seq'].shape[2] // 2
			start = center - args.window // 2
			end = center + args.window // 2
			
			attributions = scores['shap']['seq'][:, :, start:end].transpose(0, 2, 1)
			sequences = scores['raw']['seq'][:, :, start:end].transpose(0, 2, 1)

		scores.close()
	else:
		if args.sequences[-3:] == 'npy':
			sequences = np.load(args.sequences)
		elif args.sequences[-3:] == 'npz':
			sequences = np.load(args.sequences)['arr_0']

		if args.attributions[-3:] == 'npy':
			attributions = np.load(args.attributions)
		elif args.attributions[-3:] == 'npz':
			attributions = np.load(args.attributions)['arr_0']

		center = sequences.shape[2] // 2
		start = center - args.window // 2
		end = center + args.window // 2

		sequences = sequences[:, :, start:end].transpose(0, 2, 1)
		attributions = attributions[:, :, start:end].transpose(0, 2, 1)

	motif_discovery_main(sequences, attributions, args.output, 
		args.max_seqlets)

elif args.cmd == 'report':
	modiscolite.report.report_motifs(args.h5py, args.output, suffix=args.suffix, 
		top_n_matches=args.n_matches, meme_motif_db=args.meme_db,
		meme_motif_dir=args.meme_motifs)
