#!/usr/bin/env python3

import pysam
import argparse
import numpy as np
from sklearn import manifold
import matplotlib.pyplot as plt

__author__ = "Susana Posada Cespedes"
__copyright__ = "Copyright 2017"
__credits__ = "Susana Posada Cespedes"
__license__ = "GPL2+"
__maintainer__ = "Susana Posada Cespedes"
__email__ = "hiv-cbg@bsse.ethz.ch"
__status__ = "Development"


def hamming_dist(s1, s2):
    """Return the Hamming distance between sequences"""
    assert len(s1) == len(
        s2), "Hamming distance is undefined for sequences differying in their lengths"
    d = 0
    alignment_len = 0
    for el1, el2 in zip(s1, s2):
        if el1 != '*' and el2 != '*':
            alignment_len += 1
            d += 1 if el1 != el2 else 0
    return (d, alignment_len)


def get_alignment_sequence(read, start, end):
    """
    Read sequenced bases, excluding soft-clipped bases and insertions
    but including deletions w.r.t reference sequence. Insertions are 
    returned separately
    """

    alignment = ''
    insertions_len = 0
    insertions = []
    idx = 0
    idx_ref = read.reference_start

    # cigartuples: cigar string encoded as a list of tuples (operation,
    # length)
    for op, length in read.cigartuples:
        # Return read bases if they correspond to an alignment match (0),
        # sequence match (7) or sequence mismatch (8)
        if op in [0, 7, 8]:
            alignment += read.query_sequence[idx:idx + length]
            idx += length
            idx_ref += length
        # Add gap symbol '-' for deletions (2). Do not increment the index
        # because deletions are not reported in read.query_sequence
        elif op == 2:
            alignment += ''.join(np.repeat('-', length))
            idx_ref += length
        # skip read base if they correspond to an insertion (1), but
        # increment insertions counter
        elif op == 1:
            if idx_ref > start and idx_ref < end:
                insertions_len += length
                insertions.append(
                    [idx_ref, read.query_sequence[idx:idx + length]])
            idx += length

        # skip read bases if they correspond to a soft-clipped bases (4)
        elif op == 4:
            idx += length

    if read.reference_start < start:
        idx = start - read.reference_start
        alignment = alignment[idx:]
    if read.reference_end > end:
        idx = len(alignment) - (read.reference_end - end)
        alignment = alignment[:idx]

    return (alignment, insertions, insertions_len)


def compare_insertions(insertions_i, insertions_j, len_i, len_j):
    """
    Evaluate whether insertions are equivalent between two sequences
    """
    dist = len_i + len_j
    ins_len = dist
    for ins_i in insertions_i:
        for ins_j in insertions_j:
            if ins_i[0] == ins_j[0]:
                min_len = min(len(ins_i[1]), len(ins_j[1]))
                aux = hamming_dist(ins_i[1][0:min_len], ins_j[1][0:min_len])[0]
                dist -= (min_len - aux) * 2
                ins_len -= min_len
    return (dist, ins_len)


def parse_args():
    """ Set up the parsing of command-line arguments """
    parser = argparse.ArgumentParser(description="Script for accuracy assesment",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    requiredNamed = parser.add_argument_group('required named arguments')
    requiredNamed.add_argument("-q", required=True, metavar='filename', dest='haplotypes',
                               help="File name prefix for reconstructed haplotypes")
    parser.add_argument("-s", "--start", required=False, default=2252, metavar='INT', dest='start',
                        type=int, help="Starting position of the region of interest, 0-based indexing")
    parser.add_argument("-e", "--end", required=False, default=2549, metavar='INT', dest='end',
                        type=int, help="Ending position of the region of interest, 0-based indexing. Note a half-open interval is used, i.e, [start:end)")
    parser.add_argument("-r", required=False, default=None, metavar='FASTA', dest='haplotypes_true',
                        help="FASTA file containing msa for true haplotypes")
    parser.add_argument("-t", "--thrd", required=False, default=0.02, metavar='FLOAT',
                        dest='freq_thrd', type=float, help="Threshold on the haplotype frequency")
    parser.add_argument("-d", "--sim", required=False, default=0.9, metavar='FLOAT', dest='sim_thrd',
                        type=float, help="Threshold on the similarity between true haplotypes and reconstructed haplotypes")
    parser.add_argument("-p", "--plot_outname", required=False, default='mds_plot.pdf',
                        metavar='FILENAME', dest='plot_outname', help="File name for the output plot")
    parser.add_argument("-o", "--outname", required=False, default='mapping.tsv',
                        metavar='FILENAME', dest='outname', help="File name for the output file containing mapping from reconstructed haplotypes to true haplotypes")

    return parser.parse_args()


def main():
    args = parse_args()

    # Load sequences with pysam
    bamfile = args.haplotypes + ".bam"

    region_length = args.end - args.start
    hap_alignment_sequences = []
    haplotype_IDs = []
    haplotype_ID_last = []
    insertions = []
    insertion_len = []

    # Parse reconstructed sequences using same indexing as the reference
    with pysam.AlignmentFile(bamfile, mode='rb') as alnfile:
        for hap_recons in alnfile.fetch(until_eof=True):
            alignment_sequence = get_alignment_sequence(
                hap_recons, args.start, args.end)
            aux = alignment_sequence[0]
            if len(aux) < region_length:
                if hap_recons.reference_start > args.start:
                    aux = ''.join(
                        [''.join(np.repeat('*', hap_recons.reference_start - args.start)), aux])
                if hap_recons.reference_end < args.end:
                    aux = ''.join(
                        [aux, ''.join(np.repeat('*', args.end - hap_recons.reference_end))])
            assert len(aux) == region_length, "Reconstructed sequence {} is shorter than expected".format(
                hap_recons.query_name)
            # Cliques with "gaps" are reported as a single sequence in the fasta and as separate sequence in the bam. Here, it is assumed that second sequence in the bam file can be appended after first.
            if hap_recons.query_name == haplotype_ID_last:
                assert locus_last <= hap_recons.reference_start, "It is assumed that second sequence in the bam file can be appended after first"
                aux2 = hap_alignment_sequences[-1][:(hap_recons.reference_start - args.start)] + aux[(
                    hap_recons.reference_start - args.start):]
                hap_alignment_sequences[-1] = aux2
                insertions[-1] = insertions[-1] + alignment_sequence[1]
                insertion_len[-1] += alignment_sequence[2]
            else:
                hap_alignment_sequences.append(aux)
                haplotype_IDs.append(hap_recons.query_name)
                insertions.append(alignment_sequence[1])
                insertion_len.append(alignment_sequence[2])
            haplotype_ID_last = hap_recons.query_name
            locus_last = hap_recons.reference_end

    num_haplotypes_recons = len(hap_alignment_sequences)

    # Compute pairwise distances
    dist = np.zeros(shape=(num_haplotypes_recons,
                           num_haplotypes_recons), dtype=int)
    alignment_len = np.zeros(
        shape=(num_haplotypes_recons, num_haplotypes_recons), dtype=int)

    for i in range(0, num_haplotypes_recons):
        for j in range(0, num_haplotypes_recons):
            dist_aux = hamming_dist(
                hap_alignment_sequences[i], hap_alignment_sequences[j])
            dist[i, j] = dist_aux[0]
            alignment_len[i, j] = dist_aux[1]
            if insertion_len[i] > 0 or insertion_len[j] > 0:
                if insertion_len[i] > 0 and insertion_len[j] > 0:
                    # discard that it doesn't correspond to the same inserted nucleotides
                    aux = compare_insertions(
                        insertions[i], insertions[j], insertion_len[i], insertion_len[j])
                    dist[i, j] += aux[0]
                    alignment_len[i, j] += aux[1]
                else:
                    dist[i, j] += insertion_len[i] + insertion_len[j]
                    alignment_len[i, j] += insertion_len[i] + insertion_len[j]

    # Parse relative abundances
    fastafile = args.haplotypes + ".fasta"
    haplotypes_recons = pysam.FastaFile(fastafile)
    haplotype_freqs = np.zeros(haplotypes_recons.nreferences)
    for h in haplotypes_recons.references:
        aux = h.split('|')
        idx = [i for i, ID in enumerate(haplotype_IDs) if ID == aux[0]]
        if aux[1] != 'paired':
            haplotype_freqs[idx] = float(aux[1].split(':')[-1])
        else:
            haplotype_freqs[idx] = float(aux[2].split(':')[-1])

    # Compute MDS and plot
    seed = np.random.RandomState(seed=3)
    mds = manifold.MDS(n_components=2, max_iter=3000, eps=1e-9, random_state=seed,
                       dissimilarity="precomputed", n_jobs=1)
    pos = mds.fit(dist).embedding_

    fig = plt.figure(figsize=(5, 4.5))
    colors = plt.cm.jet(np.linspace(0, 1, num_haplotypes_recons))
    ax = fig.add_subplot(111)
    ax.set_autoscaley_on(False)
    ax.scatter(pos[:, 0], pos[:, 1], s=haplotype_freqs *
               1000, c=colors, alpha=0.7, lw=0)
    for idx, label in enumerate(haplotype_IDs):
        ax.annotate(label, (pos[idx, 0], pos[idx, 1]), xytext=(-2, 2),
                    textcoords='offset points', ha='center', va='bottom')
    plt.axis('tight')

    plt.savefig(args.plot_outname, format='pdf')
    plt.close()

    # Compute mapping to true haplotypes (if available)
    if args.haplotypes_true is not None:
        haplotypes_true = pysam.FastaFile(args.haplotypes_true)
        num_haplotypes = haplotypes_true.nreferences

        dist = np.zeros(shape=(num_haplotypes_recons,
                               num_haplotypes), dtype=int)
        alignment_len = np.zeros(
            shape=(num_haplotypes_recons, num_haplotypes), dtype=int)
        c_recons = 0
        region_length = args.end - args.start

        for j in range(0, num_haplotypes):
            hap = haplotypes_true.fetch(
                reference=haplotypes_true.references[j], start=args.start, end=args.end)
            for i in range(0, num_haplotypes_recons):
                aux = hamming_dist(hap.upper(), hap_alignment_sequences[i])
                dist[i, j] = aux[0] + insertion_len[i]
                alignment_len[i, j] = aux[1]

        dist_min = np.min(dist, axis=1)
        dist_min_idxs = np.argmin(dist, axis=1)

        mapping = {}
        TP = 0
        TP_w = 0
        for i in range(0, num_haplotypes):
            haplotype_idxs = np.argwhere(dist_min_idxs == i)
            if haplotype_idxs.size > 0:
                haplotype_dist = dist_min[haplotype_idxs]
                idxs = haplotype_idxs[haplotype_dist == np.min(haplotype_dist)]
                # Compute sequence similarity between reconstructed haplotype and true haplotype
                sim = (alignment_len[idxs, i] - dist[idxs, i]) / region_length
                idx = idxs[sim == np.max(sim)]
                assert idx.size == 1, "More than one reconstructed haplotype mapped to the same true haplotype"
                idx = idx[0]
                if sim[np.argmax(sim)] >= args.sim_thrd and haplotype_freqs[idx] > args.freq_thrd:
                    TP += 1
                    TP_w += haplotype_freqs[idx]
                    mapping[haplotypes_true.references[i]] = (
                        haplotype_IDs[idx], dist[idx, i], alignment_len[idx, i], haplotype_freqs[idx])
                else:
                    print("Reconstructed haplotypes mapped to {} - Relative abundance(s):{} and corresponding similarity: {}".format(
                        haplotypes_true.references[i], haplotype_freqs[idxs], sim))

        # Write mapping to output
        with open(args.outname, 'w') as outfile:
            outfile.write(
                "# True haplotype\tHaplotype ID\tHamming distance\tReconstructed length\tRelative abundance\n")
            for ID, values in mapping.items():
                outfile.write('{}\t{}\t{}\t{}\t{:.6e}\n'.format(
                    ID, values[0], values[1], values[2], values[3]))

        # Weighted Precision: sum of relative abundancies of reconstructed haplotypes that map to a true haplotype up to certain similarity threshold divided by the sum of relative frequencies discarding low frequency haplotypes.
        # Lander-Waterman model: Freq_min = -G ln (1 - p^(1/G)) / NL
        recall = TP / num_haplotypes
        precision = TP / np.sum(haplotype_freqs > args.freq_thrd)
        precision_w = TP_w / \
            np.sum(haplotype_freqs[haplotype_freqs > args.freq_thrd])
        return (recall, precision, precision_w)


if __name__ == '__main__':
    results = main()
    if results is not None:
        print("Recall: {}".format(results[0]))
        print("Precision: {}".format(results[1]))
        print("Weighted precision: {}".format(results[2]))
