#!/usr/bin/env python3

import argparse
import json
import os

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn import tree
from sklearn.externals.six import StringIO
import pydot
import seaborn as sns
from graphviz import Source

import pickle

from miner2 import risk_predict, survival, util, preprocess, miner
from miner2 import GIT_SHA, __version__ as pkg_version

DESCRIPTION = """miner-riskpredict - MINER compute risk prediction.
MINER Version %s (Git SHA %s)""" % (pkg_version, GIT_SHA.replace('$Id: ', '').replace(' $', ''))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                     description=DESCRIPTION)
    parser.add_argument('input', help="input specification file")
    parser.add_argument('outdir', help="output directory")
    parser.add_argument('--skip_tpm', action="store_true",
                        help="overexpression threshold")

    args = parser.parse_args()

    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    with open(os.path.join(args.outdir, 'run_info.txt'), 'w') as outfile:
        util.write_dependency_infos(outfile)

    with open(args.input) as infile:
        input_spec = json.load(infile)


    exp_data, conv_table = preprocess.main(input_spec['exp'], input_spec['idmap'], do_preprocess_tpm=(not args.skip_tpm))

    translocations = None
    if 'translocations' in input_spec:
        translocations = pd.read_csv(input_spec['translocations'], index_col=0, header=0)


    with open(input_spec['coexpression_dictionary']) as infile:
        revised_clusters = json.load(infile)
    with open(input_spec['coexpression_modules']) as infile:
        coexpression_modules = json.load(infile)
    with open(input_spec['regulon_modules']) as infile:
        regulon_modules = json.load(infile)
    with open(input_spec['mechanistic_output']) as infile:
        mechanistic_output = json.load(infile)

    regulon_df = pd.read_csv(input_spec['regulon_df'], index_col=0, header=0)
    oem_matrix = pd.read_csv(input_spec['overexpressed_members'], index_col=0, header=0)
    oem_matrix.index = np.array(oem_matrix.index).astype(str)

    uem_matrix = pd.read_csv(input_spec['underexpressed_members'], index_col=0, header=0)
    uem_matrix.index = np.array(uem_matrix.index).astype(str)

    eigengenes = pd.read_csv(input_spec['eigengenes'], index_col=0, header=0)
    eigengenes.index = np.array(eigengenes.index).astype(str)

    filtered_causal_results = pd.read_csv(input_spec['filtered_causal_results'],
                                          index_col=0, header=0)

    with open(input_spec['transcriptional_programs']) as infile:
        transcriptional_programs = json.load(infile)
        program_list = [transcriptional_programs[str(key)]
                        for key in range(len(transcriptional_programs.keys()))]

    with open(input_spec['transcriptional_states']) as infile:
        transcriptional_states = json.load(infile)
        states_list = [transcriptional_states[str(key)]
                       for key in range(len(transcriptional_states.keys()))]
        final_state = list(set(exp_data.columns) - set(np.hstack(states_list)))
        if len(final_state) > 0:
            states_list.append(final_state)

        states = list(states_list)
        diff_matrix_mmrf = oem_matrix - uem_matrix

    survival_mmrf = pd.read_csv(input_spec['primary_survival_data'], index_col=0, header=0)
    survival_df_mmrf = survival_mmrf.iloc[:,0:2]
    survival_df_mmrf.columns = ["duration", "observed"]
    oem_matrix_mmrf = oem_matrix
    km_df_mmrf = miner.kmAnalysis(survivalDf=survival_df_mmrf,
                                  durationCol="duration",
                                  statusCol="observed")
    guanSurvivalDfMMRF = miner.guanRank(kmSurvival=km_df_mmrf)



    # Figures

    plt.figure(figsize=(8,8))
    plt.imshow(diff_matrix_mmrf.loc[np.hstack(program_list),
                                    np.hstack(states_list)],
               cmap="bwr", vmin=-1, vmax=1, aspect=0.275)
    plt.grid(False)
    plt.savefig(os.path.join(args.outdir, "regulon_activity_heatmap.pdf"),
                bbox_inches="tight")

    # Survival analysis of regulons
    cox_regulons_output = miner.parallelMedianSurvivalAnalysis(regulon_modules,
                                                               exp_data,
                                                               guanSurvivalDfMMRF,
                                                               numCores=5)
    cox_regulons_output = cox_regulons_output.iloc[np.argsort(np.array(cox_regulons_output.index).astype(int))]

    cox_regulons_output.to_csv(os.path.join(args.outdir, 'CoxProportionalHazardsRegulons.csv'))
    cox_regulons_output.sort_values(by="HR", ascending=False, inplace=True)

    print("\nHigh-risk programs:")
    print(cox_regulons_output.iloc[0:5,:])
    print("\nLow-risk programs")
    print(cox_regulons_output.iloc[-5:,:])

    # Survival analysis of transcriptional programs
    # Create dictionary of program genes
    # make dictionary of genes by program
    pr_genes = {}
    for i in range(len(program_list)):
        rgns = program_list[i]
        genes = []
        for r in rgns:
            genes.append(regulon_modules[r])
        genes = list(set(np.hstack(genes)))
        pr_genes[i] = genes

    cox_programs_output = miner.parallelMedianSurvivalAnalysis(pr_genes, exp_data,
                                                               guanSurvivalDfMMRF,numCores=5)
    cox_programs_output = cox_programs_output.iloc[np.argsort(np.array(cox_programs_output.index).astype(int))]
    cox_programs_output.to_csv(os.path.join(args.outdir,
                                            'CoxProportionalHazardsPrograms.csv'))
    cox_programs_output.sort_values(by="HR", ascending=False, inplace=True)

    print("\nHigh-risk programs:")
    print(cox_programs_output.iloc[0:5,:])
    print("\nLow-risk programs")
    print(cox_programs_output.iloc[-5:,:])

    # Kaplan-Meier plot of all programs (median expression)
    srv = guanSurvivalDfMMRF.copy()
    keys = list(pr_genes.keys())

    plt.figure(figsize=(8,8))
    plt.style.use('seaborn-whitegrid')
    #plt.xlim(-100,2000)

    for key in keys:
        cluster = np.array(exp_data.loc[pr_genes[key],:])
        median_ = np.mean(cluster,axis=0)
        threshold = np.percentile(median_,85)
        median_[median_>=threshold] = 1
        median_[median_<threshold] = 0
        membership_df = pd.DataFrame(median_)
        membership_df.index = exp_data.columns
        membership_df.columns = [key]

        cox_hr, cox_p = miner.survivalMembershipAnalysisDirect(membership_df, guanSurvivalDfMMRF)

        groups = [membership_df.index[np.where(membership_df[key]==1)[0]]]
        labels = ["{0}: {1:.2f}".format(str(key), cox_hr)]
        #print("groups: %d" % len(groups[0]))
        #print("labels: %d" % len(labels[0]))

        miner.kmplot(srv=srv, groups=groups, labels=labels,
                     #xlim_=(-100,1750),
                     filename=None, lw=2, color="gray", alpha=0.3)

    key_min = cox_programs_output.index[0]
    key_max = cox_programs_output.index[-1]

    cluster = np.array(exp_data.loc[pr_genes[key_min],:])
    median_ = np.mean(cluster, axis=0)
    threshold = np.percentile(median_,85)
    median_[median_>=threshold] = 1
    median_[median_<threshold] = 0
    membership_df = pd.DataFrame(median_)
    membership_df.index = exp_data.columns
    membership_df.columns = [key_min]

    cox_hr, cox_p = miner.survivalMembershipAnalysisDirect(membership_df, guanSurvivalDfMMRF)

    groups = [membership_df.index[np.where(membership_df[key_min]==1)[0]]]
    labels = ["{0}: {1:.2f}".format(str(key), cox_hr)]
    #print("groups: %d" % len(groups[0]))
    #print("labels: %d" % len(labels[0]))

    miner.kmplot(srv=srv,groups=groups,labels=labels,
                 #xlim_=(-100,1750),
                 filename=None,lw=2,color="blue",alpha=1)

    cluster = np.array(exp_data.loc[pr_genes[key_max],:])
    median_ = np.mean(cluster,axis=0)
    threshold = np.percentile(median_,85)
    median_[median_>=threshold] = 1
    median_[median_<threshold] = 0
    membership_df = pd.DataFrame(median_)
    membership_df.index = exp_data.columns
    membership_df.columns = [key_max]

    cox_hr, cox_p = miner.survivalMembershipAnalysisDirect(membership_df, guanSurvivalDfMMRF)

    groups = [membership_df.index[np.where(membership_df[key_max]==1)[0]]]
    labels = ["{0}: {1:.2f}".format(str(key), cox_hr)]
    #print("groups: %d" % len(groups[0]))
    #print("labels: %d" % len(labels[0]))

    miner.kmplot(srv=srv, groups=groups, labels=labels,
                 #xlim_=(-100,1750),
                 filename=None, lw=2, color="red",alpha=1)

    plt.savefig(os.path.join(args.outdir,"kmplots_programs.pdf"), bbox_inches="tight")

    # Survival analysis of transcriptional states
    # Kaplan-Meier plot of all programs (median expression)
    plt.figure(figsize=(8,8))
    plt.style.use('seaborn-whitegrid')
    #plt.xlim(-100,2000)

    srv = guanSurvivalDfMMRF.copy()
    for key in range(len(states_list)):
        median_df = pd.DataFrame(np.zeros(exp_data.shape[1]))
        median_df.index = exp_data.columns
        median_df.columns = [key]
        median_df.loc[states_list[key],key] = 1

        cox_hr, cox_p = miner.survivalMembershipAnalysisDirect(median_df, guanSurvivalDfMMRF)

        groups = [median_df.index[np.where(median_df[key]==1)[0]]]
        labels = ["{0}: {1:.2f}".format(str(key), cox_hr)]
        #print("groups: %d" % len(groups[0]))
        #print("labels: %d" % len(labels[0]))

        miner.kmplot(srv=srv,groups=groups,labels=labels,
                     #xlim_=(-100,1750),
                     filename=None,lw=2,color="gray",alpha=0.3)

    try:
        highlight_list = [np.hstack([states_list[4],states_list[5],states_list[16],states_list[23]])]

        for key in range(len(highlight_list)):
            median_df = pd.DataFrame(np.zeros(exp_data.shape[1]))
            median_df.index = exp_data.columns
            median_df.columns = [key]
            median_df.loc[highlight_list[key],key] = 1

            cox_hr, cox_p = miner.survivalMembershipAnalysisDirect(median_df, guanSurvivalDfMMRF)

            groups = [median_df.index[np.where(median_df[key]==1)[0]]]
            labels = ["{0}: {1:.2f}".format(str(key), cox_hr)]
            #print("groups: %d" % len(groups[0]))
            #print("labels: %d" % len(labels[0]))

            miner.kmplot(srv=srv,groups=groups,labels=labels,
                         #xlim_=(-100,1750),
                         filename=None,lw=2,color="red",alpha=1)

        # Combine high-risk states
        highlight_list = [np.hstack([states_list[10],states_list[14],states_list[21]])]

        for key in range(len(highlight_list)):
            median_df = pd.DataFrame(np.zeros(exp_data.shape[1]))
            median_df.index = exp_data.columns
            median_df.columns = [key]
            median_df.loc[highlight_list[key],key] = 1

            cox_hr, cox_p = miner.survivalMembershipAnalysisDirect(median_df, guanSurvivalDfMMRF)

            groups = [median_df.index[np.where(median_df[key]==1)[0]]]
            labels = ["{0}: {1:.2f}".format(str(key), cox_hr)]
            #print("groups: %d" % len(groups[0]))
            #print("labels: %d" % len(labels[0]))

            miner.kmplot(srv=srv,groups=groups,labels=labels,
                         #xlim_=(-100,1750),
                         filename=None,lw=2,color="blue",alpha=1)

        plt.savefig(os.path.join(args.outdir, "kmplots_states.pdf"), bbox_inches="tight")
    except:
        # skip plotting the states if that leads to index error
        # for now
        pass

    # Generate boxplot data for transcriptional states
    survival_patients = list(guanSurvivalDfMMRF.index)

    if translocations is not None:
        t414_patients = translocations.columns[
            np.where(translocations.loc["RNASeq_WHSC1_Call",:]==1)[0]
        ]
        t1114_patients = translocations.columns[
            np.where(translocations.loc["RNASeq_CCND1_Call",:]==1)[0]
        ]
    else:
        t414_patients = []
        t1114_patients = []

    min_patients = 5
    ranks = []
    boxplot_data = []
    boxplot_names = []
    boxplot_samples = []
    boxplot_labels = []
    percent_t414 = []
    percent_t1114 = []
    for key in range(len(states_list)):
        state = states_list[key]
        overlap_patients = list(set(survival_patients) & set(state))
        if len(overlap_patients) < min_patients:
            continue
        guan_data = list(guanSurvivalDfMMRF.loc[overlap_patients,"GuanScore"])
        boxplot_samples.append(overlap_patients)
        boxplot_data.append(guan_data)
        boxplot_names.append([1+key for i in range(len(overlap_patients))])
        risk_score = np.median(guan_data)
        ranks.append(risk_score)

        t414_overlap = len(set(state)&set(t414_patients))
        t1114_overlap = len(set(state)&set(t1114_patients))
        percent_t414_ = float(t414_overlap)/len(state)
        percent_t1114_ = float(t1114_overlap)/len(state)
        percent_t414.append(percent_t414_)
        percent_t1114.append(percent_t1114_)

    labels = np.hstack(np.array(boxplot_names)[np.argsort(ranks)])
    labels_df = pd.DataFrame(labels)
    labels_df.index = np.hstack(np.array(boxplot_samples)[np.argsort(ranks)])
    labels_df.columns = ["label"]
    plot_data = pd.concat([guanSurvivalDfMMRF.loc[labels_df.index,"GuanScore"],labels_df],axis=1)

    rank_order = np.array(list(set(np.hstack(boxplot_names))))[np.argsort(ranks)]
    ranked_t414 = np.array(percent_t414)[np.argsort(ranks)]
    ranked_t1114 = np.array(percent_t1114)[np.argsort(ranks)]

    # Violin plots by states
    f, ax = plt.subplots(figsize=(12, 2))

    # Plot the orbital period with horizontal boxes
    sns.violinplot(x="label", y="GuanScore", data=plot_data, fliersize=0,
                palette="coolwarm",order=rank_order)

    # Add in points to show each observation
    sns.swarmplot(x="label", y="GuanScore", data=plot_data,
                  size=2, color=".3", linewidth=0,order=rank_order)

    # Tweak the visual presentation
    ax.set(ylabel="")
    ax.set(xlabel="")

    # Save figure
    plt.savefig(os.path.join(args.outdir, "violin_states_risk.pdf"), bbox_inches="tight")

    # Boxplots by states
    f, ax = plt.subplots(figsize=(12, 2))

    # Plot the orbital period with horizontal boxes
    sns.boxplot(x="label", y="GuanScore", data=plot_data,fliersize=0,
                palette="coolwarm",order=rank_order)

    # Add in points to show each observation
    sns.swarmplot(x="label", y="GuanScore", data=plot_data,
                  size=2, color=".3", linewidth=0,order=rank_order)

    # Tweak the visual presentation
    ax.set(ylabel="")
    ax.set(xlabel="")
    ax.set(ylim=(-0.1,1.1))

    # Save figure
    plt.savefig(os.path.join(args.outdir, "boxplot_states_risk.pdf"), bbox_inches="tight")

    if translocations is not None:
        # t(4;14) and t(11;14) subtypes by states
        plt.figure(figsize=(12, 2))

        N = len(ranks)
        ind = np.arange(N)    # the x locations for the groups
        w = 0.6
        p1 = plt.bar(ind, 100*ranked_t1114,width=w,color='#0A6ECC',edgecolor="white",alpha=1)
        p2 = plt.bar(ind, 100*ranked_t414,
                     bottom=100*ranked_t1114,width=w,color='#E53939',edgecolor="white",alpha=1)

        plt.xlim(-0.5,N-0.5)
        plt.ylim(-5,110)
        try:
            plt.xticks(ticks=range(len(rank_order)),labels=list(rank_order))
        except:
            # solves compatibility issue with pyplot, as it does not work in Python2
            pass
        plt.legend((p1[0], p2[0]), ('t(11;14)', 't(4;14)'),loc="upper left")

        plt.savefig(os.path.join(args.outdir,"barplot_states_translocations.pdf"),
                    bbox_inches="tight")




