#!/usr/bin/env python3

import sys

import Bio.SeqIO
import argparse

__author__ = "David Seifert"
__copyright__ = "Copyright 2016-2017"
__credits__ = "David Seifert"
__license__ = "GPL2+"
__maintainer__ = "Susana Posada Cespedes"
__email__ = "hiv-cbg@bsse.ethz.ch"
__status__ = "Development"

parser = argparse.ArgumentParser()
parser.add_argument("--it", dest="INPUT_TRANS",
                    help="Input file from QuasiRecomb, usually 'quasispecies.fasta' for transmitter",
                    metavar="source", required=True)
parser.add_argument("--ir", dest="INPUT_RECIPIENT",
                    help="Input file from QuasiRecomb, usually 'quasispecies.fasta' for recipient",
                    metavar="source", required=True)
parser.add_argument("-o", dest="OUTPUT", help="Name of output file to write sequences to", metavar="dest",
                    required=True)
parser.add_argument("--prefix", dest="PREFIX",
                    help="Prefix to use in FASTA header (retains source prefixes if none provided)", default="")
parser.add_argument("--mf", dest="MIN_FREQ",
                    help="Minimum frequency required for keeping sequence", default=0)
parser.add_argument("-L", dest="LENGTH",
                    help="Length of sequences have to be EXACTLY L", default=0)
parser.add_argument("-p", dest="PROTEIN", help="Translate sequences into protein sequences", default=False,
                    action='store_true')
args = parser.parse_args()

INPUT_TRANS = args.INPUT_TRANS
INPUT_RECIPIENT = args.INPUT_RECIPIENT

OUTPUT_FILE = args.OUTPUT
PREFIX = args.PREFIX
MIN_FREQ = float(args.MIN_FREQ)
LENGTH = int(args.LENGTH)

TRANSLATE_INTO_PROTEIN = args.PROTEIN

# first round
invalid_loci = []


def load_initial(filename):
    first_round_sequences = {}
    for record in Bio.SeqIO.parse(filename, "fasta"):
        identifier = str(record.id.split('_', 1)[0])
        freq = float(record.id.split('_', 1)[1])
        new_seq = str(record.seq.upper())

        if LENGTH > 0 and len(new_seq) != LENGTH:
            continue

        if len(new_seq) == 289:
            # Vpr hack, remove frameshift
            new_seq = new_seq[:212] + new_seq[213:]

        if new_seq.count('-') / len(new_seq) > 0.10 or freq < MIN_FREQ:
            continue

        if new_seq in first_round_sequences:
            first_round_sequences[new_seq][0] += freq
        else:
            first_round_sequences[new_seq] = [freq, identifier]

        for i in range(0, len(new_seq), 3):
            if 0 < new_seq[i:i + 3].count('-'):
                invalid_loci.append(i)

    return first_round_sequences


seqs_trans = load_initial(INPUT_TRANS)
seqs_recip = load_initial(INPUT_RECIPIENT)

invalid_loci = sorted(set(invalid_loci))


# second round
def curate_seqs(seqs):
    final_sequences = {}
    for seq, details in seqs.items():
        new_seq = ""

        for i in range(0, len(seq), 3):
            if not i in invalid_loci:
                new_seq += seq[i:i + 3]

        if TRANSLATE_INTO_PROTEIN:
            try:
                new_seq = Bio.Seq.translate(new_seq, gap='-')
            except:
                continue

        if new_seq in final_sequences:
            final_sequences[new_seq][0] += details[0]
        else:
            final_sequences[new_seq] = details
    return final_sequences


final_trans = curate_seqs(seqs_trans)
final_recip = curate_seqs(seqs_recip)

# finally, write file
out_file = open(OUTPUT_FILE, "wt")
i = 0


def write_to_file(seqs, file):
    # sort by count
    sorted_seq = ((k, seqs[k]) for k in sorted(
        seqs, key=lambda e: e[1][0], reverse=True))

    # write out
    out_file = open(OUTPUT_FILE, "wt")
    i = 0
    for sequence, details in sorted_seq:
        if details[0] < MIN_FREQ:
            continue

        if PREFIX == "":
            header = details[1]
        else:
            header = PREFIX + str(i)

        file.write(">{}_{}\n{}\n".format(
            header, int(10000 * details[0]), sequence))
        i += 1


write_to_file(final_trans, out_file)
write_to_file(final_recip, out_file)

out_file.close()
