#!/usr/bin/env python3

import collections
import csv
import sys

import Bio.SeqIO
import argparse
from Bio import Phylo
from Bio.Phylo.TreeConstruction import DistanceTreeConstructor
from Bio.Phylo.TreeConstruction import _DistanceMatrix
from matplotlib import figure

__author__ = "David Seifert"
__copyright__ = "Copyright 2017"
__credits__ = "David Seifert"
__license__ = "GPL2+"
__maintainer__ = "Susana Posada Cespedes"
__email__ = "hiv-cbg@bsse.ethz.ch"
__status__ = "Development"

parser = argparse.ArgumentParser()
parser.add_argument("-i", dest="TSV", help="tsv file of patient/sample/cluster/type mapping", metavar="tsv",
                    required=True)
parser.add_argument("-s", dest="MSA", help="MSA file containing the alignments between patient-sample sequences",
                    metavar="MSA_file", required=True)
parser.add_argument("-o", dest="OUTPUT_FILE",
                    help="Output file for final pairs", metavar="pairs", required=True)
args = parser.parse_args()

TSV_FILE = args.TSV
MSA_FILE = args.MSA
OUTPUT_FILE = args.OUTPUT_FILE

patient_full_record = collections.namedtuple(
    "patient_full_record", "patient sample cluster type seq")
data = []

# parse samples
with open(TSV_FILE, newline='') as csvfile:
    tsv_file = csv.reader(csvfile, delimiter='\t')
    for row in tsv_file:
        assert len(row) == 4

        data.append(patient_full_record(patient=row[0], sample=row[
                    1], cluster=int(row[2]), type=row[3], seq=''))

# parse sequences
ids = []
sequences = []
for record in Bio.SeqIO.parse(MSA_FILE, "fasta"):
    ids.append(str(record.id))
    sequences.append(str(record.seq))

# assign sequence to data
for i in range(len(data)):
    index = ids.index(data[i].patient + '-' + data[i].sample)
    data[i] = patient_full_record(patient=data[i].patient, sample=data[i].sample, cluster=data[i].cluster,
                                  type=data[i].type, seq=sequences[index])


# build distance matrix
def calculate_distance_between_bases(baseA, baseB):
    baseA = baseA.upper()
    baseB = baseB.upper()

    assert baseA != '-'
    assert baseB != '-'

    IUPAC = {
        'A': ('A'),
        'C': ('C'),
        'G': ('G'),
        'T': ('T'),
        'R': ('A', 'G'),
        'Y': ('C', 'T'),
        'S': ('G', 'C'),
        'W': ('A', 'T'),
        'K': ('G', 'T'),
        'M': ('A', 'C'),
        'B': ('C', 'G', 'T'),
        'D': ('A', 'G', 'T'),
        'H': ('A', 'C', 'T'),
        'V': ('A', 'C', 'G'),
        'N': ('A', 'C', 'G', 'T')}

    mismatches = 0

    for i in IUPAC[baseA]:
        for j in IUPAC[baseB]:
            if i != j:
                mismatches += 1

    return mismatches / (len(IUPAC[baseA]) * len(IUPAC[baseA]))


def calculate_distance_between_sequences(seqA, seqB):
    assert len(seqA) == len(seqB)

    valid_loci = 0
    mismatches = 0

    for i in range(len(seqA)):
        if seqA[i].isupper() and seqB[i].isupper():
            valid_loci += 1
            mismatches += calculate_distance_between_bases(seqA[i], seqB[i])

    if valid_loci == 0:
        valid_loci += 1
        mismatches += 1

    return (mismatches / valid_loci, mismatches, valid_loci)


names = []
matrix = []

transmit_pair = collections.namedtuple(
    "transmit_pair", "cluster T Tsample R Rsample")
TFs = []

I = 0
for i in data:
    best_distance = (float("inf"), float("inf"), float("inf"))
    best_distance_cluster = (float("inf"), float("inf"), float("inf"))
    best_distance_cluster_transmitter = (
        float("inf"), float("inf"), float("inf"))

    dists = []
    J = 0
    for j in data:
        if i != j:
            ij_distance = calculate_distance_between_sequences(i.seq, j.seq)

            if ij_distance[0] < best_distance[0]:
                best_distance = ij_distance
                best = j

            if i.cluster == j.cluster:
                if ij_distance[0] < best_distance_cluster[0]:
                    best_distance_cluster = ij_distance
                    best_cluster = j

                if 'R' in i.type and 'T' in j.type:
                    if ij_distance[0] < best_distance_cluster_transmitter[0]:
                        best_distance_cluster_transmitter = ij_distance
                        best_cluster_transmitter = j

        if J == I:
            dists.append(0)
        elif J < I:
            dists.append(ij_distance[0])

        J += 1

    new_name = i.patient + '-' + i.sample
    if new_name in names:
        sys.exit("{} already exists".format(new_name))
    else:
        names.append(new_name)
    matrix.append(dists)

    if 'R' in i.type:
        # diagnostic info
        print("{}-{} ({}):".format(i.patient, i.sample, i.type), file=sys.stderr)
        print("\tBest neighbour:             {}-{} (dist={})".format(best.patient, best.sample, best_distance),
              file=sys.stderr)
        print("\tBest cluster neighbour:     {}-{} (dist={})".format(best_cluster.patient, best_cluster.sample,
                                                                     best_distance_cluster), file=sys.stderr)
        print("\tBest transmitter neighbour: {}-{} (dist={})\n".format(best_cluster_transmitter.patient,
                                                                       best_cluster_transmitter.sample,
                                                                       best_distance_cluster_transmitter),
              file=sys.stderr)

        # write new tsv
        print("{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}".format(
            i.patient, i.sample, i.cluster,
            best.patient, best.sample, best.cluster, *best_distance,
            best_cluster.patient, best_cluster.sample, best_cluster.cluster, *best_distance_cluster,
            best_cluster_transmitter.patient, best_cluster_transmitter.sample, best_cluster_transmitter.cluster,
            *best_distance_cluster_transmitter
        ))

        if i.cluster == best.cluster:
            # good enough for us
            TFs.append(transmit_pair(cluster=i.cluster, T=best_cluster_transmitter.patient,
                                     Tsample=best_cluster_transmitter.sample, R=i.patient, Rsample=i.sample))

    I += 1

TFs.sort(key=lambda x: x.cluster)
SUFFIXES = ('a', 'b', 'c', 'd')

cur = ''
for i in range(len(TFs)):
    if TFs[i].cluster != cur:
        cur = TFs[i].cluster
        count = 0
    else:
        if count == 0:
            TFs[i - 1] = transmit_pair(cluster=str(TFs[i - 1].cluster) + SUFFIXES[0], T=TFs[i - 1].T,
                                       Tsample=TFs[i - 1].Tsample, R=TFs[i - 1].R, Rsample=TFs[i - 1].Rsample)

        count += 1
        TFs[i] = transmit_pair(cluster=str(TFs[i].cluster) + SUFFIXES[count], T=TFs[i].T, Tsample=TFs[i].Tsample,
                               R=TFs[i].R, Rsample=TFs[i].Rsample)

with open(OUTPUT_FILE, "wt") as out_file:
    out_file.write("Cluster\tTransmitter\tT_sample\tRecipient\tR_sample\n")
    for i in TFs:
        out_file.write("{}\t{}\t{}\t{}\t{}\n".format(*i))

DistMatrix = _DistanceMatrix(names, matrix)
constructor = DistanceTreeConstructor()
tree = constructor.nj(DistMatrix)

PhyloPlot = Phylo.draw(tree)
figure(PhyloPlot, figsize=(8, 6))
savefig('foo.pdf', figsize=(8, 6))
