#!/usr/bin/which python3


# Copyright (C) 2015-2019 Jeff Vierstra (jvierstra@altius.org)

from __future__ import print_function, division

from __future__ import absolute_import
import sys, logging, os, os.path, glob, tempfile, shutil
from argparse import ArgumentParser
import multiprocessing as mp

import numpy as np

import scipy.stats
import scipy.optimize

import pysam

from genome_tools import bed, genomic_interval, genomic_interval_set

from footprint_tools.modeling import dispersion
from footprint_tools.stats import bayesian, segment, differential, windowing

logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO)
logger = logging.getLogger(__name__)

def parse_options(args):

    parser = ArgumentParser(description = "Perform a cleavage rate differential test between two groups")

    parser.add_argument("metadata_file", metavar = "metadata_file", type = str,
                        help = "Path to configuration file containing metadata for samples")

    parser.add_argument("interval_file", metavar = "interval_file", type = str, 
                        help = "File path to BED file")

    grp_ot = parser.add_argument_group("Other options")

    grp_ot.add_argument("--processors", metavar = "N", type = int,
                        dest = "processors", default = mp.cpu_count()-2,
                        help = "Number of processors to use. Note that value excludes the"
                        " minimum 2 threads that are dedicated to data I/O."
                        " (default: all available processors)")

    grp_ot.add_argument("--tmpdir", metavar = "", type = str,
                        dest = "tmpdir", default = None,
                        help = "Temporary directory to use"
                        " (default: uses unix 'mkdtemp'")

    return parser.parse_args(args)


def read_func(tabix_files, intervals, queue):

    tabix_handles = [pysam.TabixFile(str(f)) for f in tabix_files]
    n_datasets = len(tabix_handles)


     # Write to input queue
    for interval in intervals:

        l = len(interval)

        obs = np.zeros((n_datasets, l), dtype = np.float64)
        exp = np.zeros((n_datasets, l), dtype = np.float64)

        i = 0
        j = 0
    
        for tabix in tabix_handles:

            try:
                for row in tabix.fetch(interval.chrom, interval.start, interval.end, parser = pysam.asTuple()):
                    j = int(row[1])-interval.start
                    exp[i, j] = np.float64(row[3])
                    obs[i, j] = np.float64(row[4])

            except:
                pass

            i += 1

        queue.put( (interval, exp, obs) )

        # Stop memory from getting out of control
        while queue.qsize() > 100:
            pass

    [handle.close() for handle in tabix_handles]


def median_abs_dev(x):
    med = np.median(x)
    return med, np.median(np.abs(x - med))

def find_outliers(x):
    med, mad = median_abs_dev(x)
    if mad>0:    
    	dev = 0.6745 * (x - med) / mad
    	return np.abs(dev) > 3.5
    else:
        return np.zeros(len(x), dtype=bool)

def ratio_diff(pra,prb):
    x = np.linspace(-6, 6, 100)
    mua = x[np.argmax(np.exp(pra))]
    mub = x[np.argmax(np.exp(prb))]
    return(mub-mua)


def process_func(disp_models, l_group_a, l_group_b, queue, outfile):

    handle = open(outfile, 'w')

    x_values = np.linspace(-6, 6, 100)

    while 1:

        data = queue.get()

        if data == None:
            queue.task_done()
            break

        (interval, exp, obs) = data

        log2_obs_over_exp = np.log2((obs+1)/(exp+1))

        #filter outliers
        outliers = np.apply_along_axis(find_outliers, 0, log2_obs_over_exp)
        log2_obs_over_exp[outliers] = np.nan

        # Step 1: Fit prior
        variance = np.var(log2_obs_over_exp[:,:], axis = 0)
        variance = variance[np.isfinite(variance)]
        variance = variance[variance>0]

        nu_0, sig2_0 = scipy.optimize.fmin(differential.invchi2_likelihood, [1, 1], args = (variance,), full_output = False, disp = False)

        logger.info("nu_0 = %f, sig2_0 =  %f" % (nu_0, sig2_0))

        # Calculate prior
        pr_a = differential.compute_log_prior_t(log2_obs_over_exp[:l_group_a,:], nu_0, sig2_0, -6, 6, 100)
        pr_b = differential.compute_log_prior_t(log2_obs_over_exp[l_group_a:,:], nu_0, sig2_0, -6, 6, 100)
        pr_ab = differential.compute_log_prior_t(log2_obs_over_exp[:,:], nu_0, sig2_0, -6, 6, 100)

        # Calculate negative binomial log pmf
        nb = differential.compute_logpmf_values(disp_models, obs[:,:], exp[:,:], -6, 6, 100)

        # psuedo-integration over 'depletion' scores
        pa = pr_a[:,:, np.newaxis] + nb[:,:,:l_group_a]
        pb = pr_b[:,:, np.newaxis] + nb[:,:,l_group_a:]
        pab = pr_ab[:,:, np.newaxis] + nb[:,:,:]

        La = np.sum(scipy.misc.logsumexp(pa, axis = 0), axis = 1)
        Lb = np.sum(scipy.misc.logsumexp(pb, axis = 0), axis = 1)
        Lab = np.sum(scipy.misc.logsumexp(pab, axis = 0), axis = 1)

        llr = La + Lb - Lab
        lrt = scipy.stats.chi2.sf(2*llr, df = 3)
        lrt[lrt==1.0]=(1.0-1e-6)

        stfz = windowing.stouffers_z(lrt, 3)

        r = np.array([ratio_diff(pr_a[:,i], pr_b[:,i]) for i in range(len(interval))])

        for i in range(len(interval)):
            print("{}\t{:d}\t{:d}\t{:f}\t{:f}\t{:f}".format(interval.chrom, interval.start+i, interval.start+i+1, -np.log(lrt[i]), -np.log(stfz[i]), r[i]), file = handle)

        queue.task_done()

    handle.close()

def read_metadata_file(filename):

    ids = []
    groups = []
    dm_files = []
    tabix_files = []

    with open(filename) as f:  
        for line in f:
            
            if line[0] == "#":
                continue

            #try:
            fields = line.strip().split("\t")
            #print(fields)
            groups.append(fields[0])
            ids.append(fields[1])
            dm_files.append(fields[2])
            tabix_files.append(fields[3]) 
            #except Exec:
            #print("ERROR:" + line)


    return (np.array(groups), np.array(ids), np.array(dm_files), np.array(tabix_files))

def main(argv = sys.argv[1:]):

    args = parse_options(argv)

    (groups, ids, disp_model_files, tabix_files) = read_metadata_file(args.metadata_file)

    n_datasets = len(ids)
    disp_models = []

    # Load and parse input files
    for i in range(n_datasets):

         # Check to make sure that all input files exist
        if not os.path.exists(tabix_files[i]):
            print("Fatal error: TABIX-file %s does not exists!" % tabix_files[i], file = sys.stderr)
            sys.exit(1)

        # Load dispersion model
        if not os.path.exists(disp_model_files[i]):
            print("Fatal error: Dispersion model file %s does not exists!" % disp_model_files[i], file = sys.stderr)
            sys.exit(1)
        else:
            disp_models.append(dispersion.read_dispersion_model(disp_model_files[i]))

    disp_models = np.array(disp_models)

    # order data for group analysis
    o = np.argsort(groups)

    l_group_a = np.sum(groups == "A")
    l_group_b = np.sum(groups == "B")

    assert(len(groups) == l_group_a + l_group_b)

    # Load intervals file
    intervals = genomic_interval_set(bed.bed3_iterator(open(args.interval_file)))
    n_intervals = len(intervals)
    h = int(n_intervals / 2)

    #
    tmpdir = tempfile.mkdtemp() if not args.tmpdir else args.tmpdir
    chunk_files = [os.path.join(tmpdir, "chunk%s" % i) for i in range(args.processors)]

    #
    q = mp.JoinableQueue()   
    readers = [ mp.Process(target = read_func, args = (tabix_files[o], intervals[:h], q)), mp.Process(target = read_func, args = (tabix_files[o], intervals[h:], q)) ]
    processors = [ mp.Process(target = process_func, args = (disp_models[o], l_group_a, l_group_b, q, f)) for f in chunk_files ]

    [reader.start() for reader in readers]
    [processor.start() for processor in processors]

    logger.info("Working (%d threads; chunked results in %s)" % (len(chunk_files), tmpdir))

    [reader.join() for reader in readers]

    logger.info("Finishing reading data -- waiting for final processing...")

    q.join() # Wait for queue to unblock
    [q.put(None) for i in range(len(chunk_files))] # sends a message to kill processing threads	
    [processor.join() for processor in processors] # wait for threads to return

    logger.info("Merging data...")

    for file in chunk_files: 
        with open(file, 'r') as handle:
            for line in handle:
                sys.stdout.write(line)

    logger.info("Cleaning up...")
    
    if not args.tmpdir:
        shutil.rmtree(tmpdir)

    return 0

if __name__ == "__main__":

    sys.exit(main())

