#!/usr/bin/env python3

import gzip
import argparse
from itertools import islice
import numpy as np

__author__ = "Susana Posada Cespedes"
__copyright__ = "Copyright 2017"
__credits__ = "Susana Posada Cespedes"
__license__ = "GPL2+"
__maintainer__ = "Susana Posada Cespedes"
__email__ = "hiv-cbg@bsse.ethz.ch"
__status__ = "Development"


def parse_args():
    parser = argparse.ArgumentParser(description="Script for predicting number of read pairs after trimming",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    requiredNamed = parser.add_argument_group('required named arguments')
    requiredNamed.add_argument("-n", required=True, metavar='STR', dest='sample_id',
                               help="Patient/sample identifier")
    requiredNamed.add_argument("-d", required=True, metavar='DATE', dest='sample_date',
                               help="Sample date")
    parser.add_argument("-w", required=False, default=10, metavar='INT', dest='window_len', type=int,
                        help="Length of the sliding window for emulating read trimming")
    parser.add_argument("-q", "--qual", required=False, default=30, metavar='INT', dest='qual_thrd', type=int,
                        help="Quality threshold for emulating read trimming")
    parser.add_argument("-c", required=False, default=25000, metavar='INT', dest='counts_thrd', type=int,
                        help="Threshold on the read counts")
    parser.add_argument("-l", required=False, default=None, metavar='INT', dest='read_len', type=int,
                        help="Read length. If not specify estimated for every read independently")
    parser.add_argument("-o", required=False, default="output.tsv", metavar='output.tsv', dest='outfile',
                        help="Output file name")
    parser.add_argument("FILES", nargs='+', metavar="FASTQ",
                        help='FASTQ files for forward and reverse reads')

    return parser.parse_args()


def get_nth(iterable, n, after=1):
    if after > 1:
        consume(iterable, after - 1)
    while True:
        yield next(iterable)
        consume(iterable, n - 1)


def consume(iterator, n):
    """
    Advance the iterator n-steps ahead. If n is none, consume entirely.
    Adapted from itertools recipes.
    """
    # Use functions that consume iterators at C speed.
    if n is None:
        # feed the entire iterator into a zero-length deque
        collections.deque(iterator, maxlen=0)
    else:
        # advance to the empty slice starting at position n
        next(islice(iterator, n, n), None)


def trim(qual, slen, window_len, qual_thrd):
    """Emulates read trimming"""

    trimmed_len = 0

    # Trimming read from the 5'-end
    for i in range(0, slen - window_len):
        min_q = min(qual[i:(i + window_len)])
        if min_q >= qual_thrd:
            trimmed_len = slen - i
            break

    if trimmed_len > 0:
        # Trimming read from the 3'-end
        for i in range(slen, window_len - 1, -1):
            min_q = min(qual[(i - window_len):i])
            if min_q >= qual_thrd:
                trimmed_len = trimmed_len - (slen - i)
                break

    return trimmed_len


def main():

    args = parse_args()
    R1 = args.FILES[0]
    R2 = args.FILES[1]

    # For every read pair increment counter if length after trimming is larger
    # than 0.8*read_len
    # Emulating read trimming: use sliding windows of length <window_len> and
    # stop when all bases have quality larger or equal to <qual_thrd>
    read_count = 0
    # Using Phred+33 encoding
    args.qual_thrd = args.qual_thrd + 33

    with gzip.open(R1, "r") as f1, gzip.open(R2, "r") as f2:

        every = (4, 4)

        for line_f1, line_f2 in zip(get_nth(f1, *every), get_nth(f2, *every)):

            qual_f1 = line_f1.rstrip()
            qual_f2 = line_f2.rstrip()
            read_len_f1 = len(qual_f1)
            read_len_f2 = len(qual_f2)

            # Lenght would be zero if all bases are trimmed
            len_f1 = trim(qual_f1, read_len_f1,
                          args.window_len, args.qual_thrd)
            len_f2 = trim(qual_f2, read_len_f2,
                          args.window_len, args.qual_thrd)

            if args.read_len is None:
                if len_f1 >= np.ceil(0.8 * read_len_f1) and len_f2 >= np.ceil(0.8 * read_len_f2):
                    read_count += 1
            else:
                if len_f1 >= np.ceil(0.8 * args.read_len) and len_f2 >= np.ceil(0.8 * args.read_len):
                    read_count += 1

    output = open(args.outfile, "wt")
    if read_count > args.counts_thrd:
        output.write("{patient}\t{date}\n".format(
            patient=args.sample_id, date=args.sample_date))
    else:
        print("Sample {patient} ({date}) reports a read count of {read_count}, discarding".format(
            patient=args.sample_id, date=args.sample_date, read_count=read_count))


if __name__ == "__main__":
    main()
