#!/usr/bin/which python3

# Copyright (C) 2015-2019 Jeff Vierstra (jvierstra@altius.org)

from __future__ import print_function, division

from __future__ import absolute_import
import sys, os, logging
from argparse import ArgumentParser
import multiprocessing as mp
from functools import partial

import numpy as np
import scipy.stats

import pysam


logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO)
logger = logging.getLogger(__name__)

def parse_options(args):

    parser = ArgumentParser(description = "Learn the genome-wide beta model prior for nucleotide protection")

    parser.add_argument("tabix_file", metavar = "tabix_file", type = str,
                        help = "Path to TABIX-format file (e.g., ouput"
                        " from 'ftd-compute-deviation')")

    grp_m = parser.add_argument_group("Model options")

    grp_m.add_argument("--fdr-cutoff", metavar = "N", type = float,
                        dest = "fdr_cutoff", default = 0.05,
                        help = "Only consider nucleotides with FDR <= this value."
                        " (default: %(default)s)")

    grp_m.add_argument("--exp-cutoff", metavar = "N", type = int,
                        dest = "exp_cutoff", default = 10,
                        help = "Only consider nucleotides with expected cleavages >= this value."
                        " (default: %(default)s)")

    return parser.parse_args(args)

##

def compute_beta_prior(filename, fdr_cutoff, exp_cutoff):

	handle = pysam.TabixFile(filename)

	ratios = []

	for row in handle.fetch(parser = pysam.asTuple()):

		exp = np.float64(row[3])
		obs = np.float64(row[4])
		fdr = np.float64(row[7])

		if fdr <= fdr_cutoff and exp >= exp_cutoff:
			ratios.append( (obs+1)/(exp+1) )

	handle.close()

	ratios = np.array(ratios)
	(a, b) = scipy.stats.beta.fit(ratios[np.isfinite(ratios) & (ratios>0) & (ratios<1)], floc = 0, fscale = 1)[0:2]

	return [a, b]

def main(argv = sys.argv[1:]):

	args = parse_options(argv)


	prior = compute_beta_prior(args.tabix_file, fdr_cutoff = args.fdr_cutoff, exp_cutoff = args.exp_cutoff)

	print("%0.4f\t%0.4f" % (prior[0], prior[1]), file = sys.stdout)

	return 0

if __name__ == "__main__":
    sys.exit(main())
