#!/usr/bin/which python3


# Copyright (C) 2015-2019 Jeff Vierstra (jvierstra@altius.org)
import sys, logging, os, os.path, glob, tempfile, shutil
from argparse import ArgumentParser
import multiprocessing as mp

import numpy as np
import scipy.stats
import pysam

from genome_tools import bed, genomic_interval, genomic_interval_set

from footprint_tools.modeling import dispersion
from footprint_tools.stats import bayesian, segment

logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO)
logger = logging.getLogger(__name__)

def parse_options(args):

    parser = ArgumentParser(description = "Compute the posterior probability of cleavage data")

    parser.add_argument("metadata_file", metavar = "metadata_file", type = str,
                        help = "Path to configuration file containing metadata for samples")

    parser.add_argument("interval_file", metavar = "interval_file", type = str, 
                        help = "File path to BED file")

    grp_st = parser.add_argument_group("Statistical options")

    grp_st.add_argument("--fdr-cutoff", metavar = "N", type = float,
                        dest = "fdr_cutoff", default = 0.05,
                        help = "Only consider nucleotides with FDR <= this value."
                        " (default: %(default)s)")

    grp_o = parser.add_argument_group("Output options")

    grp_o.add_argument("--post-cutoff", metavar = "N", type = float,
                        dest = "post_cutoff", default = 0.2,
                        help = "Only output nucleotides with posterior probability <= this value."
                        " (default: %(default)s)")

    grp_ot = parser.add_argument_group("Other options")

    grp_ot.add_argument("--processors", metavar = "N", type = int,
                        dest = "processors", default = mp.cpu_count()-2,
                        help = "Number of processors to use. Note that value excludes the"
                        " minimum 2 threads that are dedicated to data I/O."
                        " (default: all available processors)")

    return parser.parse_args(args)


def read_func(tabix_files, intervals, queue):

    tabix_handles = [pysam.TabixFile(f) for f in tabix_files]
    n_datasets = len(tabix_handles)

     # Write to input queue
    for interval in intervals:

        l = len(interval)

        obs = np.zeros((n_datasets, l), dtype = np.float64)
        exp = np.zeros((n_datasets, l), dtype = np.float64)
        fdr = np.ones((n_datasets, l), dtype = np.float64)
        w = np.zeros((n_datasets, l), dtype = np.float64)

        i = 0
        j = 0
    
        for tabix in tabix_handles:

            try:
                for row in tabix.fetch(interval.chrom, interval.start, interval.end, parser = pysam.asTuple()):
                    j = int(row[1])-interval.start
                    exp[i, j] = np.float64(row[3])
                    obs[i, j] = np.float64(row[4])
                    fdr[i, j] = np.float64(row[7])
                    w[i, j] = 1
            except:
                pass

            i += 1

        queue.put( (interval, exp, obs, fdr, w) )

        # Stop memory from getting out of control
        while queue.qsize() > 100:
            pass

    [handle.close() for handle in tabix_handles]


def process_func(disp_models, beta_priors, queue, outfile, fdr_cutoff, post_cutoff):

    handle = open(outfile, 'w')

    while 1:

        data = queue.get()

        if data == None:
            queue.task_done()
            break

        (interval, exp, obs, fdr, w) = data

        prior = bayesian.compute_prior_weighted(fdr, w, cutoff = fdr_cutoff)   
        scale = bayesian.compute_delta_prior(obs, exp, fdr, beta_priors, cutoff = fdr_cutoff)

        ll_on = bayesian.log_likelihood(obs, exp, disp_models, delta = scale, w = 3) 
        ll_off = bayesian.log_likelihood(obs, exp, disp_models, w = 3)

        # Compute posterior
        post = -bayesian.posterior(prior, ll_on, ll_off)
        post[post <= 0] = 0.0

        # Write ouput
        ii = np.where(np.nanmax(post, axis = 0) > post_cutoff)[0]

        for i in ii:
            print("{}\t{:d}\t{:d}\t".format(interval.chrom, interval.start+i, interval.start+i+1) + '\t'.join(map(str, post[:,i])), file = handle)

        queue.task_done()

    handle.close()

def read_metadata_file(filename):

    ids = []
    dm_files = []
    tabix_files = []
    bp_files = []

    with open(filename) as f:  
        for line in f:
            
            if line[0] == "#":
                continue

            #TODO: add try statement
            (i, d, t, b) = line.strip().split("\t")

            ids.append(i)
            dm_files.append(d)
            tabix_files.append(t)
            bp_files.append(b if b != "!" else None)            

    return (ids, dm_files, tabix_files, bp_files)

def main(argv = sys.argv[1:]):

    args = parse_options(argv)

    (ids, disp_model_files, tabix_files, beta_prior_files) = read_metadata_file(args.metadata_file)

    n_datasets = len(ids)

    disp_models = []
    beta_priors = np.ones((n_datasets, 2))

    # Load and parse input files
    for i in range(n_datasets):

         # Check to make sure that all input files exist
        if not os.path.exists(tabix_files[i]):
            print("Fatal error: TABIX-file %s does not exists!" % tabix_files[i], file = sys.stderr)
            sys.exit(1)

        # Load dispersion model
        if not os.path.exists(disp_model_files[i]):
            print("Fatal error: Dispersion model file %s does not exists!" % disp_model_files[i], file = sys.stderr)
            sys.exit(1)
        else:
            disp_models.append(dispersion.read_dispersion_model(disp_model_files[i]))

        # Load the priors file
        if beta_prior_files[i]:
            if not os.path.exists(beta_prior_files[i]):
                print("Fatal error: Beta-prior file %s does not exists!" % beta_prior_files[i], file = sys.stderr)
                sys.exit(1)
            else:
                with open(beta_prior_files[i], 'r') as f:
                    params = f.readline().strip().split('\t')
                    beta_priors[i,:] = np.array(params, dtype = np.float64)

    # Load intervals file
    intervals = genomic_interval_set(bed.bed3_iterator(open(args.interval_file)))
    n_intervals = len(intervals)
    h = int(n_intervals / 2)

    #
    tmpdir = tempfile.mkdtemp()
    chunk_files = [os.path.join(tmpdir, "chunk%s" % i) for i in range(args.processors)]

    #
    q = mp.JoinableQueue()   
    readers = [ mp.Process(target = read_func, args = (tabix_files, intervals[:h], q)), mp.Process(target = read_func, args = (tabix_files, intervals[h:], q)) ]
    processors = [ mp.Process(target = process_func, args = (disp_models, beta_priors, q, f, args.fdr_cutoff, -np.log(args.post_cutoff))) for f in chunk_files ]

    [reader.start() for reader in readers]
    [processor.start() for processor in processors]

    logger.info("Working (%d threads; chunked results in %s)" % (len(chunk_files), tmpdir))

    [reader.join() for reader in readers]

    logger.info("Finishing reading data -- waiting for final processing...")

    q.join() # Wait for queue to unblock
    [q.put(None) for i in range(len(chunk_files))] # sends a message to kill processing threads	
    [processor.join() for processor in processors] # wait for threads to return

    logger.info("Merging data...")

    for file in chunk_files: 
        with open(file, 'r') as handle:
            for line in handle:
                sys.stdout.write(line)

    logger.info("Cleaning up...")

    shutil.rmtree(tmpdir)

    return 0

if __name__ == "__main__":

    sys.exit(main())

