#!/usr/bin/env python3

'''
input:  alignment file(s) as Bam file(s) and, either reference sequence in
        FASTA format or, if only one genetic region is of interest, the name of
        the reference. An index for each BAM file should exists in the same
        directory as the corresponding BAM file.
output: tab-separated file. Minority alleles on the rows and samples as columns.
        Loci are reported using 0-based indexing.
        Majority base/aminoacid per loci is identify by summing relative
        abundances.
'''

import pysam
import argparse
import os
import numpy as np
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq
from multiprocessing import Pool
from smallgenomeutilities.__checkPath__ import CheckPath
from smallgenomeutilities.__pileup__ import AlignedRead, ascii2idx, get_counts

__author__ = "Susana Posada Cespedes"
__copyright__ = "Copyright 2017"
__credits__ = "Susana Posada Cespedes"
__license__ = "GPL2+"
__maintainer__ = "Susana Posada Cespedes"
__email__ = "hiv-cbg@bsse.ethz.ch"
__status__ = "Development"


def parse_args():
    """ Set up the parsing of command-line arguments """
    parser = argparse.ArgumentParser(description="Script to extract minority alleles per samples",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    requiredNamed = parser.add_argument_group('required named arguments')
    requiredNamed.add_argument("-r", required=True, metavar='FASTA', dest='reference',
                               help="Either a fasta file containing a reference sequence or the reference name of the region/chromosome of interest. The latter is expected if a region is specified")
    parser.add_argument("-s", "--start", required=False, default=None, metavar='INT', dest='start',
                        type=int, help="Starting position of the region of interest, 0-based indexing")
    parser.add_argument("-e", "--end", required=False, default=None, metavar='INT', dest='end',
                        type=int, help="Ending position of the region of interest, 0-based indexing. Note a half-open interval is used, i.e, [start:end)")
    parser.add_argument("-p", "--config", required=False, default=None, metavar='file.config', dest='frames',
                        help="Report minority aminoacids - a .config file specifying reading frames expected")
    parser.add_argument("-c", required=False, default=100, metavar='INT', dest='min_coverage',
                        type=int, help="Minimum read depth for reporting variants per locus and sample")
    parser.add_argument("-N", required=False, default=None, metavar='name1,name2,...', dest="patientIDs",
                        help="Patient/sample identifiers as comma separated strings")
    parser.add_argument("-t", required=False, default=1, metavar='INT', dest='thrds', type=int,
                        help="Number of threads")
    parser.add_argument("-f", "--freqs", required=False, action='store_true', dest='freqs',
                        help="Indicates whether or not all frequencies should be stored")
    parser.add_argument("-d", required=False, action='store_true', dest='coverage',
                        help="Indicates whether coverage per locus should be written to output file")
    parser.add_argument("-o", required=False, default=os.getcwd(), action=CheckPath, metavar='PATH', dest='outdir',
                        help="Output directory")
    parser.add_argument("FILES", nargs='+', metavar='BAM', help="BAM file(s)")

    return parser.parse_args()


def main():
    args = parse_args()

    if args.start is not None:
        assert args.end is not None, 'Minority variants are extracted from a region of interest. An ending position was expected'

    if args.end is not None:
        if args.start is None:
            print('Starting position was expected. Setting it to 0')
            args.start = 0

    if args.frames is None:
        # Nucleotides - Alphabet = {A, C, G, T, -}, including deletions
        alphabet = np.array(['A', 'C', 'G', 'T', '-'])
        alphabet_len = alphabet.size
    else:
        alphabet = np.array(['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
                             'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'])
        alphabet_len = alphabet.size
        print("Option not implemented yet")
        raise SystemExit(0)

    num_samples = len(args.FILES)

    if args.patientIDs is not None:
        patientIDs = args.patientIDs.split(",")
        assert len(
            patientIDs) == num_samples, 'Number of patient/sample identifiers do not match number of BAM files'

    if args.start is None and args.end is None:
        # Length of the reference sequence
        # Only one reference sequence expected
        start = args.start
        end = args.end
        reference = pysam.FastaFile(args.reference)
        assert reference.nreferences == 1, 'Only one reference sequence expected'
        region_len = reference.lengths[0]
        reference_name = reference.references[0]
        header = reference_name
        cohort_consensus = reference.fetch(reference=reference_name).lower()

    else:
        start = args.start
        end = args.end - 1
        region_len = end - start + 1
        reference_name = args.reference
        header = "{}:{}-{}".format(reference_name, start, end)
        cohort_consensus = 'n' * region_len

    args_list = [(bamfile, reference_name, start, end, region_len,
                  alphabet_len) for bamfile in args.FILES]
    pool = Pool(processes=args.thrds)
    res = pool.map(get_counts, args_list)
    pool.close()
    pool.join()

    nt_counts = np.vstack(res).T

    coverage = np.zeros(shape=(region_len, num_samples), dtype=int)
    coverage_raw = np.zeros(shape=(region_len, num_samples), dtype=int)
    nt_freqs = np.zeros(shape=(region_len * alphabet_len, num_samples))

    for i in range(region_len):
        idx_locus = i * alphabet_len
        coverage_raw[i, ] = np.sum(
            nt_counts[idx_locus:(idx_locus + alphabet_len), ], axis=0)
        # Identify samples with coverage below threshold and discard those read
        # counts
        mask = coverage_raw[i, ] < args.min_coverage
        coverage[i, ] = coverage_raw[i, ]
        coverage[i, mask] = 0
        nt_freqs[idx_locus:(idx_locus + alphabet_len), ~mask] = nt_counts[
            idx_locus:(idx_locus + alphabet_len), ~mask] / coverage[i, ~mask]

    if start is None and end is None:
        pos = np.arange(region_len)
    else:
        pos = np.arange(start, end + 1)

    # Exclude loci for which all samples report zero counts
    loci = pos[np.sum(coverage, axis=1) > 0]
    variant_loci = np.zeros(shape=loci.size * alphabet_len).astype(bool)
    minor_variants_freqs = np.zeros(shape=(variant_loci.size, num_samples))

    # Write matrix of frequencies per locus and per sample
    if args.freqs:
        minor_variants = np.tile(alphabet, loci.size)
        loci_tile = np.repeat(loci, alphabet_len)
        if start is not None:
            loci_tile = loci_tile - start
        aux = nt_freqs[(loci_tile * alphabet_len) +
                       np.tile(np.arange(alphabet_len), loci.size)]

        type_string = 'i8,U8,' + ','.join('f8' for x in np.arange(num_samples))
        out = np.zeros(loci_tile.size, dtype=type_string)
        for i_idx in range(loci_tile.size):
            out[i_idx][0] = loci_tile[i_idx]
            out[i_idx][1] = minor_variants[i_idx]
            for j_idx in range(num_samples):
                out[i_idx][2 + j_idx] = aux[i_idx, j_idx]

        np.save(os.path.join(args.outdir, 'frequencies.npy'), out)

    # Extract minority variants
    for i, locus in enumerate(loci):

        if start is not None:
            locus -= start

        # Counts per bases for locus i and for all samples
        idx_locus = locus * alphabet_len
        nt_freqs_locus = nt_freqs[idx_locus:(idx_locus + alphabet_len), ]

        # Identify variants: bases reporting at least <min_coverage count> for one
        # of the samples. Store 'True' if the sum of nucleotide frequencies across
        # sample is larger than 0.
        idx_array = i * alphabet_len
        variant_loci[idx_array:(idx_array + alphabet_len)
                     ] = np.sum(nt_freqs_locus, axis=1) > 0
        minor_variants_freqs[idx_array:(
            idx_array + alphabet_len), ] = nt_freqs_locus

        # Identify samples for which current locus doesn't report
        # 'min_coverage'
        mask = coverage[locus, ] == 0
        minor_variants_freqs[idx_array:(
            idx_array + alphabet_len), mask] = np.nan

        # Identify the majority base per locus, and omit samples for which
        # locus is not covered
        nt_freqs_locus = nt_freqs_locus[:, ~mask]
        idx_major = np.sum(nt_freqs_locus, axis=1).argmax()

        # Store 'False' for the majority variant
        variant_loci[idx_array:(idx_array + alphabet_len)][idx_major] = False
        cohort_consensus = cohort_consensus[
            :locus] + alphabet[idx_major] + cohort_consensus[(locus + 1):]

    # Instantiate arrays
    minor_variants = np.tile(alphabet, loci.size)
    loci = np.repeat(loci, alphabet_len)

    # Exclude bases with zero counts for all samples, as well as majority bases
    minor_variants = minor_variants[variant_loci]
    minor_variants_freqs = minor_variants_freqs[variant_loci, ]
    loci = loci[variant_loci]

    # Write to output file
    if args.patientIDs is None:
        patientIDs = "\t".join(str(x) for x in np.arange(num_samples))
    else:
        patientIDs = "\t".join(patientIDs)

    with open(os.path.join(args.outdir, 'minority_variants.tsv'), 'w') as outfile:
        outfile.write("# pos\tvariant\t" + patientIDs + "\n")
        for i_idx in range(loci.size):
            outfile.write('{:d}\t{:s}'.format(
                loci[i_idx],  minor_variants[i_idx]))
            for j_idx in range(num_samples):
                outfile.write('\t{:.6e}'.format(
                    minor_variants_freqs[i_idx, j_idx]))
            outfile.write('\n')

    # Write to output file cohort-consensus. Consensus is built with respect to the reference
    cohort_consensus = SeqRecord(Seq(cohort_consensus), id=header, description="")
    with open(os.path.join(args.outdir, 'cohort_consensus.fasta'), 'w') as outfile:
        SeqIO.write(cohort_consensus, outfile, "fasta")

    # Write to output file coverage per locus
    if args.coverage:
        out = np.concatenate((pos[:, np.newaxis], coverage_raw), axis=1)
        header = "pos\t" + patientIDs

        np.savetxt(os.path.join(args.outdir, 'coverage.tsv'), out, fmt='%d',
                   delimiter='\t', header=header)


if __name__ == '__main__':
    main()
