#!/usr/bin/env python3

# Copyright (C) 2015-2021 Jeff Vierstra (jvierstra@altius.org)

import sys, logging, os, os.path
from argparse import ArgumentParser
import multiprocessing as mp

import numpy as np
import pandas as pd

import scipy.stats
import pysam

from genome_tools import bed, genomic_interval, genomic_interval_set

import footprint_tools
from footprint_tools.modeling import dispersion
from footprint_tools.stats import bayesian, segment

logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO)
logger = logging.getLogger(__name__)

def parse_options(args):

    parser = ArgumentParser(description = "Compute the posterior probability of cleavage data (version {})".format(footprint_tools.__version__),
                epilog="Written by Jeff Vierstra (jvierstra@altius.org). See http://github.com/jvierstra/footprint-tools for extended documentation.")

    parser.add_argument("metadata_file", metavar = "metadata_file", type = str,
                        help = "Path to configuration file containing metadata for samples")

    parser.add_argument("interval_file", metavar = "interval_file", type = str, 
                        help = "File path to BED file")

    grp_st = parser.add_argument_group("Statistical options")

    grp_st.add_argument("--fdr-cutoff", metavar = "N", type = float,
                        dest = "fdr_cutoff", default = 0.05,
                        help = "Only consider nucleotides with a minimum FDR <= this value."
                        " (default: %(default)s)")

    grp_o = parser.add_argument_group("Output options")

    grp_o.add_argument("--post-cutoff", metavar = "N", type = float,
                        dest = "post_cutoff", default = 0.2,
                        help = "Only output nucleotides with posterior probability <= this value in 1 or more samples"
                        " (default: %(default)s)")

    grp_ot = parser.add_argument_group("Other options")

    grp_ot.add_argument("--processors", metavar = "N", type = int,
                        dest = "processors", default = max(1, mp.cpu_count()-3),
                        help = "Number of processors to use. Note that value excludes the"
                        " minimum 3 threads that are dedicated to data I/O"
                        " (default: all available processors)")

    return parser.parse_args(args)

def read_func(tabix_files, intervals, queue):
    """
    Reads TABIX files and outputs to a multiprocessing pool queue
    """

    tabix_handles = [pysam.TabixFile(f) for f in tabix_files]
    n_datasets = len(tabix_handles)

     # Write to input queue
    for interval in intervals:

        l = len(interval)

        obs = np.zeros((n_datasets, l), dtype = np.float64)
        exp = np.zeros((n_datasets, l), dtype = np.float64)
        fdr = np.ones((n_datasets, l), dtype = np.float64)
        w = np.zeros((n_datasets, l), dtype = np.float64)

        i = 0
        j = 0
    
        for tabix in tabix_handles:

            try:
                for row in tabix.fetch(interval.chrom, interval.start, interval.end, parser = pysam.asTuple()):
                    j = int(row[1])-interval.start
                    exp[i, j] = np.float64(row[3])
                    obs[i, j] = np.float64(row[4])
                    fdr[i, j] = np.float64(row[7])
                    w[i, j] = 1
            except:
                pass

            i += 1

        queue.put( (interval, exp, obs, fdr, w) )

        # Stop memory from getting out of control in the processing
        while queue.qsize() > 100:
            pass

    [handle.close() for handle in tabix_handles]

def process_func(disp_models, beta_priors, read_q, write_q, fdr_cutoff=0):
    """
    
    """
    
    while True:

        data = read_q.get()

        if data == None:
            read_q.task_done()
            break

        (interval, exp, obs, fdr, w) = data

        prior = bayesian.compute_prior_weighted(fdr, w, cutoff = fdr_cutoff)   
        scale = bayesian.compute_delta_prior(obs, exp, fdr, beta_priors, cutoff = fdr_cutoff)

        ll_on = bayesian.log_likelihood(obs, exp, disp_models, delta = scale, w = 3) 
        ll_off = bayesian.log_likelihood(obs, exp, disp_models, w = 3)

        # Compute posterior
        post = -bayesian.posterior(prior, ll_on, ll_off)
        post[post <= 0] = 0.0

        read_q.task_done()

        write_q.put( (interval, post) )

def write_func(q, total, log_post_cutoff=0):
    """
    
    """
    
    handle = sys.stdout

    progress_desc="Regions processed"
    with tqdm(total=total, desc=progress_desc, ncols=80) as progress_bar:

        while True:

            data = q.get()
            if data == None:
                q.task_done()
                break

            interval, post = data
            
            # Write ouput; filter out positions by posterior cutoff
            for i in np.where(np.nanmax(post, axis = 0) > log_post_cutoff)[0]:
                print("{}\t{:d}\t{:d}\t".format(interval.chrom, interval.start+i, interval.start+i+1) + '\t'.join(map(str, post[:,i])), file = handle)

            q.task_done()

            progress_bar.update(1)

def chunkify(l, nchunks):
    """
    Splits an iterable list into n chunks
    """
    
    return [l[i::nchunks] for i in range(nchunks)]

def main(argv = sys.argv[1:]):

    args = parse_options(argv)

    logger.info("Reading metadata file and verifying inputs")

    metadata = pd.read_table(args.metadata_file, header=0, comment='#')

    disp_models = []
    beta_priors = np.ones((len(metadata), 2))

    # Load and parse input files
    for sample in metadata.itertuples():
        if not os.path.exists(sample["tabix_file"]):
            logger.critical("Fatal error: TABIX-file ({}) does not exists!".format(sample["tabix_file"]))
            sys.exit(1)

        if not os.path.exists(sample["dispersion_model"]):
            logger.critical("Fatal error: Dispersion model ({}) does not exists!".format(sample["dispersion_model"]))
            sys.exit(1)
        else:
            disp_models.append(dispersion.read_dispersion_model(sample["dispersion_model"]))
       
        if not os.path.exists(sample["beta_prior_file"]):
            logger.critical("Fatal error: Beta-prior file ({}) does not exists!".format(sample["beta_prior_file"]))
            sys.exit(1)
        else:
            with open(sample["beta_prior_file"], 'r') as f:
                params = f.readline().strip().split('\t')
                beta_priors[i,:] = np.array(params, dtype = np.float64)

    # Load intervals file
    intervals = list(bed.bed3_iterator(open(args.interval_file)))
    
    logger.info("BED file contains {:,} regions".format(len(intervals)))
    
    # Processing queues
    read_q = mp.JoinableQueue()
    write_q = mp.JoinableQueue()

    #
    read_procs = []
    for i, chunk in enumerate(chunkify(intervals, 2))):
        p = mp.Process(target=read_func, args=(metadata["tabix_file"].tolist(), chunk, read_q))
        read_procs.append(p)

    logger.info("Using {} threads to read footprint statistics".format(len(read_procs)))
    [p.start() for p in read_procs]

    #
    process_procs = []
    process_kwargs = {
        "fdr_cutoff": args.fdr_cutoff,
    }

    for i in range(args.processors):
        p = mp.Process(target=process_func, args=(disp_models, beta_priors, read_q, write_q), kwargs=process_kwargs)
        process_procs.append(p)

    logger.info("Using {} threads to compute posteriors".format(len(process_procs)))
    [p.start() for p in process_procs]

    #
    write_kwargs = {
        "log_post_cutoff": -np.log(args.post_cutoff)
    }
    write_proc = mp.Process(target=write_func, args=(write_q, len(intervals)), kwargs=write_kwargs)
    write_proc.start()

    try:
        # Wait for readers to finish
        [p.join() for p in read_procs]

        # Block until all remaining regions are processed
        read_q.join() # wait till read queue is empty

        # Send a message to kill processing threads	and block until they exit
        [read_q.put(None) for i in range(len(process_procs))]
        [p.join() for p in process_procs]
        
        # Block until all remaining regions are written to output and thread exits
        write_q.join() # wait till write queue is empty
        write_q.put(None) # sends kill signal
        write_proc.join() # block until thread exits

        logger.info("Finished computing and writing footprint posteriors!")

    except KeyboardInterrupt:
        [p.terminate() for p in read_procs]
        [p.terminate() for p in process_procs]
        write_proc.terminate()

    return 0

if __name__ == "__main__":

    sys.exit(main())

