#!/usr/bin/env python3

import sys
from collections import namedtuple

import argparse
import pysam

from smallgenomeutilities.__mapper_impl__ import convert_from_intervals_to_list, find_interval_on_dest

__author__ = "David Seifert"
__copyright__ = "Copyright 2017"
__credits__ = "David Seifert"
__license__ = "GPL2+"
__maintainer__ = "Susana Posada Cespedes"
__email__ = "hiv-cbg@bsse.ethz.ch"
__status__ = "Development"

parser = argparse.ArgumentParser()
parser.add_argument(
    "-t", dest="TO", help="Name of target contig, e.g. HXB2:2253-2256", metavar="dest", required=True)
parser.add_argument("-i", dest="INPUT",
                    help="Input SAM/BAM file", metavar="input", required=True)
parser.add_argument("-o", dest="OUTPUT", help="Output TSV file",
                    metavar="output", required=True)
parser.add_argument("-m", dest="MSA", help="MSA of contigs",
                    metavar="msa_file", required=True)
parser.add_argument("--select", dest="SELECT_CONTIG", help="Name of contig that is of interest", metavar="contig",
                    required=True)
args = parser.parse_args()

TO_CONTIG = args.TO.split(':', 1)[0]
COORDINATES = args.TO.split(':', 1)[1]
MSA_FILE = args.MSA

INPUT = args.INPUT
OUTPUT = args.OUTPUT
SELECT_CONTIG = args.SELECT_CONTIG

# 1. load SAM/BAM
samfile = pysam.AlignmentFile(INPUT, "rb")

# 2. Create source -> dist map
contig_loci_map = {}
contig_min_max = {}
contig_start_end = namedtuple("contig_start_end", "min, max")
for contig in samfile.references:
    result = convert_from_intervals_to_list(find_interval_on_dest(
        contig, TO_CONTIG, COORDINATES, 0, MSA_FILE, False))
    contig_loci_map[str(contig)] = dict.fromkeys(result, 0)
    contig_min_max[str(contig)] = contig_start_end(
        min=result[0], max=result[-1] + 1)

if not SELECT_CONTIG in samfile.references:
    sys.exit("{} is not a valid config in {}".format(SELECT_CONTIG, INPUT))

# 3. iterate over reads, adding to coverage
AlignedReads = 0
UnalignedReads = 0

for contig in contig_loci_map:
    for read in samfile.fetch(reference=contig, start=contig_min_max[contig].min, end=contig_min_max[contig].max):
        if read.is_unmapped:
            UnalignedReads += 1
        else:
            AlignedReads += 1

            if not read.reference_name in contig_loci_map:
                sys.exit("Read '{}' has contig '{}', which is not in MSA file '{}'".format(read.query_name,
                                                                                           read.reference_name,
                                                                                           MSA_FILE))

            for i in range(read.reference_start, read.reference_end):
                if i in contig_loci_map[read.reference_name]:
                    contig_loci_map[read.reference_name][i] += 1

# 4. Tally up final stats
FinalStats = {}
for contig, loci in contig_loci_map.items():
    Sum = 0
    for locus, coverage in loci.items():
        Sum += coverage
    Sum /= len(loci)

    FinalStats[contig] = int(Sum)

# 5. Print statistics
with open(OUTPUT, "w") as output:
    # header
    output.write("Name\tTarget")
    for i in sorted(contig_loci_map):
        if i != SELECT_CONTIG:
            output.write("\t{}".format(i))
    output.write("\n")

    # data
    output.write(
        "{}\t{}".format(SELECT_CONTIG, FinalStats[SELECT_CONTIG]))
    for i in sorted(contig_loci_map):
        if i != SELECT_CONTIG:
            output.write("\t{}".format(FinalStats[i]))
    output.write("\n")
