#! /usr/bin/env python
from __future__ import print_function, division
import collections
import os
import sys
import time
import argparse
import pandas
import numpy as np
from Bio import SeqIO
from Bio.Seq import Seq
from sklearn.cluster import AffinityPropagation
from Bio.SeqUtils import GC
from Bio.Alphabet import IUPAC
from argparse import RawTextHelpFormatter
import csv
__author__ = "Elaina Graham"
__license__ = "GPL 3.0"
__maintainer__ = "Elaina Graham"
__email__ = "egraham147@gmail.com"
##########################Get GC Count#######################################
def GC_Fasta_file(b, name, outdir):
    """Makes a tab-delimeted output indicating the GC count associated with each contig"""
    GC_dict = {}
    input_file = open(os.path.join(b, name), 'r')
    output_file = open(os.path.join(outdir, 'GC_count.txt'), 'w')
    output_file.write("%s\t%s\n" % ('contig', 'GC'))
    for cur_record in SeqIO.parse(input_file, "fasta"):
        gene_name = cur_record.id
        GC_percent = float(GC(cur_record.seq))
        GC_dict.setdefault(gene_name, [])
        GC_dict[gene_name].append(GC_percent)
        output_line = '%s\t%i\n' % \
            (gene_name, GC_percent)
        output_file.write(output_line)

    output_file.close()
    input_file.close()
###########################Get tetramer frequencies#####################################
def kmer_list(dna, k):
    """Makes list of k-mers based on input of k and the dna sequence"""
    result = []
    dna = dna.upper()
    dna_edit = Seq(dna, IUPAC.unambiguous_dna)
    reverse_complement = (dna_edit).reverse_complement()
    dna_edit = str(dna_edit)
    reverse_complement = str(reverse_complement)
    for x in range(len(dna_edit)+1-k):
        result.append(dna_edit[x:x+k])
    for x in range(len(reverse_complement)+1-k):
        result.append(reverse_complement[x:x+k])
    result = [s for s in result if not s.strip('AGTC')]
    return result
#####################################
def kmer_counts(kmer, b, name):
    kmer_dict = {}
    for record in SeqIO.parse(os.path.join(b, name), "fasta"):
        id_ = str(record.id)
        seq = str(record.seq)
        length = len(seq)
        tetra = kmer_list(seq, kmer)
        c = collections.Counter(tetra)
        c = dict(c)
        val = list(c.values())
        val_edit = []
        for freq in val:
            freq = (float(freq)/float(length))
            val_edit.append(freq)

        keys = []
        for key in c:
            keys.append(str(key))
        c_edit = dict(list(zip(keys, val_edit)))
        kmer_dict.setdefault(str(id_), [])
        kmer_dict[str(id_)].append(c_edit)
    return kmer_dict
##################################
def output(a, outdir, kmer):
    """Builds tab-delimeted file based on dictionary of tetramers"""
    topleft = 'contig'
    headers = sorted(set(key
                         for row in list(a.values())
                         for key in row[0]))
    writer = csv.writer(open(os.path.join(
        outdir, "%smer-frequencies.txt" % str(kmer)), 'wb'), delimiter='\t')
    writer.writerow([topleft] + headers)
    for key in a:
        row = [key]
        for header in headers:
            row.append(a[key][0].get(header, 0))
        writer.writerow(row)
#########################Combine tetramer frequencies and GC counts in tab delimeted format and normalize################################
def get_contigs(c, outdir, kmer):
    GC = open(os.path.join(outdir, "GC_count.txt"), "r")
    dataframe = pandas.read_csv(GC, sep="\t", index_col=0, header=None)
    dataframe = dataframe.drop('contig')
    dataframe = dataframe.apply(pandas.to_numeric)
    print(dataframe.head())
    dataframe *= 100
    dataframe += 1
    dataframe = np.log10(dataframe)
    tetra = open(os.path.join(outdir, '%smer-frequencies.txt' %
                              (str(kmer))), "r")
    dataframe2 = pandas.read_csv(tetra, sep="\t", index_col=0, header=None)
    dataframe2 = dataframe2.drop('contig')
    dataframe2 = dataframe2.apply(pandas.to_numeric)
    dataframe2 *= 100
    dataframe2 += 1
    dataframe2 = np.log10(dataframe2)
    cov = open(str(c), "r")
    reader = list(csv.reader(cov, delimiter='\t'))
    titles = ["contig"]
    reader.insert(0, titles)
    dataframe3 = pandas.DataFrame(reader)
    dataframe3.columns = dataframe3.iloc[0]
    dataframe3 = dataframe3[1:]
    dataframe3 = dataframe3.set_index("contig")
    result = pandas.concat(
        [dataframe, dataframe2, dataframe3], axis=1, join='inner')
    result.to_csv(path_or_buf=str(os.path.join(
        outdir, "kmer-GC-out.txt")), header=None, sep="\t")
##########################Run AP for refinement##############################
def create_fasta_dict(fastafile):
    """makes dictionary using fasta files to be binned"""
    fasta_dict = {}
    for record in SeqIO.parse(fastafile, "fasta"):
        fasta_dict[record.id] = record.seq
    return fasta_dict
def get_cov_data(h):
    """Makes a dictionary of coverage values for affinity propagation"""
    all_cov_data = {}
    inf=open(str(h),"r")
    for line in inf:
        line = line.rstrip()
        cov_data = line.split()
        all_cov_data[cov_data[0]] = cov_data
    inf.close()
    return all_cov_data
#####################
def cov_array(a, b, name, size):
    """Computes a coverage array based on the contigs in files and the contig names associated with the coverage file"""
    names = []
    c = 0
    cov_array = []
    for record in SeqIO.parse(os.path.join(b, name), "fasta"):
        if len(record.seq) >= int(size):
            if record.id in a:
                data = a[record.id]
                names.append(data[0])
                data.remove(data[0])
                line = " ".join(data)
                if c == 1:
                    temparray = np.fromstring(line, dtype=float, sep=' ')
                    cov_array = np.vstack((cov_array, temparray))
                if c == 0:
                    cov_array = np.fromstring(line, dtype=float, sep=' ')
                    c += 1
    print("          %s" % (cov_array.shape,))
    return cov_array, names
#################################
def affinity_propagation(array, names, file_name, damping, iterations, convergence, preference, path, output_directory, outname):
    """Uses affinity propagation to make putative bins"""
    apclust = AffinityPropagation(damping=float(damping),max_iter=int(iterations), convergence_iter=int(
        convergence), copy=True, preference=int(preference), affinity='euclidean', verbose=False).fit_predict(array)
    outfile_data = {}
    i = 0
    while i < len(names):
        if apclust[i] in outfile_data:
            outfile_data[apclust[i]].append(names[i])
        if apclust[i] not in outfile_data:
            outfile_data[apclust[i]] = [names[i]]
        i += 1
    if outname is None:
        out_name = file_name.rsplit(".", 1)[0]+"_Bin"
    else:
        out_name = outname+"_Bin"
    with open(os.path.join(path, file_name), "r") as input2_file:
        fasta_dict = create_fasta_dict(input2_file)
        count = 0
        for k in outfile_data:
            if len(outfile_data[k]) >= 5:
                output_file = open(os.path.join(output_directory, str(
                    out_name)+"-refined_%s.fna" % (k)), "w")
                for x in outfile_data[k]:
                    output_file.write(">"+str(x)+"\n"+str(fasta_dict[x])+"\n")
                output_file.close()
                count = count + 1
            elif len(outfile_data[k]) < 5 and len(outfile_data[k]) > 1:
                if any((len(fasta_dict[x]) > 50000) for x in outfile_data[k]):
                    output_file = open(os.path.join(
                        output_directory, str(out_name)+"-bin_%s.fna" % (k)), "w")
                    for x in outfile_data[k]:
                        output_file.write(
                            ">"+str(x)+"\n"+str(fasta_dict[x])+"\n")
                    output_file.close()
                    count = count + 1
            else:
                output_file = open(os.path.join(
                    output_directory, 'unclustered.fna'), 'a')
                for x in outfile_data[k]:
                    output_file.write(">"+str(x)+"\n"+str(fasta_dict[x])+"\n")
                output_file.close()
            if os.path.isfile(os.path.join(output_directory, 'unclustered.fna')) is True:
                contig_number = 0
                for record in SeqIO.parse(os.path.join(output_directory, 'unclustered.fna'), "fasta"):
                    contig_number += 1
                if contig_number < 2:
                    os.remove(os.path.join(
                        output_directory, 'unclustered.fna'))
            print("          Cluster "+str(k)+": "+str(len(outfile_data[k])))
        print(("          Total Number of Bins: %i" %int(count)))
class Logger(object):
    def __init__(self, logfile, location):
        self.terminal = sys.stdout
        self.log = open(os.path.join(location, logfile), "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):
        pass
if __name__ == '__main__':
    parser = argparse.ArgumentParser(prog='Binsanity-refine', description="""
    ***********************************************************************
    *****************************BinSanity*********************************
    ** Binsanity-refine uses combined coverage and composition           **
    ** (in the form of tetramer frequencies and GC%) to cluster          **
    ** contigs. The published workflow uses this to refine bins          **
    ** initially clustered soley on coverage.                            **
    **                                                                   **
    ** Binsanity-refine can be used as a stand alone script if you       **
    ** don't have more than 2 sample replicates                          **
    **                                                                   **
    ***********************************************************************""", usage='%(prog)s -f [/path/to/fasta] -l [fastafile] -kmer [kmer type] -c [coverage file] -p [preference] -m [maxiter] -v [convergence iterations] -d [damping factor] -x [contig size] -o [output directory]', formatter_class=RawTextHelpFormatter)
    parser.add_argument("-c", dest="inputCovFile", help="""
    Specify a Coverage File
    """)
    parser.add_argument("-f", dest="inputContigFiles", help="""
    Specify directory containing your contigs
    """)
    parser.add_argument("-p", type=float, dest="preference", default=-25, help="""
    Specify a preference (default is -25)
    Note: decreasing the preference leads to more lumping,
    increasing will lead to more splitting. If your range
    of coverages are low you will want to decrease the
    preference, if you have 10 or less replicates increasing
    the preference could benefit you. For complex datasets
    with low abundance organisms a preference
    of -25 was found to be optimal
    """)
    parser.add_argument("-m", type=int, dest="maxiter", default=4000, help="""
    Specify a max number of iterations (default is 4000)
    """)
    parser.add_argument("-v", type=int, dest="conviter", default=400, help="""
    Specify the convergence iteration number (default is 200)
    e.g Number of iterations with no change in the number
    of estimated clusters that stops the convergence.
    """)
    parser.add_argument("-kmer", type=int, dest="inputKmer", default=4, help="""
    Specify a number for kmer calculation. Default is 4.
    Tetramer frequencies are recommended
    """)
    parser.add_argument("-d", default=0.95, type=float, dest="damp", help="""
    Specify a damping factor between 0.5 and 1, default is 0.9
    """)
    parser.add_argument("-l", dest="fastafile", help="""
    Specify the fasta file containing contigs you want to cluster
    """)
    parser.add_argument("-x", dest="ContigSize", type=int, default=1000, help="""
    Specify the contig size cut-off (Default 1000 bp)
    """)
    parser.add_argument("-o", dest="outputdir", default="BINSANITY-REFINEMENT", help="""
    Give a name to the directory BinSanity results will be output in
    [Default is 'BINSANITY-REFINEMENT']
    """)
    parser.add_argument("--log", dest="LogFile", default="binsanity-refine.log", help="""
    Specify an output name for the log file. [Default: binsanity-refine.log]""")
    parser.add_argument('--version', action='version',
                        version='%(prog)s v0.4.4')
    parser.add_argument("--outPrefix", dest="outname", default=None, help="""
    Sepcify what prefix you want appended to final Bins {optional}""")

    args = parser.parse_args()
    if len(sys.argv) < 3:
        parser.print_help()
    if args.inputCovFile is None:
        if (args.inputContigFiles is None) and (args.fastafile is None):
            parser.print_help()
    if (args.inputCovFile is None):
        print("Please indicate -c coverage file")
    if args.inputContigFiles is None:
        print("Please indicate -f directory containing your contigs")
    elif args.inputContigFiles and not args.fastafile:
        parser.error('-l Need to identify file to be clustered')

    else:
        start_time = time.time()
        if os.path.isdir(str(args.outputdir)) is False:
            os.mkdir(args.outputdir)
        sys.stdout = Logger(args.LogFile, args.outputdir)
        print("""
        ******************************************************
        **********************BinSanity***********************
        |____________________________________________________|
        |               Binsanity refinement                 |
        |                                                    |
        |              Calculating GC content                |
        |____________________________________________________|
        """)
        GC_time = time.time()
        GC_Fasta_file(args.inputContigFiles, args.fastafile, args.outputdir)
        print("          GC content calculated in %s seconds" %
              (time.time() - GC_time))
        print("""
         ______________________________________________________

                     Calculating %smer frequencies
         ______________________________________________________""" % str(args.inputKmer))
        kmer_time = time.time()
        output(kmer_counts(args.inputKmer, args.inputContigFiles,
                           args.fastafile), args.outputdir, args.inputKmer)
        print("           %smer frequency calculated in %s seconds" %
              (args.inputKmer, time.time() - kmer_time))
        print("""
        ______________________________________________________

                   Creating Combined Profile
        ______________________________________________________""")
        combine_time = time.time()
        print("            Combined profile created in %s seconds" %
              (time.time() - combine_time))

        get_contigs(args.inputCovFile, args.outputdir, args.inputKmer)
        print("""
        ______________________________________________________

         Clustering Using %smers, GC percentage, and Coverage
        ______________________________________________________""" % (args.inputKmer))
        print("          Preference: " + str(args.preference))
        print("          Maximum Iterations: " + str(args.maxiter))
        print("          Convergence Iterations: " + str(args.conviter))
        print("          Contig Cut-Off: " + str(args.ContigSize))
        print("          Damping Factor: " + str(args.damp))
        print("          Coverage File: " + str(args.inputCovFile))
        print("          Fasta File: " + str(args.fastafile))

        val1, val2 = cov_array((get_cov_data(os.path.join(
            args.outputdir, 'kmer-GC-out.txt'))), args.inputContigFiles, args.fastafile, args.ContigSize)

        affinity_propagation(val1, val2, args.fastafile, args.damp, args.maxiter, args.conviter,
                             args.preference, args.inputContigFiles, args.outputdir, args.outname)

        print(("""
        ______________________________________________________
                   Bins Computed in
                   %s seconds
        ______________________________________________________""" % (time.time() - start_time)))
        for f in os.listdir(args.outputdir):
            if f.endswith(".txt"):
                os.remove(os.path.join(args.outputdir, f))
