#!/usr/bin/env python3

'''
input:  alignment file as BAM file
output: consensus sequences including either the majority base or the ambiguous bases as FASTA files
'''

import collections
import pysam
import argparse
import os
import numpy as np
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq
from smallgenomeutilities.__checkPath__ import CheckPath
from smallgenomeutilities.__pileup__ import AlignedRead, ascii2idx, get_counts

__author__ = "Susana Posada Cespedes"
__copyright__ = "Copyright 2017"
__credits__ = "Susana Posada Cespedes"
__license__ = "GPL2+"
__maintainer__ = "Susana Posada Cespedes"
__email__ = "hiv-cbg@bsse.ethz.ch"
__status__ = "Development"


def parse_args():
    """ Set up the parsing of command-line arguments """
    parser = argparse.ArgumentParser(description="Script to construct consensus sequences",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    requiredNamed = parser.add_argument_group('required named arguments')
    requiredNamed.add_argument(
        "-i", required=True, metavar='BAM', dest='bamfile', help="Input BAM file")
    parser.add_argument("-f", required=False, metavar='FASTA', dest='reference',
                        type=str, help="Fasta file containing the reference sequence")
    parser.add_argument("-r", required=False, default=None, metavar='BED', dest='region', type=str,
                        help="Region of interested in BED format, e.g. HXB2:2253-3869. Loci are interpreted using 0-based indexing, and a half-open interval is used, i.e, [start:end)")
    parser.add_argument("-c", required=False, default=50, metavar='INT', dest='min_coverage',
                        type=int, help="Minimum read depth for reporting variants per locus")
    parser.add_argument("-q", required=False, default=15, metavar='INT', dest='qual_thrd',
                        type=int, help="Minimum phred quality score a base has to reach to be counted")
    parser.add_argument("-a", required=False, default=0.05, metavar='FLOAT',
                        dest='min_freq', type=float, help="Minimum frequency for an ambiguous nucleotide")
    parser.add_argument("-N", required=False, default="CONSENSUS", type=str,
                        metavar='STR', dest="sampleID", help="Patient/sample identifier")
    parser.add_argument("-o", required=False, default=os.getcwd(),
                        action=CheckPath, metavar='PATH', dest='outdir', help="Output directory")

    return parser.parse_args()


def majority_vote(counts, cov_thrd, alphabet):
    majority = np.argmax(counts, 0)
    cons_majority = alphabet[majority]

    return cons_majority, majority


def ambiguous_bases(counts, min_freq, majority=None):
    """
    counts: nucleotide counts per locus. Columns correspond to loci and rows to nucleotides A, C, G, T and deletions, in this order.
    """
    # Mapping {1: 'T', 2:'G', 3:'K', 4:'C', 5:'Y', 6:'S', 7:'B', 8:'A', 9:'W', 10:'R', 11:'D', 12:'M', 13:'H', 14:'V', 15:'N'}
    IUPAC_CODE = np.array(('n', 'T', 'G', 'K', 'C', 'Y',
                           'S', 'B', 'A', 'W', 'R', 'D', 'M', 'H', 'V', 'N'))
    coverage = np.sum(counts, 0)
    min_count = np.floor(min_freq * coverage)
    min_count[min_count == 0] = 1
    aux = counts >= min_count
    aux = aux.astype(int)
    encoding = np.array((8, 4, 2, 1))
    for i in range(encoding.size):
        aux[i, aux[i, :] == 1] = encoding[i]
    cons_ambig = IUPAC_CODE[np.sum(aux[:4, ], 0)]

    # Take into account deletions, when information is available
    if counts.shape[0] > 4:
        assert majority is not None
        cons_ambig[np.logical_and(majority == 4, aux[4, :] == 1)] = '-'

    return cons_ambig


def main():
    alphabet = np.array(['A', 'C', 'G', 'T', '-'])
    args = parse_args()

    reference_name = None
    if args.region is not None:
        aux = args.region.split(":")
        reference_name = aux[0]
        sampleID = ':'.join((args.sampleID, aux[1]))
        aux = aux[1].split('-')
        start = int(aux[0])
        end = int(aux[1])
    else:
        start = None
        end = None
        sampleID = args.sampleID

    # 1. Load BAM file and get counts per loci
    with pysam.AlignmentFile(args.bamfile, 'rb') as alnfile:
        if reference_name is None:
            reference_name = alnfile.references[0]
        counts = alnfile.count_coverage(
            contig=reference_name, start=start, stop=end, quality_threshold=args.qual_thrd)
        counts = np.array(counts)

        # Account for deletions w.r.t reference
        region_len = counts.shape[1]
        alphabet_len = alphabet.size
        counts_dels = get_counts(
            [args.bamfile, reference_name, start, end, region_len, alphabet_len])
        counts_dels = counts_dels.reshape(
            (alphabet_len, region_len), order='F')

    # 2. Build majority consensus
    cons_majority, majority_idx = majority_vote(
        counts, args.min_coverage, alphabet)
    cons_majority_dels, majority_idx = majority_vote(
        counts_dels, args.min_coverage, alphabet)

    # 3. Build consensus including ambiguous bases
    cons_ambig = ambiguous_bases(counts, args.min_freq)
    cons_ambig_dels = ambiguous_bases(counts_dels, args.min_freq, majority_idx)

    mask = np.sum(counts, 0) < args.min_coverage
    mask_dels = np.sum(counts_dels, 0) < args.min_coverage
    if args.reference is not None:
        reference = pysam.FastaFile(args.reference)
        assert reference.references[0] == reference_name, "the name of the genomic region and the reference differ"
        reference = reference.fetch(
            reference=reference_name, start=start, end=end).lower()
        aux = [reference[idx] for idx in range(len(reference)) if mask[idx]]
        cons_majority[mask] = aux
        cons_ambig[mask] = aux

        aux = [reference[idx]
               for idx in range(len(reference)) if mask_dels[idx]]
        cons_majority_dels[mask_dels] = aux
        cons_ambig_dels[mask_dels] = aux
    else:
        cons_majority[mask] = 'n'
        cons_ambig[mask] = 'n'

        cons_majority_dels[mask_dels] = 'n'
        cons_ambig_dels[mask_dels] = 'n'

    # 4. Write to output
    cons_majority = SeqRecord(
        Seq(''.join(cons_majority)), id=sampleID, description="| Majority-vote rule")
    cons_ambig = SeqRecord(Seq(''.join(cons_ambig)),
                           id=sampleID, description="| Ambiguous bases")

    cons_majority_dels = SeqRecord(
        Seq(''.join(cons_majority_dels)), id=sampleID, description="| Majority-vote rule")
    cons_ambig_dels = SeqRecord(
        Seq(''.join(cons_ambig_dels)), id=sampleID, description="| Ambiguous bases")

    with open(os.path.join(args.outdir, "ref_majority.fasta"), "w") as outfile:
        SeqIO.write(cons_majority, outfile, "fasta")

    with open(os.path.join(args.outdir, "ref_ambig.fasta"), "w") as outfile:
        SeqIO.write(cons_ambig, outfile, "fasta")

    with open(os.path.join(args.outdir, "ref_majority_dels.fasta"), "w") as outfile:
        SeqIO.write(cons_majority_dels, outfile, "fasta")

    with open(os.path.join(args.outdir, "ref_ambig_dels.fasta"), "w") as outfile:
        SeqIO.write(cons_ambig_dels, outfile, "fasta")


if __name__ == '__main__':
    main()
