#!/usr/bin/env python

import argparse
import pandas as pd
import numpy as np
import json
import sys
import os
import time

import matplotlib
matplotlib.use('Agg')
import logging
from pkg_resources import Requirement, resource_filename

from miner2 import preprocess, util, mechanistic_inference, GIT_SHA, subtypes
from miner2 import __version__ as MINER_VERSION
from miner2 import miner


DESCRIPTION = """miner-mechinf - MINER compute mechanistic inference
MINER Version %s (Git SHA %s)""" % (str(MINER_VERSION).replace('miner2 ', ''),
                                    GIT_SHA.replace('$Id: ', '').replace(' $', ''))

NUM_CORES = 5
MIN_REGULON_GENES = 5

if __name__ == '__main__':
    LOG_FORMAT = '%(asctime)s %(message)s'
    logging.basicConfig(format=LOG_FORMAT, level=logging.DEBUG,
                        datefmt='%Y-%m-%d %H:%M:%S \t')

    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                     description=DESCRIPTION)
    parser.add_argument('expfile', help="input matrix")
    parser.add_argument('mapfile', help="identifier mapping file")
    parser.add_argument('coexprdict', help="coexpressionDictionary.json file from miner-coexpr")
    parser.add_argument('outdir', help="output directory")
    parser.add_argument('-mc', '--mincorr', type=float, default=0.2,
                        help="minimum correlation")
    parser.add_argument('--skip_tpm', action="store_true",
                        help="overexpression threshold")

    args = parser.parse_args()

    if not os.path.exists(args.expfile):
        sys.exit("expression file not found")
    if not os.path.exists(args.mapfile):
        sys.exit("identifier mapping file not found")
    if not os.path.exists(args.coexprdict):
        sys.exit("revised clusters file not found")

    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    with open(os.path.join(args.outdir, 'run_info.txt'), 'w') as outfile:
        util.write_dependency_infos(outfile)

    exp_data, conv_table = preprocess.main(args.expfile, args.mapfile, do_preprocess_tpm=(not args.skip_tpm))

    with open(args.coexprdict) as infile:
        revised_clusters = json.load(infile)

    # get first principal component axes of clusters
    t1 = time.time()
    database_path = resource_filename(Requirement.parse("isb_miner2"),
                                      'miner2/data/{}'.format("tfbsdb_tf_to_genes.pkl"))

    axes = mechanistic_inference.get_principal_df(revised_clusters, exp_data,
                                                  subkey=None, min_number_genes=1)

    # analyze revised clusters for enrichment in relational database
    # (default: transcription factor binding site database)
    mechanistic_output = mechanistic_inference.enrichment(axes, revised_clusters, exp_data,
                                                          correlation_threshold=args.mincorr,
                                                          num_cores=NUM_CORES,
                                                          database_path=database_path)


    # write mechanistic output to .json file
    with open(os.path.join(args.outdir, "mechanisticOutput.json"), 'w') as outfile:
        json.dump(mechanistic_output, outfile)

    # order mechanisticOutput as {tf:{coexpressionModule:genes}}
    coregulation_modules = mechanistic_inference.get_coregulation_modules(mechanistic_output)

    # write coregulation modules to .json file
    with open(os.path.join(args.outdir, "coregulationModules.json"), 'w') as outfile:
        json.dump(coregulation_modules, outfile)

    # get final regulons by keeping genes that requently appear coexpressed and associated
    # to a common regulator
    regulons = mechanistic_inference.get_regulons(coregulation_modules,
                                                  min_number_genes=MIN_REGULON_GENES,
                                                  freq_threshold=0.333)

    # reformat regulon dictionary for consistency with revisedClusters and coexpressionModules
    regulon_modules, regulon_df = mechanistic_inference.get_regulon_dictionary(regulons)

    # write regulons to json file
    with open(os.path.join(args.outdir, "regulons.json"), 'w') as outfile:
        json.dump(regulon_modules, outfile)
    regulon_df.to_csv(os.path.join(args.outdir, "regulonDf.csv"))

    # define coexpression modules as composite of coexpressed regulons
    coexpression_modules = mechanistic_inference.get_coexpression_modules(mechanistic_output)

    # write coexpression modules to .json file
    with open(os.path.join(args.outdir, "coexpressionModules.json"), 'w') as outfile:
        json.dump(coexpression_modules, outfile)

    # reconvert revised clusters to original gene annotations
    annotated_revised_clusters = mechanistic_inference.convert_dictionary(revised_clusters,
                                                                          conv_table)

    # write annotated coexpression clusters to .json file
    with open(os.path.join(args.outdir, "coexpressionDictionary_annotated.json"), 'w') as outfile:
        json.dump(revised_clusters, outfile)

    # reconvert results into original annotations
    regulon_annotated_df = mechanistic_inference.convert_regulons(regulon_df, conv_table)

    # write annotated regulon table to .csv
    regulon_annotated_df.to_csv(os.path.join(args.outdir, "regulons_annotated.csv"))

    # reconvert regulons
    annotated_regulons = mechanistic_inference.convert_dictionary(regulon_modules, conv_table)

    # write annotated regulons to .json file
    with open(os.path.join(args.outdir, "regulons_annotated.json"), 'w') as outfile:
        json.dump(regulons, outfile)

    # reconvert coexpression modules
    annotated_coexpression_modules = mechanistic_inference.convert_dictionary(coexpression_modules, conv_table)

    # write annotated coexpression modules to .json file
    with open(os.path.join(args.outdir, "coexpressionModules_annotated.json"), 'w') as outfile:
        json.dump(annotated_coexpression_modules, outfile)



    # Get eigengenes for all modules
    eigengenes = subtypes.get_eigengenes(regulon_modules, exp_data, regulon_dict=None,
                                         saveFolder=None)
    eigen_scale = np.percentile(exp_data, 95) / np.percentile(eigengenes, 95)
    eigengenes = eigen_scale * eigengenes
    eigengenes.index = np.array(eigengenes.index).astype(str)

    # write eigengenes to .csv
    eigengenes.to_csv(os.path.join(args.outdir, "eigengenes.csv"))

    t2 = time.time()
    logging.info("Completed mechanistic inference in {:.2f} minutes".format((t2 - t1) / 60.))
    logging.info("Inferred network with {:d} regulons, {:d} regulators, and {:d} co-regulated genes".format(len(regulon_df.Regulon_ID.unique()), len(regulon_df.Regulator.unique()),len(regulon_df.Gene.unique())))
