#! /usr/bin/env python
from __future__ import print_function, division

import os
import sys
import time
import argparse
import shutil
import numpy as np
from Bio import SeqIO
from argparse import RawTextHelpFormatter
from sklearn.cluster import AffinityPropagation


__author__ = "Elaina Graham"
__license__ = "GPL 3.0"
__maintainer__ = "Elaina Graham"
__email__ = "egraham147@gmail.com"

###############################################################################################################
def create_fasta_dict(fastafile):
    """makes dictionary using fasta files to be binned"""
    fasta_dict = {}
    for record in SeqIO.parse(fastafile, "fasta"):
        fasta_dict[record.id] = record.seq
    return fasta_dict
#################################################################################################################
def get_cov_data(h):
    """Makes a dictionary of coverage values for affinity propagation"""
    all_cov_data = {}
    for line in open(str(h), "r"):
        line = line.rstrip()
        cov_data = line.split()
        all_cov_data[cov_data[0]] = cov_data
    return all_cov_data
###################################################################################################################
def cov_array(a, path, file_name, size):
    """Computes a coverage array based on the contigs in files and the contig names associated with the coverage file"""
    count = 0
    names = []
    c = 0
    cov_array = []
    for record in SeqIO.parse(os.path.join(path, file_name), "fasta"):
        count = count + 1
        if len(record.seq) >= int(size):
            if record.id in a:
                data = a[record.id]
                names.append(data[0])
                data.remove(data[0])
                line = " ".join(data)
                if c == 1:
                    temparray = np.fromstring(line, dtype=float, sep=' ')
                    cov_array = np.vstack((cov_array, temparray))
                if c == 0:
                    cov_array = np.fromstring(line, dtype=float, sep=' ')
                    c += 1
    print("          %s" % (cov_array.shape,))
    return cov_array, names
################################################################################################################
def affinity_propagation(array, names, file_name, damping, iterations, convergence, preference, path, output_directory, outname):
    """Uses affinity propagation to make putative bins"""
    apclust = AffinityPropagation(damping=float(damping), max_iter=int(iterations), convergence_iter=int(
        convergence),copy=True,preference=int(preference), affinity='euclidean', verbose=False).fit_predict(array)
    outfile_data = {}
    i = 0
    while i < len(names):
        if apclust[i] in outfile_data:
            outfile_data[apclust[i]].append(names[i])
        if apclust[i] not in outfile_data:
            outfile_data[apclust[i]] = [names[i]]
        i += 1
    if outname is None:
        out_name = file_name.rsplit(".", 1)[0]+"_Bin"
    else:
        out_name = outname+"_Bin"
    with open(os.path.join(path, file_name), "r") as input2_file:
        fasta_dict = create_fasta_dict(input2_file)
        count = 0
        for k in outfile_data:
            if len(outfile_data[k]) >= 5:
                output_file = open(os.path.join(
                    output_directory, str(out_name)+"-%s.fna" % (k)), "w")
                for x in outfile_data[k]:
                    output_file.write(">"+str(x)+"\n"+str(fasta_dict[x])+"\n")
                output_file.close()
                count = count + 1
            elif len(outfile_data[k]) < 5 and len(outfile_data[k]) > 1:
                if any((len(fasta_dict[x]) > 50000) for x in outfile_data[k]):
                    output_file = open(os.path.join(
                        output_directory, str(out_name)+"-bin_%s.fna" % (k)), "w")
                    for x in outfile_data[k]:
                        output_file.write(
                            ">"+str(x)+"\n"+str(fasta_dict[x])+"\n")
                    output_file.close()
                    count = count + 1
            else:
                output_file = open(os.path.join(
                    output_directory, 'unclustered.fna'), 'a')
                for x in outfile_data[k]:
                    output_file.write(">"+str(x)+"\n"+str(fasta_dict[x])+"\n")
                output_file.close()
            if os.path.isfile(os.path.join(output_directory, 'unclustered.fna')) is True:
                contig_number = 0
                for record in SeqIO.parse(os.path.join(output_directory, 'unclustered.fna'), "fasta"):
                    contig_number += 1
                if contig_number < 2:
                    os.remove(os.path.join(
                        output_directory, 'unclustered.fna'))
            print("          Cluster "+str(k)+": "+str(len(outfile_data[k])))
        print(("          Total Number of Bins: %i"%int(count)))
########################################################################################
class Logger(object):
    def __init__(self, logfile, location):
        self.terminal = sys.stdout
        self.log = open(os.path.join(location, logfile), "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):
        pass
###################################################################
if __name__ == '__main__':
    parser = argparse.ArgumentParser(prog='Binsanity', description="""
    *****************************************************************************
    *********************************BinSanity***********************************
    **  Binsanity uses Affinity Propagation to split metagenomic contigs into  **
    **  'bins' using contig coverage values. It takes as input a coverage file **
    **  and files containing the contigs to be binned, then outputs clusters   **
    **  of contigs in putative bins.                                           **
    **                                                                         **
    **  NOTE: BinSanity becomes highly memory intensive at 100,000 contigs or  **
    **  above. If you have greater than this number of contigs we recommend    **
    **  trying the beta-version for our script binsanity-lc.                   **
    *****************************************************************************
    """, usage='%(prog)s -f [/path/to/fasta] -l [fastafile] -c [coverage file]', formatter_class=RawTextHelpFormatter)
    parser.add_argument("-c", dest="inputCovFile", help="""Specify a Coverage File
    """)
    parser.add_argument("-f", dest="inputContigFiles", help="""Specify directory
    containing your contigs
    """)
    parser.add_argument("-p", type=float, dest="preference", default=-3, help="""
    Specify a preference (default is -3)

    Note: decreasing the preference leads to more lumping,
    increasing will lead to more splitting. If your range
    of coverages are low you will want to decrease the preference,
    if you have 10 or less replicates increasing the preference could
    benefit you.""")
    parser.add_argument("-m", type=int, dest="maxiter", default=4000, help="""
    Specify a max number of iterations [default is 2000]""")
    parser.add_argument("-v", type=int, dest="conviter", default=400, help="""
    Specify the convergence iteration number (default is 200)
    e.g Number of iterations with no change in the number
    of estimated clusters that stops the convergence.""")
    parser.add_argument("-d", default=0.95, type=float, dest="damp", help="""
    Specify a damping factor between 0.5 and 1, default is 0.9""")
    parser.add_argument("-l", dest="fastafile", help="""
    Specify the fasta file containing contigs you want to cluster""")
    parser.add_argument("-x", dest="ContigSize", type=int, default=1000, help="""
    Specify the contig size cut-off [Default 1000 bp]""")
    parser.add_argument("-o", dest="outputdir", default="BINSANITY-RESULTS", help="""
    Give a name to the directory BinSanity results will be output in
    [Default is 'BINSANITY-RESULTS']""")
    parser.add_argument("--outPrefix", dest="outname", default=None, help="""
    Sepcify what prefix you want appended to final Bins {optional}""")
    parser.add_argument("--log", dest="logfile", default="binsanity-logfile.txt", help="""
    specify a name for the log file [Default is 'binsanity-logfile.txt']""")
    parser.add_argument('--version', action='version',
                        version='%(prog)s v0.4.4')

    args = parser.parse_args()
    if len(sys.argv) < 3:
        print(parser.print_help())
    elif args.inputCovFile is None:
        if (args.inputContigFiles is None) and (args.fastafile is None):
            parser.print_help()
    elif (args.inputCovFile is None):
        print("Please indicate -c coverage file")
    elif args.inputContigFiles is None:
        print("Please indicate -f directory containing your contigs")
    elif args.inputContigFiles and not args.fastafile:
        parser.error('-l Need to identify file to be clustered')
    else:
        if os.path.isdir(str(args.outputdir)) is False:
            os.mkdir(args.outputdir)
        start_time = time.time()
        sys.stdout = Logger(args.logfile, args.outputdir)
        print("""
        ******************************************************
        **********************BinSanity***********************
        |____________________________________________________|
        |                                                    |
        |             Computing Coverage Array               |
        |____________________________________________________|
        """)
        print("          Preference: " + str(args.preference))
        print("          Maximum Iterations: " + str(args.maxiter))
        print("          Convergence Iterations: " + str(args.conviter))
        print("          Contig Cut-Off: " + str(args.ContigSize))
        print("          Damping Factor: " + str(args.damp))
        print("          Coverage File: " + str(args.inputCovFile))
        print("          Fasta File: " + str(args.fastafile))
        print("          Output directory: " + str(args.outputdir))
        print("          logfile: " + str(args.logfile))

        val1, val2 = cov_array((get_cov_data(args.inputCovFile)),
                               args.inputContigFiles, args.fastafile, args.ContigSize)

        print("""
         ______________________________________________________
        |                                                      |
        |                 Clustering Contigs                   |
        |______________________________________________________|

        """)
        affinity_propagation(val1, val2, args.fastafile, args.damp, args.maxiter, args.conviter,
                             args.preference, args.inputContigFiles, args.outputdir, args.outname)

        print()
        """
         *|*|*|*|*|*|*|*|*|*|*|*|*|*|*|*|*|*|*|*|*|*|*|*|*|*|*
         _____________________________________________________
        |                                                     |
        |                   Creating Bins                     |
        |_____________________________________________________|
        """

        print(("""
         _____________________________________________________

                       Putative Bins Computed
                       in %s seconds
         _____________________________________________________""" % (time.time() - start_time)))
