#!/usr/bin/which python3

# Copyright (C) 2015-2019 Jeff Vierstra (jvierstra@altius.org)

import sys, logging
from argparse import ArgumentParser, Action, ArgumentError
import multiprocessing as mp

import pyfaidx
import numpy as np
import random

from genome_tools import bed, genomic_interval, genomic_interval_set

from footprint_tools import cutcounts
from footprint_tools.modeling import bias, predict, dispersion
from footprint_tools.stats import fdr, windowing
from six.moves import map
from six.moves import range


logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO)
logger = logging.getLogger(__name__)

class kmer_action(Action):
    def __call__(self, parser, namespace, values, option_string = None):
        try:
            setattr(namespace, self.dest, bias.kmer_model(values[0]))
        except IOError as e:
             raise ArgumentError(self, str(e))

class dispersion_model_action(Action):
    def __call__(self, parser, namespace, values, option_string = None):
        try:
            setattr(namespace, self.dest, dispersion.read_dispersion_model(values[0]))
        except IOError as e:
             raise ArgumentError(self, str(e))

def paired_ints(arg):
    try:
        fw, rev = list(map(int, arg.split(',')))
        return (fw, rev)
    except:
        raise ArgumentTypeError("Offset argument must be a in the format of int,int")

def parse_options(args):

    parser = ArgumentParser(description = "Compute the per-nucleotide cleavage deviation statistics")

    parser.add_argument("bam_file", metavar = "bam_file", type = str,
                        help = "Path to BAM-format tag sequence file")

    parser.add_argument("fasta_file", metavar = "fasta_file", type = str, 
                        help = "Path to genome FASTA file (requires associated"
                        " FASTA index in same folder; see documentation on how"
                        " to create an index)")

    parser.add_argument("interval_file", metavar = "interval_file", type = str, 
                        help = "File path to BED file")

    grp_bm = parser.add_argument_group("Bias modeling options")

    grp_bm.add_argument("--bm", metavar = "MODEL_FILE", dest = "bias_model", 
                        nargs = 1, action = kmer_action, default = bias.uniform_model(),
                        help = "Use a k-mer model for local bias (supplied by file). If"
                        " argument is not provided the model defaults to uniform sequence"
                        " bias.")

    grp_bm.add_argument("--half-win-width", metavar = "N", type = int, default = 5,
                        help = "Half window width to apply bias model."
                        " (default: %(default)s)")

    grp_sm = parser.add_argument_group("Smoothing options")
    
    grp_sm.add_argument("--smooth-half-win-width", metavar = "N", type = int, default = 50,
                        help = "Half window width to apply smoothing model. When set to"
                        " zero no smoothing is applied. (default: %(default)s)")

    grp_sm.add_argument("--smooth-clip", metavar = "N", type = float, default = 0.01,
                        help = "Fraction of signal to clip when computing trimmed mean."
                        " (default: %(default)s)")

    grp_st = parser.add_argument_group("Statistics options")

    grp_st.add_argument("--dm", nargs = 1, metavar = "MODEL_FILE", 
                        dest = "dispersion_model", action = dispersion_model_action, default = None,
                        help = "Dispersion model for negative binomial tests. If argument"
                        " is not provided then no stastical output is provided. File is in"
                        " JSON format and generated using the 'ftd-learn-dispersion-model'"
                        " script included in the software package.")

    # TODO: implement
    # grp_st.add_argument("--fdr_cutoffs", nargs = 1, metavar = "[N, ...]", dest = "fdr_cutoffs",
    #                     action = list_parse_action, default = [],
    #                     help = "FDR cutoff at which to report footprints.")

    grp_st.add_argument("--fdr-shuffle-n", metavar = "N", type = int,
                        dest = "fdr_shuffle_n", default = 50,
                        help = "Number of times to shuffle data for FDR calculation."
                        " (default: %(default)s)")

    grp_ot = parser.add_argument_group("ther options")

    grp_ot.add_argument("--remove-dups", action = "store_true",
                        dest = "remove_dups", help = "Remove duplicate reads from analysis"
                        " (SAM flag -- 1024)")

    grp_ot.add_argument("--bam-offset", metavar = "N", type = paired_ints,
                        dest = "bam_offset", default = (0, -1),
                        help = "BAM file offset (support for legacy BAM/SAM format)"
                        " (default: %(default)s)")

    grp_ot.add_argument("--seed", metavar = "N", type = int, dest = "seed",
                        default = None, help = "Seed for random number generation"
                        " (default: no seed)")


    grp_ot.add_argument("--processors", metavar = "N", type = int,
                        dest = "processors", default = mp.cpu_count(),
                        help = "Number of processors to use."
                        " (default: all available processors)")

    return parser.parse_args(args)

def process_func(pred, dm, fdr_shuffle_n, seed=None):
    """Main processing function"""

    # Reseed here, for multiprocessing reasons
    if seed:
        random.seed(seed)
        np.random.seed(seed)

    (obs_counts, exp_counts, win_counts) = pred.compute()
    
    obs = obs_counts['+'][1:] + obs_counts['-'][:-1]
    exp = exp_counts['+'][1:] + exp_counts['-'][:-1]

    if dm:

        try:

            win_pvals_func = lambda z: windowing.stouffers_z(np.ascontiguousarray(z), 3)

            pvals = dm.p_values(exp, obs)
            counts_null, pvals_null = dm.resample_p_values(exp, fdr_shuffle_n)

            win_pvals = win_pvals_func(pvals)
            
            win_pvals_null = np.apply_along_axis(win_pvals_func, 0, pvals_null)
            FDR = fdr.emperical_fdr(win_pvals_null, win_pvals)

            data = np.column_stack((exp, obs, -np.log(pvals), -np.log(win_pvals), FDR))

        except Exception as e:
            
            print(e, file = sys.stderr)

            data = np.column_stack((exp, obs, np.zeros(len(obs)), np.zeros(len(obs)), np.ones(len(obs))))

    else:

        data = np.column_stack((exp, obs))

    return (pred.interval, data)

class process_callback(object):
    """Writer class used as a callback"""

    def __init__(self, filehandle = sys.stdout):
        
        self.filehandle = filehandle

    def __call__(self, res):

        (interval, data) = res

        chrom = interval.chrom
        start = interval.start

        out = '\n'.join(["%s\t%d\t%d\t%s" % (chrom, start + i, start + i +1, '\t'.join(['{:0.4f}'.format(val) for val in data[i,:]])) for i in range(data.shape[0])])

        print(out, file = self.filehandle)

def main(argv = sys.argv[1:]):

    args = parse_options(argv)
    random.seed(args.seed)
    np.random.seed(args.seed)

    reads = cutcounts.bamfile(args.bam_file, remove_dups = args.remove_dups, offset = args.bam_offset)
    fasta = pyfaidx.Fasta(args.fasta_file, one_based_attributes = False, sequence_always_upper = True)
    intervals = bed.bed3_iterator(open(args.interval_file))

    write_func = process_callback()

    pool = mp.Pool(args.processors)

    for interval in genomic_interval_set(intervals):

        region = predict.prediction(reads, fasta, interval, args.bias_model, args.half_win_width, args.smooth_half_win_width, args.smooth_clip)

        pool.apply_async(process_func, args = (region, args.dispersion_model, args.fdr_shuffle_n, args.seed,), callback = write_func)

        while pool._taskqueue.qsize() > 1000:
            pass

    pool.close()
    pool.join()

    return 0
    
if __name__ == "__main__":
    sys.exit(main())
