#!/usr/bin/env python

import argparse
import pandas as pd
import numpy as np
import json
import sys
import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from miner2 import preprocess, biclusters, subtypes, util
from miner2 import GIT_SHA, __version__ as pkg_version
from miner2 import miner

import logging


DESCRIPTION = """miner-subtypes - MINER compute sample subtypes
MINER Version %s (Git SHA %s)""" % (pkg_version, GIT_SHA.replace('$Id: ', '').replace(' $', ''))

MIN_CORRELATION = 0.2

LOG_FORMAT = '%(asctime)s %(message)s'
DATE_FORMAT = '%Y-%m-%d %H:%M:%S \t'

if __name__ == '__main__':
    stream_handler = logging.StreamHandler(sys.stdout)
    stream_handler.setFormatter(logging.Formatter(LOG_FORMAT, DATE_FORMAT))
    LOGGER = logging.getLogger()
    LOGGER.addHandler(stream_handler)
    LOGGER.setLevel(logging.INFO)

    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                     description=DESCRIPTION)
    parser.add_argument('expfile', help="input matrix")
    parser.add_argument('mapfile', help="identifier mapping file")
    parser.add_argument('regulons', help="regulons.json file from miner-mechinf")
    parser.add_argument('outdir', help="output directory")
    parser.add_argument('--skip_tpm', action="store_true",
                        help="overexpression threshold")

    args = parser.parse_args()

    if not os.path.exists(args.regulons):
        sys.exit("regulons file not found")

    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    with open(os.path.join(args.outdir, 'run_info.txt'), 'w') as outfile:
        util.write_dependency_infos(outfile)

    LOGGER.info('load and setup data')
    exp_data, conv_table = preprocess.main(args.expfile, args.mapfile, do_preprocess_tpm=(not args.skip_tpm))
    bkgd = preprocess.background_df(exp_data)

    with open(args.regulons) as infile:
        regulon_modules = json.load(infile)

    overexpressed_members = biclusters.make_membership_dictionary(regulon_modules, bkgd, label=2,
                                                                  p=0.05)
    overexpressed_members_matrix = biclusters.membership_to_incidence(overexpressed_members,
                                                                      exp_data)
    underexpressed_members = biclusters.make_membership_dictionary(regulon_modules, bkgd, label=0,
                                                                   p=0.05)
    underexpressed_members_matrix = biclusters.membership_to_incidence(underexpressed_members,
                                                                       exp_data)

    sample_dictionary = overexpressed_members
    sample_matrix = overexpressed_members_matrix

    # Matt's new way of subtype discovery (START)
    # TODO: fold this into the code above
    reference_matrix = overexpressed_members_matrix - underexpressed_members_matrix
    primary_matrix = overexpressed_members_matrix
    primary_dictionary = overexpressed_members
    secondary_matrix = underexpressed_members_matrix
    secondary_dictionary = underexpressed_members

    states, centroid_clusters = miner.inferSubtypes(reference_matrix, primary_matrix,
                                                    secondary_matrix,
                                                    primary_dictionary,
                                                    secondary_dictionary,
                                                    minClusterSize=int(np.ceil(0.01*exp_data.shape[1])),restricted_index=None)
    states_dictionary = {str(i):states[i] for i in range(len(states))}
    with open(os.path.join(args.outdir, "transcriptional_states.json"), 'w') as outfile:
        json.dump(states_dictionary, outfile)

    eigengenes = miner.getEigengenes(regulon_modules, exp_data,
                                     regulon_dict=None,saveFolder=None)
    reference_df = eigengenes.copy()
    programs, _ = miner.mosaic(dfr=reference_df, clusterList=centroid_clusters,
                               minClusterSize_x=int(np.ceil(0.01*exp_data.shape[1])),
                               minClusterSize_y=5,
                               allow_singletons=False,
                               max_groups=50,
                               saveFile=os.path.join(args.outdir,"regulon_activity_heatmap.pdf"),
                               random_state=12)
    transcriptional_programs, program_regulons = miner.transcriptionalPrograms(programs,
                                                                               regulon_modules)
    program_list = [program_regulons[("").join(["TP",str(i)])] for i in range(len(program_regulons))]
    programs_dictionary = {str(i):program_list[i] for i in range(len(program_list))}

    with open(os.path.join(args.outdir, "transcriptional_programs.json"), 'w') as outfile:
        json.dump(programs_dictionary, outfile)

    mosaic_df = reference_df.loc[np.hstack(program_list), np.hstack(states)]
    mosaic_df.to_csv(os.path.join(args.outdir, "regulons_activity_heatmap.csv"))

    dfr = overexpressed_members_matrix - underexpressed_members_matrix
    mtrx = dfr.loc[np.hstack(program_list),np.hstack(states)]
    plt.figure(figsize=(8,8))
    plt.imshow(mtrx, cmap="bwr", vmin=-1, vmax=1, aspect=float(mtrx.shape[1]) / float(mtrx.shape[0]))
    plt.grid(False)
    plt.savefig(os.path.join(args.outdir, "mosaic_all.pdf"), bbox_inches="tight")

    # Determine activity of transcriptional programs in each sample
    states_df = miner.reduceModules(df=dfr.loc[np.hstack(program_list),
                                               np.hstack(states)],programs=program_list,
                                    states=states, stateThreshold=0.50,
                                    saveFile=os.path.join(args.outdir, "transcriptional_programs.pdf"))

    # Cluster patients into subtypes and give the activity of each program in each subtype
    programsVsStates = miner.programsVsStates(states_df, states,
                                              filename=os.path.join(args.outdir, "programs_vs_states.pdf"),
                                              showplot=True)
