#!/usr/bin/env python

import argparse
import pandas as pd
import numpy as np
import json
import sys
import os
import time
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import logging
from pkg_resources import Requirement, resource_filename

from miner2 import coexpression, preprocess, mechanistic_inference as mechinf, miner
from miner2 import GIT_SHA
from miner2 import __version__ as MINER_VERSION
from miner2 import util

MIN_REGULON_GENES = 5
DESCRIPTION = """miner-causalinference - MINER causal inference.
MINER Version %s (Git SHA %s)""" % (str(MINER_VERSION).replace('miner2 ', ''),
                                    GIT_SHA.replace('$Id: ', '').replace(' $', ''))


def expand_regulons(exp_data, regulon_modules, regulon_df, eigengenes,
                    overexpressed_members_matrix, args):
    """TODO: This can't be done without the network mapping/causal inference step !!!!
    Move it there, because we need that overexpressed_members_matrix
    note that in that case, tfbsdb_path must be defined
    """
    t1 = time.time()
    tfbsdb_path = resource_filename(Requirement.parse("isb_miner2"),
                                    'miner2/data/{}'.format("tfbsdb_tf_to_genes.pkl"))
    expandedRegulons = miner.parallelRegulonExpansion(eigengenes, regulon_modules, regulon_df,
                                                      exp_data, tfbsdb_path,
                                                      overexpressed_members_matrix,
                                                      corrThreshold=0.25, auc_threshold=0.70,
                                                      numCores=5)

    regulonIDtoRegulator = miner.regulonIdToRegulator(regulon_df)
    expandedRegulonDf = miner.regulonDictToDf(expandedRegulons, regulonIDtoRegulator)
    t2 = time.time()
    print("Completed regulon expansion in {:.2f} minutes".format((t2-t1)/float(60.)))

    expandedEigengenes = miner.getEigengenes(expandedRegulons, exp_data,
                                             regulon_dict=None, saveFolder=None)
    eigenScale = np.percentile(expressionData, 95) / np.percentile(expandedEigengenes, 95)
    expandedEigengenes = eigenScale * expandedEigengenes
    expandedEigengenes.index = np.array(expandedEigengenes.index).astype(str)

    expandedEigengenes.to_csv(os.path.join(args.outdir, "expanded_eigengenes.csv"))
    expandedRegulonDf.to_csv(os.path.join(args.outdir, "expanded_regulonDf.csv"))
    miner.write_json(expandedRegulons,os.path.join(args.outdir, "expanded_regulons.json"))
    print("Expanded network to {:d} regulons, {:d} regulators, and {:d} co-regulated genes".format(
        len(expandedRegulonDf.Regulon_ID.unique()),
        len(expandedRegulonDf.Regulator.unique()),
        len(expandedRegulonDf.Gene.unique())))


if __name__ == '__main__':
    LOG_FORMAT = '%(asctime)s %(message)s'
    logging.basicConfig(format=LOG_FORMAT, level=logging.DEBUG,
                        datefmt='%Y-%m-%d %H:%M:%S \t')

    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                     description=DESCRIPTION)
    parser.add_argument('expfile', help="input matrix")
    parser.add_argument('mapfile', help="identifier mapping file")
    parser.add_argument('coreg', help="coregulationModules.json file from miner-mechinf")
    parser.add_argument('coher', help="coherentMembers.csv file from miner-bcmembers")

    parser.add_argument('--common_mutations', help="common mutations file")
    parser.add_argument('--translocations', help="translocations file")
    parser.add_argument('--cytogenetics', help="cytogenetics file")
    parser.add_argument('--skip_tpm', action="store_true",
                        help="overexpression threshold")
    parser.add_argument('outdir', help="output directory")

    args = parser.parse_args()

    if not os.path.exists(args.expfile):
        sys.exit("expression file not found")
    if not os.path.exists(args.mapfile):
        sys.exit("identifier mapping file not found")
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    exp_data, conv_table = preprocess.main(args.expfile, args.mapfile, do_preprocess_tpm=(not args.skip_tpm))
    with open(args.coreg) as infile:
        coregulation_modules = json.load(infile)

    regulons = mechinf.get_regulons(coregulation_modules,
                                    min_number_genes=MIN_REGULON_GENES,
                                    freq_threshold=0.333)
    regulon_modules, regulon_df = mechinf.get_regulon_dictionary(regulons)
    coherent_samples_matrix = pd.read_csv(args.coher, index_col=0, header=0)

    eigengenes = miner.getEigengenes(regulon_modules, exp_data,
                                     regulon_dict=None, saveFolder=None)
    eigen_scale = np.percentile(exp_data,95) / np.percentile(eigengenes, 95)
    eigengenes = eigen_scale * eigengenes
    eigengenes.index = np.array(eigengenes.index).astype(str)

    # Perform causal analysis for each mutation matrix
    result_dir = os.path.join(args.outdir, "causal_analysis")
    if not os.path.isdir(result_dir):
        os.mkdir(result_dir)

    if args.common_mutations is not None:
        print('causal network analysis for common mutations')
        common_mutations = pd.read_csv(args.common_mutations, index_col=0, header=0)
        miner.causalNetworkAnalysis(regulon_matrix=regulon_df,
                                    expression_matrix=exp_data,
                                    reference_matrix=eigengenes,
                                    mutation_matrix=common_mutations,
                                    resultsDirectory=result_dir,
                                    minRegulons=1,
                                    significance_threshold=0.05,
                                    causalFolder="causal_results_common_mutations")

    if args.translocations is not None:
        print('causal network analysis for translocations')
        translocations = pd.read_csv(args.translocations, index_col=0, header=0)
        miner.causalNetworkAnalysis(regulon_matrix=regulon_df,
                                    expression_matrix=exp_data,
                                    reference_matrix=eigengenes,
                                    mutation_matrix=translocations,
                                    resultsDirectory=result_dir,
                                    minRegulons=1,
                                    significance_threshold=0.05,
                                    causalFolder="causal_results_translocations")

    if args.cytogenetics is not None:
        print('causal network analysis for cytogenetics')
        cytogenetics = pd.read_csv(args.cytogenetics, index_col=0, header=0)
        miner.causalNetworkAnalysis(regulon_matrix=regulon_df,
                                    expression_matrix=exp_data,
                                    reference_matrix=eigengenes,
                                    mutation_matrix=cytogenetics,
                                    resultsDirectory=result_dir,
                                    minRegulons=1,
                                    significance_threshold=0.05,
                                    causalFolder="causal_results_cytogenetics")

    # compile all causal results
    causal_results = miner.readCausalFiles(result_dir)
    causal_results.to_csv(os.path.join(args.outdir, "completeCausalResults.csv"))

    wire_diagram_out = os.path.join(args.outdir, 'wiring_diagram.csv')
    wire_diagram = miner.wiringDiagram(causal_results, regulon_modules,
                                       coherent_samples_matrix,
                                       include_genes=False,
                                       savefile=wire_diagram_out)


    # Generate Filtered Causal Flows
    causal_results_regulon_filtered = causal_results[causal_results["-log10(p)_Regulon_stratification"]>=-np.log10(0.05)]
    causal_results_aligned = causal_results_regulon_filtered[causal_results_regulon_filtered.Fraction_of_edges_correctly_aligned>=0.5]
    causal_results_aligned_correlated = causal_results_aligned[causal_results_aligned["RegulatorRegulon_Spearman_p-value"]<=0.05]
    causal_results_stratified_aligned_correlated = causal_results_aligned_correlated[causal_results_aligned_correlated["-log10(p)_MutationRegulatorEdge"]>=-np.log10(0.05)]

    # for all causal flows, 
    # the regulon is differentially active w.r.t the mutation,
    # the regulator is differentially active w.r.t the mutation,
    # the regulator is significantly correlated to the regulon,
    # and the directionality of at least half of the differentially active targets 
    # downstream of the regulator are consistent with the perturbation from the mutation
    causal_results_stratified_aligned_correlated.to_csv(os.path.join(args.outdir,
                                                                     "filteredCausalResults.csv"))
