#!/usr/bin/env python

import argparse
import pandas as pd
import numpy as np
import json
import sys
import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns

from miner2 import preprocess, biclusters, subtypes, survival, util
from miner2 import GIT_SHA, __version__ as pkg_version


DESCRIPTION = """miner-survival - MINER survival analysis
MINER Version %s (Git SHA %s)""" % (pkg_version, GIT_SHA.replace('$Id: ', '').replace(' $', ''))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                     description=DESCRIPTION)
    parser.add_argument('expfile', help="input matrix")
    parser.add_argument('mapfile', help="identifier mapping file")
    parser.add_argument('regulons', help="regulons.json file from miner-mechinf")
    parser.add_argument('survfile', help="survival data CSV file")
    parser.add_argument('outdir', help="output directory")
    parser.add_argument('--skip_tpm', action="store_true",
                        help="overexpression threshold")

    args = parser.parse_args()

    if not os.path.exists(args.regulons):
        sys.exit("regulons file not found")

    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    with open(os.path.join(args.outdir, 'run_info.txt'), 'w') as outfile:
        util.write_dependency_infos(outfile)

    exp_data, conv_table = preprocess.main(args.expfile, args.mapfile, do_preprocess_tpm=(not args.skip_tpm))
    bkgd = preprocess.background_df(exp_data)

    with open(args.regulons) as infile:
        regulon_modules = json.load(infile)

    overexpressed_members = biclusters.make_membership_dictionary(regulon_modules, bkgd, label=2)
    overexpressed_members_matrix = biclusters.membership_to_incidence(overexpressed_members,
                                                                      exp_data)
    underexpressed_members = biclusters.make_membership_dictionary(regulon_modules, bkgd, label=0)
    underexpressed_members_matrix = biclusters.membership_to_incidence(underexpressed_members,
                                                                       exp_data)

    sample_dictionary = overexpressed_members
    sample_matrix = overexpressed_members_matrix

    # We need some data computed in the subtypes step, to keep the
    # inputs simple, we recompute the steps
    similarity_clusters = subtypes.f1_decomposition(sample_dictionary, thresholdSFM=0.1)
    initial_classes = [i for i in similarity_clusters if len(i) > 4]

    centroid_clusters, centroid_matrix = subtypes.centroid_expansion(initial_classes,
                                                                     sample_matrix,
                                                                     f1Threshold=0.1,
                                                                     returnCentroids=True)

    mapped_clusters = subtypes.map_expression_to_network(centroid_matrix, sample_matrix,
                                                         threshold=0.05)

    ordered_overexpressed_members = subtypes.order_membership(centroid_matrix,
                                                              sample_matrix,
                                                              mapped_clusters,
                                                              ylabel="Modules",
                                                              resultsDirectory=args.outdir)

    # We need the states and states_dm objects from here, would be great if
    # we did not have to plot
    ordered_dm = subtypes.plot_differential_matrix(overexpressed_members_matrix,
                                                   underexpressed_members_matrix,
                                                   ordered_overexpressed_members,
                                                   cmap="bwr", aspect="auto",
                                                   saveFile=os.path.join(args.outdir,
                                                                         "centroid_clusters_heatmap.pdf"))

    programs, states = subtypes.mosaic(dfr=ordered_dm,
                                       clusterList=centroid_clusters,
                                       minClusterSize_x=9,
                                       minClusterSize_y=5,
                                       allow_singletons=False,
                                       max_groups=50,
                                       saveFile=os.path.join(args.outdir, "regulon_activity_heatmap.pdf"),
                                       random_state=12)


    transcriptional_programs, program_regulons = subtypes.transcriptional_programs(programs,
                                                                                   regulon_modules)
    program_list = [program_regulons[("").join(["TP",str(i)])]
                    for i in range(len(program_regulons))]

    states_df = subtypes.reduce_modules(df=ordered_dm, programs=program_list,
                                        states=states, stateThreshold=0.65,
                                        saveFile=os.path.join(args.outdir, "transcriptional_programs_vs_samples.pdf"))

    #########################################################################
    # THE ACTUAL SURVIVAL COMPUTATION
    ######################################################
    survival_mmrf = pd.read_csv(args.survfile, index_col=0, header=0)
    survival_df_mmrf = survival_mmrf.iloc[:,0:2]
    survival_df_mmrf.columns = ["duration","observed"]
    overexpressed_members_matrix_mmrf = overexpressed_members_matrix
    overexpressed_members_matrix_mmrf = overexpressed_members_matrix_mmrf.loc[ordered_overexpressed_members.index,:]

    # generate Kaplan-Meier estimates
    km_df = survival.km_analysis(survivalDf=survival_df_mmrf, durationCol="duration",
                                 statusCol="observed")
    # generate GuanRank scores
    guan_survival_df_mmrf = survival.guan_rank(kmSurvival=km_df)

    srv = guan_survival_df_mmrf.copy()
    guan_srv = pd.DataFrame(srv.loc[:,"GuanScore"])
    guan_srv.columns = ["value"]
    guan_srv_group = pd.DataFrame(-np.ones(guan_srv.shape[0]))
    guan_srv_group.index = guan_srv.index
    guan_srv_group.columns = ["group"]
    guan_srv_df = pd.concat([guan_srv,guan_srv_group], axis=1)

    mmrf_groups = states
    labels = range(len(mmrf_groups))

    xmedians = []

    for i in range(len(mmrf_groups)):
        group = list(set(srv.index) & set(mmrf_groups[i]))
        if len(group) >= 1:
            xmedians.append(np.median(guan_srv_df.loc[group, "value"]))
        elif len(group) < 1:
            xmedians.append(0)
        label = labels[i]
        guan_srv_df.loc[group,"group"] = label

    plt.close('all')  # if we don't do this, it looks like that the plot below is wrong

    ranked_states = np.argsort(xmedians)
    survival_tag = "Risk_groups"
    plt.figure(figsize=(12, 8))
    ax = sns.boxplot(x='group', y='value', data=guan_srv_df, order=ranked_states)

    for patch in ax.artists:
        patch.set_edgecolor('black')
        r, g, b, a = patch.get_facecolor()
        patch.set_facecolor((r, g, b, 0.8))

    sns.swarmplot(x='group', y='value', data=guan_srv_df, order=np.argsort(xmedians),
                  size=5, color=[0.15,0.15,0.15], edgecolor="black")

    plt.ylabel("Risk score", FontSize=20)
    plt.xlabel("Subtype", FontSize=20)
    boxplot_filename = ("_").join([survival_tag,"boxplot_swarm.pdf"])
    plt.savefig(os.path.join(args.outdir, boxplot_filename), bbox_inches="tight")

    cox_programs = survival.parallel_member_survival_analysis(membershipDf=states_df,
                                                              numCores=5,
                                                              survivalPath="",
                                                              survivalData=srv)

    cox_hr = [cox_programs[i][0] for i in range(len(cox_programs))]
    cox_p = [cox_programs[i][1] for i in range(len(cox_programs))]
    cox_df = pd.DataFrame(np.vstack([cox_hr, cox_p]).T)
    cox_programs_keys = list(cox_programs.keys())
    cox_df.index = [cox_programs_keys[i] for i in range(len(cox_programs))]
    cox_df.columns = ["HR", "p-value"]
    cox_df.sort_values(by="HR", ascending=False, inplace=True)

    """
    hrs = []
    for key in cox_programs.keys():
        HR = cox_programs[key][0]
        hrs.append(HR)
    # Another survival analysis ???
    srv = survival_df_mmrf.copy()
    """

    state_members = [states[i] for i in range(len(states)) if len(states[i]) >= 9]
    sufficient_states = [i for i in range(len(states)) if len(states[i]) >= 9]
    state_survival = pd.DataFrame(np.zeros((len(state_members), srv.shape[0])))
    state_survival.index = sufficient_states
    state_survival.columns = srv.index

    for ix in range(len(sufficient_states)):
        sm = list(set(state_members[ix])&set(srv.index))
        state_survival.loc[sufficient_states[ix],sm] = 1

    cox_states = survival.parallel_member_survival_analysis(membershipDf=state_survival,
                                                            numCores=5,
                                                            survivalPath="",
                                                            survivalData=srv)

    cox_hr = [cox_states[i][0] for i in cox_states.keys()]
    cox_p = [cox_states[i][1] for i in cox_states.keys()]
    cox_df = pd.DataFrame(np.vstack([cox_hr, cox_p]).T)
    cox_df.index = cox_states.keys()
    cox_df.columns = ["HR","p-value"]
    cox_df.sort_values(by="HR", ascending=False, inplace=True)


    # combinatorial survival analysis
    combined_states, combined_indices = survival.combined_states(mmrf_groups, ranked_states,
                                                                 srv, minSamples=4, maxStates=10)

    state_members = combined_states
    sufficient_states = range(len(combined_states))
    state_survival = pd.DataFrame(np.zeros((len(state_members),srv.shape[0])))
    state_survival.index = sufficient_states
    state_survival.columns = srv.index
    for ix in range(len(sufficient_states)):
        sm = list(set(state_members[ix])&set(srv.index))
        state_survival.loc[sufficient_states[ix],sm] = 1

    cox_combined_states = survival.parallel_member_survival_analysis(membershipDf=state_survival,
                                                                     numCores=1,survivalPath="",
                                                                     survivalData=srv)

    cox_hr = [cox_combined_states[i][0] for i in cox_combined_states.keys()]
    cox_p = [cox_combined_states[i][1] for i in cox_combined_states.keys()]
    cox_combined_states_df = pd.DataFrame(np.vstack([cox_hr,cox_p]).T)
    cox_combined_states_df.index = combined_indices
    cox_combined_states_df.columns = ["HR","p-value"]
    cox_combined_states_df.sort_values(by="HR",ascending=False,inplace=True)
    cox_combined_states_df.to_csv(os.path.join(args.outdir,
                                               "Hazards_regression_of_combined_transcriptional_states.csv"))

    highest_risk_combination = np.array(cox_combined_states_df.index[0].split("&")).astype(int)
    lowest_risk_combination = np.array(cox_combined_states_df.index[-1].split("&")).astype(int)

    kmTag = "states"
    kmFilename = ("_").join([survival_tag, kmTag, ".pdf"])
    groups = [np.hstack([states[i] for i in highest_risk_combination]),
              np.hstack([states[i] for i in range(len(states))
                         if i not in highest_risk_combination])]
    labels = np.arange(len(groups)).astype(str)
    labels = ["High-risk states","Other states"]
    plotName = os.path.join(args.outdir, kmFilename)
    survival.kmplot(srv=srv, groups=groups, labels=labels, xlim_=(-100,1750),
                    filename=plotName)
