#!/usr/bin/env python

import argparse
import pandas as pd
import numpy as np
import json
import sys
import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from miner2 import preprocess, biclusters, subtypes, util
from miner2 import GIT_SHA, __version__ as pkg_version
from miner2 import miner

import logging


DESCRIPTION = """miner-subtypes - MINER compute sample subtypes
MINER Version %s (Git SHA %s)""" % (pkg_version, GIT_SHA.replace('$Id: ', '').replace(' $', ''))

MIN_CORRELATION = 0.2

LOG_FORMAT = '%(asctime)s %(message)s'
DATE_FORMAT = '%Y-%m-%d %H:%M:%S \t'

if __name__ == '__main__':
    stream_handler = logging.StreamHandler(sys.stdout)
    stream_handler.setFormatter(logging.Formatter(LOG_FORMAT, DATE_FORMAT))
    LOGGER = logging.getLogger()
    LOGGER.addHandler(stream_handler)
    LOGGER.setLevel(logging.INFO)

    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                     description=DESCRIPTION)
    parser.add_argument('expfile', help="input matrix")
    parser.add_argument('mapfile', help="identifier mapping file")
    parser.add_argument('regulons', help="regulons.json file from miner-mechinf")
    parser.add_argument('outdir', help="output directory")
    parser.add_argument('--skip_tpm', action="store_true",
                        help="overexpression threshold")

    args = parser.parse_args()

    if not os.path.exists(args.regulons):
        sys.exit("regulons file not found")

    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    with open(os.path.join(args.outdir, 'run_info.txt'), 'w') as outfile:
        util.write_dependency_infos(outfile)

    LOGGER.info('load and setup data')
    exp_data, conv_table = preprocess.main(args.expfile, args.mapfile, do_preprocess_tpm=(not args.skip_tpm))
    bkgd = preprocess.background_df(exp_data)

    with open(args.regulons) as infile:
        regulon_modules = json.load(infile)

    overexpressed_members = biclusters.make_membership_dictionary(regulon_modules, bkgd, label=2,
                                                                  p=0.05)
    overexpressed_members_matrix = biclusters.membership_to_incidence(overexpressed_members,
                                                                      exp_data)
    underexpressed_members = biclusters.make_membership_dictionary(regulon_modules, bkgd, label=0,
                                                                   p=0.05)
    underexpressed_members_matrix = biclusters.membership_to_incidence(underexpressed_members,
                                                                       exp_data)

    sample_dictionary = overexpressed_members
    sample_matrix = overexpressed_members_matrix


    # perform initial subtype clustering
    LOGGER.info('initial subtype clustering')
    similarity_clusters = subtypes.f1_decomposition(sample_dictionary, thresholdSFM=0.1)
    initial_classes = [i for i in similarity_clusters if len(i) > 4]

    # visualize initial results
    LOGGER.info('visualize initial results')
    sample_freq_matrix = subtypes.sample_coincidence_matrix(sample_dictionary,
                                                            freqThreshold=0.333,
                                                            frequencies=True)
    similarity_matrix = sample_freq_matrix * sample_freq_matrix.T
    subtypes.plot_similarity(similarity_matrix,
                             np.hstack(initial_classes), vmin=0, vmax=0.5,
                             title="Similarity matrix", xlabel="Samples",
                             ylabel="Samples", fontsize=14, figsize=(7,7),
                             savefig=os.path.join(args.outdir,
                                                  "similarityMatrix_regulons.pdf"))

    # expand initial subtype clusters
    LOGGER.info('expand initial subtype clusters')
    centroid_clusters, centroid_matrix = subtypes.centroid_expansion(initial_classes,
                                                                     sample_matrix,
                                                                     f1Threshold=0.1,
                                                                     returnCentroids=True)

    centroid_matrix.to_csv(os.path.join(args.outdir, "centroids.csv"))
    unmapped = list(sample_matrix.columns[np.where(sample_matrix.sum(axis=0)==0)[0]])
    mapped_samples = [i for i in np.hstack(centroid_clusters) if i not in unmapped]
    mapped_clusters = subtypes.map_expression_to_network(centroid_matrix, sample_matrix,
                                                         threshold=0.05)

    # visualize expanded subtype clusters
    LOGGER.info('visualize expanded subtype clusters')
    subtypes.plot_similarity(similarity_matrix, mapped_samples,
                             vmin=0, vmax=0.333, title="Similarity matrix",
                             xlabel="Samples", ylabel="Samples", fontsize=14,
                             figsize=(7,7),
                             savefig=os.path.join(args.outdir,"centroidClusters_regulons.pdf"))

    # Generate heatmaps of module activity
    LOGGER.info('generate heatmaps of module activity')
    ordered_overexpressed_members = subtypes.order_membership(centroid_matrix,
                                                              sample_matrix,
                                                              mapped_clusters,
                                                              ylabel="Modules",
                                                              resultsDirectory=args.outdir)

    ordered_dm = subtypes.plot_differential_matrix(sample_matrix,
                                                   underexpressed_members_matrix,
                                                   ordered_overexpressed_members,
                                                   cmap="bwr", aspect="auto",
                                                   saveFile=os.path.join(args.outdir,
                                                                         "centroid_clusters_heatmap.pdf"))
    ordered_dm.to_csv('subtypes_ordered_dm.csv')  # WW: Debug

    # Infer transcriptional programs and states
    LOGGER.info('infer transcriptional programs and states')
    programs, states = subtypes.mosaic(dfr=ordered_dm, clusterList=centroid_clusters,
                                       minClusterSize_x=9, minClusterSize_y=5,
                                       allow_singletons=False, max_groups=50,
                                       saveFile=os.path.join(args.outdir,
                                                             "regulon_activity_heatmap.pdf"),
                                       random_state=12)

    # transcriptionalPrograms uses a feature of pandas that does not always
    # exist (pandas.core.indexes), my solution was to patch the problem
    # with an exception handler
    transcriptional_programs, program_regulons = subtypes.transcriptional_programs(programs,
                                                                                   regulon_modules)
    program_list = [program_regulons[("").join(["TP",str(i)])]
                    for i in range(len(program_regulons))]

    # WW: DEBUG
    with open('subtypes_program_list.json', 'w') as outfile:
        json.dump(program_list, outfile)
    with open('subtypes_states.json', 'w') as outfile:
        json.dump(list(np.hstack(states)), outfile)

    mosaic_df = ordered_dm.loc[np.hstack(program_list), np.hstack(states)]
    mosaic_df.to_csv(os.path.join(args.outdir, "regulons_activity_heatmap.csv"))

    # Get eigengenes for all modules
    eigengenes = subtypes.get_eigengenes(regulon_modules, exp_data,
                                         regulon_dict=None, saveFolder=None)

    # write eigengenes to .csv
    eigengenes.to_csv(os.path.join(args.outdir, "eigengenes.csv"))

    # plot eigengenes
    plt.figure()
    plt.imshow(eigengenes.loc[np.hstack(program_list), np.hstack(states)],
               cmap="viridis", vmin=-0.05, vmax=0.05, aspect="auto")
    plt.grid(False)

    # calculate percent of samples that fall into a state with >= minimum acceptable
    # number of samples
    LOGGER.info('calculate percent of samples that fall into a state')
    groups = [states[i]
              for i in range(len(states))
              if len(states[i])>=int(np.ceil(0.01*exp_data.shape[1]))]

    print('Discovered {:d} transcriptional states and {:d} transcriptional programs'.format((len(states)),len(transcriptional_programs)))
    print('sample coverage within sufficiently large states: {:.1f}%'.format(100*float(len(np.hstack(groups))) / exp_data.shape[1]))

    # write all transcriptional program genesets to text files for external analysis
    LOGGER.info('write output')
    if not os.path.isdir(os.path.join(args.outdir, "transcriptional_programs_coexpressionModules")):
        os.mkdir(os.path.join(args.outdir,"transcriptional_programs_coexpressionModules"))

    for tp in transcriptional_programs.keys():
        np.savetxt(os.path.join(args.outdir, "transcriptional_programs_coexpressionModules",
                                (".").join([tp,"txt"])),
                   transcriptional_programs[tp], fmt="%1.50s")

    # Determine activity of transcriptional programs in each sample
    states_df = subtypes.reduce_modules(df=ordered_dm, programs=program_list, states=states,
                                        stateThreshold=0.65,
                                        saveFile=os.path.join(args.outdir,
                                                              "transcriptional_programs_vs_samples.pdf"))

    # Cluster patients into subtypes and give the activity of each program in each subtype
    programs_vs_states = subtypes.programs_vs_states(states_df, states,
                                                     filename=os.path.join(args.outdir,
                                                                           "programs_vs_states.pdf"),
                                                     showplot=True)

    # Visualize with tSNE
    plt.figure()
    subtypes.tsne(exp_data, perplexity=15, n_components=2, n_iter=1000,
                  plotOnly=True, plotColor="red", alpha=0.4)
    plt.savefig(os.path.join(args.outdir, "tsne_gene_expression.pdf"),
                bbox_inches="tight")

    # tSNE applied to df_for_tsne. Consider changing the perplexity in the
    # range of 5 to 50
    df_for_tsne = mosaic_df.copy()
    plt.figure()
    x_embedded = subtypes.tsne(df_for_tsne, perplexity=30, n_components=2,
                               n_iter=1000, plotOnly=None, plotColor="blue", alpha=0.2)
    tsne_df = pd.DataFrame(x_embedded)
    tsne_df.index = df_for_tsne.columns
    tsne_df.columns = ["tsne1", "tsne2"]
    plt.savefig(os.path.join(args.outdir, "tsne_regulon_activity.pdf"),
                bbox_inches="tight")

    """
    nsd2, maf, ccnd1, csk1b, ikzf1, ikzf3, tp53, e2f1, ets1, phf19 = ["ENSG00000109685","ENSG00000178573","ENSG00000110092","ENSG00000173207","ENSG00000185811","ENSG00000161405","ENSG00000141510","ENSG00000101412","ENSG00000134954","ENSG00000119403"]

    target = tp53
    target_expression = np.array(exp_data.loc[target, tsne_df.index])

    plt.scatter(tsne_df.iloc[:,0], tsne_df.iloc[:,1], cmap="bwr", c=target_expression, alpha=0.65)
    plt.savefig(os.path.join(args.outdir,("").join(["labeled_tsne_", target,
                                                    "_overlay_o2.pdf"])),bbox_inches="tight")
    """

    # How many clusters do you expect ? Start with number of states
    num_clusters = len(states)

    # Are the clusters separated how you thought? If not, change the random_state
    # to a different number and retry
    random_state = 12

    clusters, labels, centroids = subtypes.kmeans(tsne_df, numClusters=num_clusters,
                                                  random_state=random_state)

    # overlay cluster labels. WW: Vega20 instead of tab20 because it does not exist
    # at my system
    plt.figure()
    try:
        plt.scatter(tsne_df.iloc[:,0], tsne_df.iloc[:,1], cmap="tab20", c=labels, alpha=0.65)
    except:
        plt.scatter(tsne_df.iloc[:,0], tsne_df.iloc[:,1], cmap="Vega20", c=labels, alpha=0.65)

    plt.savefig(os.path.join(args.outdir, "labeled_tsne_kmeans.pdf"),
                bbox_inches="tight")

    # convert states to tsne labels
    state_labels = subtypes.tsne_state_labels(tsne_df, states)

    # overlay states cluster labels
    plt.figure()
    try:
        plt.scatter(tsne_df.iloc[:,0], tsne_df.iloc[:,1], cmap="tab20",
                    c=state_labels,
                    alpha=0.65)
    except:
        plt.scatter(tsne_df.iloc[:,0], tsne_df.iloc[:,1], cmap="Vega20",
                    c=state_labels,
                    alpha=0.65)

    plt.savefig(os.path.join(args.outdir, "labeled_tsne_states.pdf"),
                bbox_inches="tight")

    # overlay activity of transcriptional programs
    subtypes.plot_states(states_df, tsne_df,
                         numCols=int(np.sqrt(states_df.shape[0])),
                         saveFile=os.path.join(args.outdir,
                                               ("_").join(["states_regulons",
                                                           ("").join(["0o", str(MIN_CORRELATION).split(".")[1]]),
                                                           "tsne.pdf"])),
                         aspect=1, size=10, scale=3)



    # Matt's new way of subtype discovery (START)
    # TODO: fold this into the code above
    reference_matrix = overexpressed_members_matrix - underexpressed_members_matrix
    primary_matrix = overexpressed_members_matrix
    primary_dictionary = overexpressed_members
    secondary_matrix = underexpressed_members_matrix
    secondary_dictionary = underexpressed_members

    states, centroid_clusters = miner.inferSubtypes(reference_matrix, primary_matrix,
                                                    secondary_matrix,
                                                    primary_dictionary,
                                                    secondary_dictionary,
                                                    minClusterSize=int(np.ceil(0.01*exp_data.shape[1])),restricted_index=None)
    states_dictionary = {str(i):states[i] for i in range(len(states))}
    with open(os.path.join(args.outdir, "transcriptional_states.json"), 'w') as outfile:
        json.dump(states_dictionary, outfile)

    reference_df = eigengenes.copy()
    programs, _ = miner.mosaic(dfr=reference_df, clusterList=centroid_clusters,
                               minClusterSize_x=int(np.ceil(0.01*exp_data.shape[1])),
                               minClusterSize_y=5,
                               allow_singletons=False,
                               max_groups=50,
                               saveFile=os.path.join(args.outdir,"regulon_activity_heatmap.pdf"),
                               random_state=12)
    transcriptional_programs, program_regulons = miner.transcriptionalPrograms(programs,
                                                                               regulon_modules)
    program_list = [program_regulons[("").join(["TP",str(i)])] for i in range(len(program_regulons))]
    programs_dictionary = {str(i):program_list[i] for i in range(len(program_list))}

    with open(os.path.join(args.outdir, "transcriptional_programs.json"), 'w') as outfile:
        json.dump(programs_dictionary, outfile)

    # Matt's new way of subtype discovery (END)
