#!/usr/bin/env python3

import os
import sys
from collections import namedtuple

import Bio.Seq
import argparse
import pysam

from smallgenomeutilities.__mapper_impl__ import convert_from_intervals_to_list, find_interval_on_dest

__author__ = "David Seifert"
__copyright__ = "Copyright 2016-2017"
__credits__ = "David Seifert"
__license__ = "GPL2+"
__maintainer__ = "Susana Posada Cespedes"
__email__ = "hiv-cbg@bsse.ethz.ch"
__status__ = "Development"

parser = argparse.ArgumentParser()
parser.add_argument(
    "-t", dest="TO", help="Name of target contig, e.g. HXB2:2253-2256", metavar="dest", required=True)
# parser.add_argument("-1", dest="ONE_BASED", help="Whether coordinates should be treated 1-based", default=False, action='store_true')
parser.add_argument("-v", dest="VERBOSE", help="Print more information (such as subsequences in references)",
                    default=False, action='store_true')
parser.add_argument("-i", dest="INPUT",
                    help="Input SAM/BAM file", metavar="input", required=True)
parser.add_argument("-o", dest="OUTPUT",
                    help="Output FASTA file", default="", metavar="output")
parser.add_argument("--mf", dest="MIN_FREQ",
                    help="Minimum frequency to output", default=0, metavar="FLOAT")
parser.add_argument("--mc", dest="MIN_COUNT",
                    help="Minimum count to output", default=0, metavar="INT")
parser.add_argument("--prefix", dest="PREFIX",
                    help="Prefix to add to header", default="seq", metavar="STR")
parser.add_argument("--rg", dest="MIN_GAP_FREQ", help="Minimum frequency for gap containing sequences", default=0,
                    metavar="FLOAT")
parser.add_argument("--rog", dest="REMOVEONLYGAPS", help="Remove sequences consisting only of gaps and stop codons",
                    default=False, action='store_true')
parser.add_argument("-p", dest="PROTEINS", help="Output sequences as translated proteins", default=False,
                    action='store_true')
parser.add_argument("-T", dest="TRAIT", help="Output sequences in trait format (for SeTesT)", default=False,
                    action='store_true')
parser.add_argument("FILES", nargs=1, metavar="MSA_file",
                    help="file containing MSA")
args = parser.parse_args()

TO_CONTIG = args.TO.split(':', 1)[0]
COORDINATES = args.TO.split(':', 1)[1]
# TO_CONTIG = FROM_CONTIG if args.TO == "" else args.TO
VERBOSE = args.VERBOSE
MINIMUM_FREQ = float(args.MIN_FREQ)
MINIMUM_COUNT = int(args.MIN_COUNT)
MSA_FILE = args.FILES[0]
PROTEINS = args.PROTEINS
PREFIX = args.PREFIX
REMOVEONLYGAPS = args.REMOVEONLYGAPS
# OFFSET = -1 if args.ONE_BASED else 0
TRAIT_OUTPUT = args.TRAIT
MIN_GAP_FREQ = float(args.MIN_GAP_FREQ)

INPUT = args.INPUT
input_filestem, input_extension = os.path.splitext(INPUT)
OUTPUT = (input_filestem + ".fasta") if args.OUTPUT == "" else args.OUTPUT

if (INPUT == OUTPUT):
    print("Input and output filenames cannot be identical!")
    sys.exit(0)

# load SAM/BAM
samfile = pysam.AlignmentFile(INPUT, "rb")

contig_loci_map = {}
contig_tuple = namedtuple("contig_map", "map locus_start locus_end")
for contig in samfile.references:
    result = convert_from_intervals_to_list(find_interval_on_dest(
        contig, TO_CONTIG, COORDINATES, 0, MSA_FILE, VERBOSE))
    contig_loci_map[contig] = contig_tuple(
        map=result, locus_start=result[0], locus_end=result[-1])

reads = {}

# pair reads
for record in samfile.fetch():
    read_id = str(record.query_name)

    if read_id in reads:
        reads[read_id].append(record)

        if reads[read_id][1].reference_start < reads[read_id][0].reference_start:
            reads[read_id][0], reads[read_id][
                1] = reads[read_id][1], reads[read_id][0]
    else:
        reads[read_id] = [record]


# print(reads[0:5])
# sys.exit(0)

def seq_builder(have_already_loci, subsequence, read, contig_details):
    seq = subsequence
    found_loci = have_already_loci

    if (contig_details.locus_start in range(read.reference_start, read.reference_end)) and (
            contig_details.locus_end in range(read.reference_start, read.reference_end)):

        pos_on_read = 0
        pos_on_genome = read.reference_start

        for cigar_op in read.cigartuples:
            op = cigar_op[0]
            op_len = cigar_op[1]

            # M, =, X:
            if op == 0 or op == 7 or op == 8:
                for i in range(0, op_len):
                    if (pos_on_genome + i in contig_details.map) and not (pos_on_genome + i in found_loci):
                        found_loci.append(pos_on_genome + i)
                        seq += read.query_sequence[pos_on_read + i]

                pos_on_read += op_len
                pos_on_genome += op_len

            # I, S:
            elif op == 1 or op == 4:
                pos_on_read += op_len

            # D:
            elif op == 2:
                for i in range(0, op_len):
                    if (pos_on_genome + i in contig_details.map) and not (pos_on_genome + i in found_loci):
                        found_loci.append(pos_on_genome + i)
                        seq += '-'

                pos_on_genome += op_len

            # N:
            elif op == 3:
                pos_on_genome += op_len

            # H, P:
            elif op == 5 or op == 6:
                pass

            else:
                raise ("Unrecognized CIGAR operation {}".format(op))

    return seq, found_loci


total = 0
extracted_subsequences = {}
for identifiers, mates in reads.items():
    contig_tuple = contig_loci_map[mates[0].reference_name]

    subsequence, loci_covered = seq_builder([], "", mates[0], contig_tuple)

    if len(subsequence) < len(contig_tuple.map) and len(mates) > 1:
        # need to try second mate
        subsequence, new_loci = seq_builder(
            loci_covered, subsequence, mates[1], contig_tuple)

    if len(subsequence) == len(contig_tuple.map):
        # covered all loci, hence add to extracted reads

        total += 1
        if subsequence in extracted_subsequences:
            extracted_subsequences[subsequence] += 1
        else:
            extracted_subsequences[subsequence] = 1

# print(extracted_subsequences)

retained_total = 0
sorted_seq = ((k, extracted_subsequences[k]) for k in
              sorted(extracted_subsequences, key=extracted_subsequences.get, reverse=True))

final_seq_dict = {}
if PROTEINS:
    # amino acid sequence
    # gather only valid peptide sequences
    for sequence, count in sorted_seq:
        try:
            peptide = Bio.Seq.translate(sequence, gap='-')
        except:
            continue

        if peptide in final_seq_dict:
            final_seq_dict[peptide] += count
        else:
            final_seq_dict[peptide] = count

        retained_total += count
else:
    # DNA sequences
    final_seq_dict = extracted_subsequences
    retained_total = total

# print(final_seq_dict)
sorted_seq = ((k, final_seq_dict[k]) for k in
              sorted(final_seq_dict, key=final_seq_dict.get, reverse=True))

out_file = open(OUTPUT, "wt")
i = 0
used = 0
for sequence, count in sorted_seq:
    if '-' in sequence and count / retained_total < MIN_GAP_FREQ:
        continue

    if REMOVEONLYGAPS:
        if sequence.count('-') == len(sequence) or sequence.count('*') == len(sequence):
            continue

    if count / retained_total >= MINIMUM_FREQ and count >= MINIMUM_COUNT:
        if TRAIT_OUTPUT:
            out_file.write("{}{}\t{}\t{}\n".format(PREFIX, i, sequence, count))
        else:
            out_file.write(">{}{}_{}\n{}\n".format(PREFIX, i, count, sequence))
        i += 1
        used += count
out_file.close()

try:
    print("Total reads:    {}".format(total))
    print("Retained reads: {} ({:.1f}%)".format(used, used / total * 100))
except:
    pass
