#!/usr/bin/env python3

import collections
import sys
from operator import itemgetter

import Bio.SeqIO
import argparse

__author__ = "David Seifert"
__copyright__ = "Copyright 2016-2017"
__credits__ = "David Seifert"
__license__ = "GPL2+"
__maintainer__ = "Susana Posada Cespedes"
__email__ = "hiv-cbg@bsse.ethz.ch"
__status__ = "Development"

parser = argparse.ArgumentParser()
parser.add_argument("-o", dest="OUTPUT", help="Name of output file to write sequences to", metavar="dest",
                    required=True)
parser.add_argument("-p", dest="MIN_COV",
                    help="Minimum case coverage to keep locus", default=0.97, type=float)
parser.add_argument("-a", dest="ALL_BASES", help="Use all bases, not just uppercase ones", default=False,
                    action='store_true')
parser.add_argument("-q", dest="QUIET", help="Do not print which loci were discarded", default=False,
                    action='store_true')
parser.add_argument("FILES", nargs=1, metavar="MSA_file",
                    help="file containing MSA")
args = parser.parse_args()

INPUT_FILE = args.FILES[0]
OUTPUT_FILE = args.OUTPUT
MIN_COV = args.MIN_COV
ALL_BASES = args.ALL_BASES

fasta_record = collections.namedtuple("fasta_record", "id seq start end")

seqs = []
seq_ids = []
seq_len = -1

# read records in input file
for record in Bio.SeqIO.parse(INPUT_FILE, "fasta"):
    if seq_len != -1 and len(record.seq) != seq_len:
        sys.exit("Sequence '{}' does not have same length (={}) as the other sequences (={})".format(record.id,
                                                                                                     len(
                                                                                                         record.seq),
                                                                                                     seq_len))
    else:
        seq_len = len(record.seq)

    if record.id in seq_ids:
        sys.exit("The file '{}' already contains a sequence called '{}'".format(
            INPUT_FILE, record.id))

    start = 0
    while record.seq[start] == '-':
        start += 1

    end = seq_len - 1
    while record.seq[end] == '-':
        end -= 1
    end += 1

    seq_ids.append(record.id)
    seqs.append(fasta_record(id=str(record.id),
                             seq=str(record.seq), start=start, end=end))

# Determine loci to keep/discard
keep_indices = []
discard_indices = []

for i in range(seq_len):
    coverage = 0
    nongap_coverage = 0

    for s in seqs:
        if s.start <= i < s.end:
            base = s.seq[i]

            if base == '-':
                coverage += 1
            else:
                if base.isupper() or (base.islower() and ALL_BASES):
                    coverage += 1
                    nongap_coverage += 1

    if coverage == 0:
        discard_indices.append(i)
    else:
        if nongap_coverage / coverage >= MIN_COV:
            keep_indices.append(i)
        else:
            discard_indices.append(i)

# Write new fasta file
out_file = open(OUTPUT_FILE, "wt")
for s in seqs:
    new_seq = ''.join(itemgetter(*keep_indices)(s.seq))
    out_file.write(">{}\n{}\n".format(s.id, new_seq))
out_file.close()

if not args.QUIET:
    print("The following {} positions have been removed:".format(
        len(discard_indices)), *discard_indices)
