#!/usr/bin/env python3

# Copyright (C) 2015-2021 Jeff Vierstra (jvierstra@altius.org)

import sys
import os
import logging
import random
import math

from argparse import ArgumentParser, Action, ArgumentError

import multiprocessing as mp
from tqdm import tqdm

import numpy as np

import pysam

from genome_tools import bed, genomic_interval

import footprint_tools
from footprint_tools import cutcounts
from footprint_tools.modeling import bias, predict, dispersion
from footprint_tools.stats import fdr, windowing

logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO)
logger = logging.getLogger(__name__)

class bias_model_action(Action):
    def __call__(self, parser, namespace, values, option_string = None):
        try:
            setattr(namespace, self.dest, bias.kmer_model(values[0]))
        except IOError as e:
             raise ArgumentError(self, str(e))

class dispersion_model_action(Action):
    def __call__(self, parser, namespace, values, option_string = None):
        try:
            setattr(namespace, self.dest, dispersion.read_dispersion_model(values[0]))
        except IOError as e:
             raise ArgumentError(self, str(e))

def paired_ints(arg):
    """
    Parse a pair of integers delimited by a comma (used by argparser)
    """

    try:
        fw, rev = list(map(int, arg.split(',')))
        return (fw, rev)
    except:
        raise ArgumentTypeError("Offset argument must be a in the format of int,int")

def parse_options(args):

    parser = ArgumentParser(description = "Compute the per-nucleotide cleavage deviation statistics (version {})".format(footprint_tools.__version__),
                epilog="Written by Jeff Vierstra (jvierstra@altius.org). See http://github.com/jvierstra/footprint-tools for extended documentation.")

    parser.add_argument("bam_file", metavar = "bam_file", type = str,
                        help = "Path to BAM-format tag sequence file")

    parser.add_argument("fasta_file", metavar = "fasta_file", type = str, 
                        help = "Path to genome FASTA file (requires associated"
                        " FASTA index in same folder; see documentation on how"
                        " to create an index)")

    parser.add_argument("interval_file", metavar = "interval_file", type = str, 
                        help = "File path to BED file")

    grp_bm = parser.add_argument_group("Bias modeling options")

    grp_bm.add_argument("--bm", metavar = "MODEL_FILE", dest = "bias_model", 
                        nargs = 1, action = bias_model_action, default = bias.uniform_model(),
                        help = "Use a k-mer model for local bias (supplied by file). If"
                        " argument is not provided the model defaults to uniform sequence"
                        " bias.")

    grp_bm.add_argument("--half-win-width", metavar = "N", type = int, default = 5,
                        help = "Half window width to apply bias model."
                        " (default: %(default)s)")

    grp_sm = parser.add_argument_group("Smoothing options")
    
    grp_sm.add_argument("--smooth-half-win-width", metavar = "N", type = int, default = 50,
                        help = "Half window width to apply smoothing model. When set to"
                        " zero no smoothing is applied. (default: %(default)s)")

    grp_sm.add_argument("--smooth-clip", metavar = "N", type = float, default = 0.01,
                        help = "Fraction of signal to clip when computing trimmed mean."
                        " (default: %(default)s)")

    grp_st = parser.add_argument_group("Statistics options")

    grp_st.add_argument("--dm", nargs = 1, metavar = "MODEL_FILE", 
                        dest = "dispersion_model", action = dispersion_model_action, default = None,
                        help = "Dispersion model for negative binomial tests. If argument"
                        " is not provided then no stastical output is provided. File is in"
                        " JSON format and generated using the 'ftd-learn-dispersion-model'"
                        " script included in the software package.")

    # Not yet implemented
    # grp_st.add_argument("--fdr_cutoffs", nargs = 1, metavar = "[N, ...]", dest = "fdr_cutoffs",
    #                     action = list_parse_action, default = [],
    #                     help = "FDR cutoff at which to report footprints.")

    grp_st.add_argument("--fdr-shuffle-n", metavar = "N", type = int,
                        dest = "fdr_shuffle_n", default = 50,
                        help = "Number of times to shuffle data for FDR calculation."
                        " (default: %(default)s)")

    grp_rf = parser.add_argument_group("Read filtering options")

    grp_rf.add_argument("--min-qual", metavar = "N", type = int,
                        dest = "min_qual", default = 1,
                        help = "Filter reads with mapping quality lower than this threshold."
                        " (default: %(default)s)")

    grp_rf.add_argument("--remove-dups", action = "store_true",
                        dest = "remove_dups", help = "Remove duplicate reads from analysis"
                        " (SAM flag -- 512)")
    grp_rf.add_argument("--remove-qcfail", action = "store_true",
                        dest = "remove_qcfail", help = "Remove QC-failed reads from analysis"
                        " (SAM flag -- 1024)")

    grp_ot = parser.add_argument_group("Other options")

    grp_ot.add_argument("--bam-offset", metavar = "N", type = paired_ints,
                        dest = "bam_offset", default = (0, -1),
                        help = "BAM file offset (support for legacy BAM/SAM format)"
                        " (default: %(default)s)")

    grp_ot.add_argument("--seed", metavar = "N", type = int, dest = "seed",
                        default = None, help = "Seed for random number generation"
                        " (default: no seed)")

    grp_ot.add_argument("--processors", metavar = "N", type = int,
                        dest = "processors", default = max(1, mp.cpu_count()-1),
                        help = "Number of processors to use."
                        " (default: all available processors)")

    return parser.parse_args(args)

def read_func(bam_file, fasta_file, bm, dm, intervals, q, **kwargs):
    """
    Reads BAM file, computes expected cleavages and associated statistics and outputs to a multiprocessing pool queue
    """

    bam_kwargs = { k:kwargs.pop(k) for k in ["min_qual", "remove_dups", "remove_qcfail", "offset"] }
    bam_reader = cutcounts.bamfile(bam_file, **bam_kwargs)
    fasta_reader = pysam.FastaFile(fasta_file)

    predict_kwargs = { k:kwargs.pop(k) for k in ["half_window_width", "smoothing_half_window_width", "smoothing_clip"] }
    predictor = predict.prediction(bam_reader, fasta_reader, bm, **predict_kwargs)

    # args used for FDR sampling procedure
    seed = kwargs.pop("seed")
    if seed:
        random.seed(seed)
        np.random.seed(seed)
    fdr_shuffle_n = kwargs.pop("fdr_shuffle_n") 

    win_pvals_func = lambda z: windowing.stouffers_z(np.ascontiguousarray(z), 3)
    
    # Read and process each region
    for interval in intervals:

        obs, exp, win = predictor.compute(interval)

        obs = obs['+'][1:] + obs['-'][:-1]
        exp = exp['+'][1:] + exp['-'][:-1]

        n = len(obs)

        if dm:

            try:

                pvals = dm.p_values(exp, obs)
                counts_null, pvals_null = dm.resample_p_values(exp, fdr_shuffle_n)

                win_pvals = win_pvals_func(pvals)
                win_pvals_null = np.apply_along_axis(win_pvals_func, 0, pvals_null)
                
                efdr = fdr.emperical_fdr(win_pvals_null, win_pvals)

            except Exception as e:

                pvals = win_pvals = efdr = np.ones(n)

            finally:

                stats = np.column_stack((exp, obs, -np.log(pvals), -np.log(win_pvals), efdr))
        else:

            stats = np.column_stack((exp, obs))

        q.put( (interval, stats) )

        while q.qsize() > 100:
            pass

def write_func(q, total):
    """
    Function to write output from multiple threads to a single file
    """    

    f = sys.stdout

    progress_desc="Regions processed"
    with tqdm(total=total, desc=progress_desc, ncols=80) as progress_bar:
    
        while True:

            data = q.get()
            if data == None:
                q.task_done()
                break

            interval, stats = data

            for i in range(stats.shape[0]):
                coords = "{}\t{:d}\t{:d}".format(interval.chrom, interval.start+i, interval.start+i+1)
                val_string = "\t".join( ["{:0.4f}".format(val) for val in stats[i,:]])
                print(coords + "\t" + val_string, file=f)

            q.task_done()

            progress_bar.update(1)

def chunkify(l, nchunks):
    """
    Splits an iterable list into n chunks
    """
    
    return [l[i::nchunks] for i in range(nchunks)]

def main(argv = sys.argv[1:]):

    args = parse_options(argv)

    intervals = list(bed.bed3_iterator(open(args.interval_file)))
    
    logger.info("BED file contains {:,} regions".format(len(intervals)))

    proc_kwargs = {
        "min_qual": args.min_qual,
        "remove_dups": args.remove_dups,
        "remove_qcfail": args.remove_qcfail,
        "offset": args.bam_offset,
        "half_window_width": args.half_win_width,
        "smoothing_half_window_width": args.smooth_half_win_width,
        "smoothing_clip": args.smooth_clip,
        "fdr_shuffle_n": args.fdr_shuffle_n,
        "seed": args.seed
        }

    q = mp.JoinableQueue()

    read_procs = []
    for i, chunk in enumerate(chunkify(intervals, max(1, args.processors-1))):
           p = mp.Process(target=read_func, args=(args.bam_file, args.fasta_file, args.bias_model, args.dispersion_model, chunk, q), kwargs=proc_kwargs)
           read_procs.append(p)

    write_proc = mp.Process(target=write_func, args=(q, len(intervals)))

    logger.info("Using {} threads to compute footprint statistics".format(len(read_procs)))

    [p.start() for p in read_procs]
    write_proc.start()

    try:
        [p.join() for p in read_procs]
        q.join() # block until queue is empty after processing is done
        q.put(None) # send kill signal
        write_proc.join() # wait for writer proc to return
        logger.info("Finished computing and writing footprint statistics!")
    except KeyboardInterrupt:
        [p.terminate() for p in read_procs]
        write_proc.terminate()
        
    return 0
    
if __name__ == "__main__":
    sys.exit(main())
