#!/usr/bin/env python3

# Copyright (C) 2015-2021 Jeff Vierstra (jvierstra@altius.org)

import sys
import os
import logging
import random
import math

from argparse import ArgumentParser, Action, ArgumentError, ArgumentTypeError

import multiprocessing as mp
from tqdm import tqdm

import numpy as np

import pysam

from genome_tools import bed, genomic_interval

import footprint_tools
from footprint_tools import cutcounts
from footprint_tools.modeling import bias, predict, dispersion

logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO)
logger = logging.getLogger(__name__)

class bias_model_action(Action):
    def __call__(self, parser, namespace, values, option_string = None):
        try:
            setattr(namespace, self.dest, bias.kmer_model(values[0]))
        except IOError as e:
             raise ArgumentError(self, str(e))

def paired_ints(arg):
	try:
		fw, rev = list(map(int, arg.split(',')))
		return (fw, rev)
	except:
		raise ArgumentTypeError("Offset argument must be a in the format of int,int")

def parse_options(args):

    parser = ArgumentParser(description = "Learn a negative binomial dispersion model from data corrected for intrinsic sequence preference (version {})".format(footprint_tools.__version__),
                epilog="Written by Jeff Vierstra (jvierstra@altius.org). See http://github.com/jvierstra/footprint-tools for extended documentation.")

    parser.add_argument("bam_file", metavar = "bam_file", type = str,
                        help = "Path to BAM-format tag sequence file")

    parser.add_argument("fasta_file", metavar = "fasta_file", type = str, 
                        help = "Path to genome FASTA file (requires associated"
                        " FASTA index in same folder; see documentation on how"
                        " to create an index)")

    parser.add_argument("interval_file", metavar = "interval_file", type = str, 
                        help = "File path to BED file")

    grp_bm = parser.add_argument_group("Bias modeling options")

    grp_bm.add_argument("--bm", metavar = "MODEL_FILE", dest = "bias_model", 
                        nargs = 1, action = bias_model_action, default = bias.uniform_model(),
                        help = "Use a k-mer model for local bias (supplied by file). If"
                        " argument is not provided the model defaults to uniform sequence"
                        " bias.")

    grp_bm.add_argument("--half-win-width", metavar = "N", type = int, default = 5,
                        help = "Half window width to apply bias model."
                        " (default: %(default)s)")

    grp_rf = parser.add_argument_group("Read filtering options")

    grp_rf.add_argument("--min-qual", metavar = "N", type = int,
                        dest = "min_qual", default = 1,
                        help = "Ignore reads with mapping quality lower than this threshold."
                        " (default: %(default)s)")

    grp_rf.add_argument("--remove-dups", action = "store_true",
                        dest = "remove_dups", help = "Remove duplicate reads from analysis"
                        " (SAM flag -- 512)")

    grp_rf.add_argument("--remove-qcfail", action = "store_true",
                        dest = "remove_qcfail", help = "Remove QC-failed reads from analysis"
                        " (SAM flag -- 1024)")

    grp_ot = parser.add_argument_group("Other options")

    grp_ot.add_argument("--bam-offset", metavar = "N", type = paired_ints,
                        dest = "bam_offset", default = (0, -1),
                        help = "BAM file offset (enables support for other datatypes -- e.g. Tn5/ATAC)"
                        " (default: %(default)s)")

    grp_ot.add_argument("--processors", metavar = "N", type = int,
                        dest = "processors", default = mp.cpu_count(),
                        help = "Number of processors to use."
                        " (default: all available processors)")

    return parser.parse_args(args)



class process_callback(object):
	"""
	Class that aggregates results from processor functions
	In this case, it just sums up the 2d histogram matricies
	"""

	def __init__(self, hist_size):
		self.x = np.zeros(hist_size)

	def __call__(self, other):
		self.x += other

def process_func(bam_file, fasta_file, bm, intervals, hist_size, proc_id, **kwargs):
	"""
	Function that can be used to parallize reading 
	the BAM files and computing expected counts
	"""

	bam_kwargs = { k:kwargs.pop(k) for k in ["min_qual", "remove_dups", "remove_qcfail", "offset"] }
	bam_reader = cutcounts.bamfile(bam_file, **bam_kwargs)
	fasta_reader = pysam.FastaFile(fasta_file)

	predictor = predict.prediction(bam_reader, fasta_reader, bm, **kwargs)

	res_hist = np.zeros(hist_size)

	progress_desc = "Chunk {}".format(proc_id)

	with tqdm(total=len(intervals), desc=progress_desc, position=proc_id+1, ncols=80) as progress_bar:

		for interval in intervals:

			obs, exp, win = predictor.compute(interval)
			
			obs = obs['+'][1:] + obs['-'][:-1]
			exp = exp['+'][1:] + exp['-'][:-1]
		
			for o, e in zip(obs, exp):
				try:
					res_hist[int(e), int(o)] += 1.0
				except IndexError:
					pass

			progress_bar.update(1)

	return res_hist

def chunkify(l, nchunks):
	"""Splits an iterable list into n chunks"""
	return [l[i::nchunks] for i in range(nchunks)]

def main(argv = sys.argv[1:]):

	args = parse_options(argv)

	hist_size = (200, 1000) # hard coded histogram size -- for now...
	hist_agg = process_callback(hist_size)

	intervals = list(bed.bed3_iterator(open(args.interval_file)))

	logger.info("BED file contains {:,} regions".format(len(intervals)))

	proc_kwargs = {
		"min_qual": args.min_qual,
		"remove_dups": args.remove_dups,
		"remove_qcfail": args.remove_qcfail,
		"offset": args.bam_offset,
		"half_window_width": args.half_win_width,
		}

	pool = mp.Pool(args.processors)

	logger.info("Using {} threads to computed expected cleavage counts".format(args.processors))

	for i, chunk in enumerate(chunkify(intervals, args.processors)):
		pool.apply_async(process_func, args=(args.bam_file, args.fasta_file, args.bias_model, chunk, hist_size, i), kwds=proc_kwargs, callback=hist_agg)

	pool.close()
	pool.join()

	logger.info("Finished computing expected cleavage counts!")

	# Learn model from histogram

	logger.info("Learning dispersion model")

	model = dispersion.learn_dispersion_model(hist_agg.x)

	# Write model

	model_file = os.path.abspath(os.path.join(os.getcwd(), "dm.json"))

	logger.info("Writing dispersion model to {}".format(model_file))

	with open(model_file, "w") as f:
		print(dispersion.write_dispersion_model(model), file = f)

	# Success!
	return 0

if __name__ == "__main__":
	sys.exit(main())
