#!/usr/bin/env python3

import argparse
import pysam
import os
import numpy as np
from multiprocessing import Pool

__author__ = "Susana Posada Cespedes"
__copyright__ = "Copyright 2017"
__credits__ = "Susana Posada Cespedes"
__license__ = "GPL2+"
__maintainer__ = "Susana Posada Cespedes"
__email__ = "hiv-cbg@bsse.ethz.ch"
__status__ = "Development"


def parse_args():
    """ Set up the parsing of command-line arguments """

    parser = argparse.ArgumentParser(description="Script to extract coverage windows for ShoRAH",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("-r", required=False, default=None, metavar='BED', dest='region',
                        help="Region of interested in BED format, e.g. HXB2:2253-3869. Loci are interpreted using 0-based indexing, and a half-open interval is used, i.e, [start:end)")
    parser.add_argument("-cf", required=False, default=None, metavar='TSV', dest='coverage_file',
                        help="File containing coverage per locus per sample. Samples are expected as columns and loci as rows. This option is not compatible with the read-window overlap thresholding")
    parser.add_argument("-c", required=False, default=100, metavar='INT', dest='min_coverage',
                        type=int, help="Minimum read depth per window")
    parser.add_argument("-f", required=False, default=0.85, metavar='FLOAT',
                        dest='window_overlap', help="Threshold on the overlap between each read and the window")
    parser.add_argument("-w", required=False, default='201', metavar='len1,len2,...', dest='window_len',
                        help="Window length used by ShoRAH")
    parser.add_argument("-s", required=False, default='67', metavar='shift1, shift2, ...',
                        dest='window_shift', help="Window shifts used by ShoRAH")
    parser.add_argument("-N", required=False, default=None, metavar='name1,name2,...', dest="patientIDs",
                        help="Patient/sample identifiers as comma separated strings")
    parser.add_argument("-e", required=False, action='store_true', default=False, dest='right_offset',
                        help="Indicate whether to apply a more liberal shift on intervals' right-endpoint")
    parser.add_argument("-t", required=False, default=1, metavar='INT', dest='thrds', type=int,
                        help="Number of threads")
    parser.add_argument("-o", required=False, default="coverage_intervals.tsv", metavar='OUTPUT', dest='outfile',
                        help="Output file name")
    parser.add_argument("input_files", nargs='+',
                        metavar='BAM', help="Input BAM file(s)")

    return parser.parse_args()


def get_intervals_wrapper(args):

    bamfile, cov_thrd, win_thrd, start, end, window_len, window_shift, ref_id, right_offset = args

    if ref_id is not None:
        intervals = get_intervals(args)
    else:
        # Load bam file
        aln_reads = pysam.AlignmentFile(bamfile)
        intervals = []
        for ref in aln_reads.references:
            new_args = [bamfile, cov_thrd, win_thrd, start, end, window_len,
                        window_shift, ref, right_offset]
            intervals.append(get_intervals(new_args))
        # Remove empty strings
        intervals = filter(None, intervals)
        intervals = ','.join(intervals)

    return intervals


def get_intervals(args):

    bamfile, cov_thrd, win_thrd, start, end, window_len, window_shift, ref_id, right_offset = args

    def left_limit(bamfile, ref_len, cov_thrd, start, min_coverage_window, window_len, shift, ref_idx=0):
        cov_window = 0
        while cov_window < cov_thrd:
            if start > ref_len:
                return ref_len
            end = start + window_len
            # Count reads that cover at least 85% of the window
            # count: start, stop denote 0-based, right half-open interval, i.e. [start, stop)
            cov_window = bamfile.count(bamfile.references[ref_idx], start=start, stop=end,
                                       read_callback=lambda read: read.get_overlap(start, end) > min_coverage_window)
            start += shift
        return start - shift

    def right_limit(bamfile, ref_len, cov_thrd, start, min_coverage_window, window_len, shift, ref_idx=0):
        cov_window = cov_thrd
        while cov_window >= cov_thrd:
            if start > ref_len:
                return end
            end = start + window_len
            # Count reads that cover at least 85% of the window
            cov_window = bamfile.count(bamfile.references[ref_idx], start=start, stop=end,
                                       read_callback=lambda read: read.get_overlap(start, end) > min_coverage_window)
            #print("{}-{}:{}".format(start, end, cov_window))
            start += shift
        return end - shift

    # Threshold on the overlap btw each read and the window
    min_coverage_window = int(win_thrd * window_len)

    # Load bam file
    aln_reads = pysam.AlignmentFile(bamfile)

    if end is not None:
        ref_len = end - start + 1
    else:
        ref_len = aln_reads.get_reference_length(ref_id)
    ref_idx = aln_reads.get_tid(ref_id)
    assert aln_reads.is_valid_tid(
        ref_idx), "unknown reference {}".format(ref_id)

    # If coverage threshold is set to zero, then return the full region
    if cov_thrd == 0:
        intervals = "{}:{}-{}".format(ref_id, 1, ref_len)
    else:
        left = []
        right = []
        while start < ref_len:
            start = left_limit(aln_reads, ref_len, cov_thrd, start, min_coverage_window,
                               window_len, window_len // window_shift * 4, ref_idx)
            if start < ref_len:
                left.append(start)
                end = right_limit(aln_reads, ref_len, cov_thrd, start,
                                  min_coverage_window, window_len, window_len // window_shift, ref_idx)
                right.append(end)
                start = end - window_len + (window_len // window_shift * 4)

        # NOTE: ShoRAH (shotgun mode) adds and subtracts 3 * (window-length /
        #       window-shift) to the limits of the target region, respectively.
        #       This is done to ensure that every locus in the target region is
        #       covered by 3 window. The window-shift is set, by default, to 3.
        #       Meaning that, overlapping windows are shifted by window-length /
        #       3. Below, we account for such offset.
        left = np.array(left)
        right = np.array(right)

        # Add offset to starting position.
        left += window_len

        # Subtract offset to ending position.
        if right_offset:
            # Often the last position with high coverage is not covered by any
            # of the windows
            right -= int(window_shift * 2)
        else:
            right -= window_len

        # Check that after offsets the interval remains valid, i.e., right end is larger than the left end
        interval_lengths = right - left
        mask = interval_lengths < 0
        left = left[~mask]
        right = right[~mask]

        intervals = ','.join("{}:{}-{}".format(ref_id, x, y)
                             for x, y in zip(left, right))

    return intervals


def nonzero_intervals(x, offset, start=None, right_offset=False, window_shift=67):

    if len(x) == 0:
        return 0
    elif not isinstance(x, np.ndarray):
        x = np.array(x)

    edges, = np.nonzero(np.diff(x == 0))
    intervals = [edges + 1]

    if x[0] != 0:
        intervals.insert(0, [0])
    if x[-1] != 0:
        intervals.append([x.size])
    intervals = np.concatenate(intervals)

    if start is not None:
        intervals += start

    intervals_start = intervals[::2]
    intervals_end = intervals[1::2]

    # NOTE: ShoRAH (shotgun mode) adds and subtracts 3 * (window-length /
    #       window-shift) to the limits of the target region, respectively.
    #       This is done to ensure that every locus in the target region is
    #       covered by 3 window. The window-shift is set, by default, to 3.
    #       Meaning that, overlapping windows are shifted by window-length /
    #       3. In this settings the ofsset is equivalent to the window length
    # Add offset to starting position.
    intervals_start += offset

    # Subtract offset to ending position.
    if right_offset:
        # Often the last position with high coverage is not covered by any
        # of the windows
        intervals_end -= int(window_shift * 2)
    else:
        intervals_end -= offset

    # Check that after offsets the interval remains valid, i.e., right end is larger than the left end
    length = intervals_end - intervals_start

    mask = length < 0
    intervals_start = intervals_start[~mask]
    intervals_end = intervals_end[~mask]
    intervals = np.vstack((intervals_start, intervals_end)).T

    return intervals


def main():
    args = parse_args()

    # Get name of the reference
    reference_name = None
    if args.region is not None:
        aux = args.region.split(":")
        reference_name = aux[0]
        aux = aux[1].split('-')
        start = int(aux[0])
        end = int(aux[1])
    else:
        start = 0
        end = None

    num_samples = len(args.input_files)
    window_len = args.window_len.split(",")
    assert len(
        window_len) == num_samples, 'Number of input values do not match number of input files.'

    window_shift = args.window_shift.split(",")
    assert len(
        window_shift) == num_samples, 'Number of input values do not match number of input files.'

    if args.coverage_file is not None:
        # Load input file
        coverage = np.loadtxt(args.coverage_file, dtype=int,
                              delimiter='\t', skiprows=1)
        loci = coverage[:, 0]
        coverage = coverage[:, 1:]
        assert coverage.shape[1] == num_samples, 'Number of columns in the coverage file do not match number of input files.'

        if args.patientIDs is None:
            # Read patient identifiers from the input file
            with open(coverage_file, 'r') as fin:
                first_line = fin.readline()

            patientIDs = first_line.rstrip().split()
            # Comments are output with a hashtag. Remove hashtag and identifier of first
            # column, which corresponds to positions
            patientIDs = patientIDs[2:]
            if len(patientIDs) != num_samples:
                patientIDs = np.arange(num_samples)
        else:
            patientIDs = args.patientIDs.split(",")

        assert len(
            patientIDs) == num_samples, 'Number of patient/sample identifiers do not match number of input files.'

        with open(args.outfile, "wt") as outfile:
            for idx in range(num_samples):
                if reference_name is None:
                    # Load bam file
                    aln_reads = pysam.AlignmentFile(args.input_files[idx])
                    reference_name = aln_reads.references[0]
                    reference_len = aln_reads.get_reference_length(
                        reference_name)

                # Identify samples with coverage below threshold and discard those read
                # counts
                if args.min_coverage == 0:
                    outfile.write("{}\t{}\n".format(patientIDs[idx],
                                                    "{}:{}-{}".format(reference_name, 1, reference_len)))
                else:
                    mask = coverage[:, idx] < args.min_coverage
                    coverage[mask, idx] = 0
                    intervals = nonzero_intervals(
                        coverage[:, idx], int(window_len[idx]), loci[0], right_offset=args.right_offset, window_shift=int(window_shift[idx]))

                    outfile.write("{}\t{}\n".format(patientIDs[idx],
                                                    ','.join("{}:{}-{}".format(reference_name, x[0], x[1]) for x in intervals)))
    else:
        if args.patientIDs is None:
            patientIDs = np.arange(num_samples)
        else:
            patientIDs = args.patientIDs.split(",")

        assert len(
            patientIDs) == num_samples, 'Number of patient/sample identifiers do not match number of input files.'

        args_list = [(args.input_files[idx], args.min_coverage, args.window_overlap, start, end, int(
            window_len[idx]), int(window_shift[idx]), reference_name, args.right_offset) for idx in range(num_samples)]

        pool = Pool(processes=args.thrds)
        res = pool.map(get_intervals_wrapper, args_list)
        pool.close()
        pool.join()

        with open(args.outfile, "wt") as outfile:
            for idx in range(num_samples):
                outfile.write("{}\t{}\n".format(patientIDs[idx], res[idx]))


if __name__ == '__main__':
    main()
