#!/usr/bin/env python

import argparse
import pandas as pd
import numpy as np
import json
import sys
import os
import time
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import logging

from miner2 import coexpression
from miner2 import preprocess
from miner2 import GIT_SHA
from miner2 import __version__ as MINER_VERSION
from miner2 import util

DESCRIPTION = """miner-coexpr - MINER cluster expression data.
MINER Version %s (Git SHA %s)""" % (str(MINER_VERSION).replace('miner2 ', ''),
                                    GIT_SHA.replace('$Id: ', '').replace(' $', ''))


def plot_expression_stats(exp_data, outdir):
    plt.figure()
    ind_exp_data = [exp_data.iloc[:,i] for i in range(50)]
    _ = plt.boxplot(ind_exp_data)
    plt.title("Patient expression profiles",FontSize=14)
    plt.ylabel("Relative expression",FontSize=14)
    plt.xlabel("Sample ID",FontSize=14)
    plt.savefig(os.path.join(outdir, "patient_expression_profiles.pdf"),
                bbox_inches="tight")

    plt.figure()
    _ = plt.hist(exp_data.iloc[0,:],bins=100, alpha=0.75)
    plt.title("Expression of single gene", FontSize=14)
    plt.ylabel("Frequency", FontSize=14)
    plt.xlabel("Relative expression", FontSize=14)
    plt.savefig(os.path.join(outdir, "expression_single_gene.pdf"),
                bbox_inches="tight")

    plt.figure()
    _ = plt.hist(exp_data.iloc[:,0],bins=200,color=[0,0.4,0.8],alpha=0.75)
    plt.ylim(0, 350)
    plt.title("Expression of single patient sample",FontSize=14)
    plt.ylabel("Frequency", FontSize=14)
    plt.xlabel("Relative expression", FontSize=14)
    plt.savefig(os.path.join(outdir, "expression_single_patient.pdf"),
                bbox_inches="tight")


if __name__ == '__main__':
    LOG_FORMAT = '%(asctime)s %(message)s'
    logging.basicConfig(format=LOG_FORMAT, level=logging.DEBUG,
                        datefmt='%Y-%m-%d %H:%M:%S \t')

    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                     description=DESCRIPTION)
    parser.add_argument('expfile', help="input matrix")
    parser.add_argument('mapfile', help="identifier mapping file")
    parser.add_argument('outdir', help="output directory")
    parser.add_argument('-mg', '--mingenes', type=int, default=6, help="min number genes")
    parser.add_argument('-moxs', '--minoverexpsamp', type=int, default=4,
                        help="minimum overexpression samples")
    parser.add_argument('-mx', '--maxexclusion', type=float, default=0.5,
                        help="maximum samples excluded")
    parser.add_argument('-rs', '--randstate', type=float, default=12,
                        help="random state")
    parser.add_argument('-oxt', '--overexpthresh', type=int, default=80,
                        help="overexpression threshold")
    parser.add_argument('--skip_tpm', action="store_true",
                        help="overexpression threshold")

    args = parser.parse_args()
    if not os.path.exists(args.expfile):
        sys.exit("expression file not found")
    if not os.path.exists(args.mapfile):
        sys.exit("identifier mapping file not found")

    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    with open(os.path.join(args.outdir, 'run_info.txt'), 'w') as outfile:
        util.write_dependency_infos(outfile)

    exp_data, conv_table = preprocess.main(args.expfile, args.mapfile, do_preprocess_tpm=(not args.skip_tpm))
    plot_expression_stats(exp_data, args.outdir)

    t1 = time.time()
    init_clusters = coexpression.cluster(exp_data,
                                         min_number_genes=args.mingenes,
                                         min_number_overexp_samples=args.minoverexpsamp,
                                         max_samples_excluded=args.maxexclusion,
                                         random_state=args.randstate,
                                         overexpression_threshold=args.overexpthresh)

    revised_clusters = coexpression.revise_initial_clusters(init_clusters, exp_data)
    with open(os.path.join(args.outdir, "coexpressionDictionary.json"), 'w') as out:
        json.dump(revised_clusters, out)


    # retrieve first three clusters for visual inspection
    first_clusters = np.hstack([revised_clusters[i] for i in np.arange(3).astype(str)])

    # visualize background expression
    plt.figure(figsize=(8,4))
    plt.imshow(exp_data.loc[np.random.choice(exp_data.index, len(first_clusters), replace=False),:],
               aspect="auto", cmap="viridis", vmin=-1,vmax=1)
    plt.grid(False)
    plt.ylabel("Genes",FontSize=20)
    plt.xlabel("Samples",FontSize=20)
    plt.title("Random selection of genes",FontSize=20)

    plt.savefig(os.path.join(args.outdir, "background_expression.pdf"),
                bbox_inches="tight")

    # visualize first 10 clusters
    plt.figure(figsize=(8,4))
    plt.imshow(exp_data.loc[first_clusters,:], aspect="auto", cmap="viridis", vmin=-1, vmax=1)
    plt.grid(False)
    plt.ylabel("Genes", FontSize=20)
    plt.xlabel("Samples", FontSize=20)
    plt.title("First 3 clusters", FontSize=20)
    plt.savefig(os.path.join(args.outdir, "first_clusters.pdf"),
                bbox_inches="tight")

    # report coverage
    logging.info("Number of genes clustered: {:d}".format(len(set(np.hstack(init_clusters)))))
    logging.info("Number of unique clusters: {:d}".format(len(revised_clusters)))

    t2 = time.time()
    logging.info("Completed clustering module in {:.2f} minutes".format((t2-t1)/60.))
