#! /usr/bin/env python
from __future__ import print_function, division
from Bio import SeqIO
import sys, subprocess, argparse,glob, csv, time
import pandas, os
import numpy as np
from argparse import RawTextHelpFormatter
__author__ = "Elaina Graham"
__license__ = "GPL 3.0"
__maintainer__ = "Elaina Graham"
__email__ = "egraham147@gmail.com"
def format_for_feature(inputfile, mapfile, outputpath):
    for file_ in os.listdir(mapfile):
        if file_.endswith(".bam"):
            outname = str(file_).lstrip('.bam')
            handle = open(os.path.join(outputpath, str(outname)+".saf"), 'w')
            handle.write('GeneID\tChr\tStart\tEnd\tStrand\n')
            for record in SeqIO.parse(inputfile, 'fasta'):
                seqLen = len(record.seq)
                start, stop = 1, seqLen
                geneId, Chr = str(record.id), str(record.id)
                outLine = [geneId, Chr, str(start), str(stop), '+']
                outLine = '\t'.join(outLine)
                handle.write(outLine + "\n")
            handle.close()
def feature_counts(mapfile, outputdir, threads):
    for file_ in os.listdir(mapfile):
        if file_.endswith('.bam'):
            outname = str(file_).lstrip('.bam')
            mapfile_file = os.path.join(mapfile, file_)
            subprocess.call(["featureCounts", "-M", "-O", "-F", "SAF", "-T", str(threads), "-a", os.path.join(outputdir, str(
                outname)+".saf"), "-o", os.path.join(outputdir, str(outname)+".readcounts"), str(mapfile_file)], stdout=subprocess.PIPE)
def make_coverage(fasta,outfile, outputdir):
    cov_data = {}
    for record in SeqIO.parse(fasta,"fasta"):
        cov_data[record.id]=[]
    count_filenames = [f for f in os.listdir(
        outputdir) if f.endswith('.readcounts')]
    for i in count_filenames:
        for line in open(os.path.join(outputdir, i), "r"):
            if line[0] != "#" or line[:6] != "Geneid":
                line = line.rstrip()
                data = line.split()
                try:
                    length = len(cov_data[data[0]])
                    if length == 0:
                        cov_data[data[0]] = [data[5], data[6]]
                    if length > 0:
                        cov_data[data[0]].append(data[6])
                except KeyError:
                    continue
    out1 = open(str(outfile)+".cov", "w")
    for k in cov_data:
        contig_length = cov_data[k].pop(0)
        coverage_list = [float(x) / float(contig_length) for x in cov_data[k]]
        coverage_list[:] = [str("%.10f" % (i)) for i in coverage_list]
        out1.write(k + "\t" + str("\t".join(coverage_list))+"\n")
    out1.close()
def get_log(file_, outfile):
    cov_file = list(csv.reader(open(str(file_), 'r'), delimiter='\t'))
    for list_ in cov_file:
        for num_ in range(1, len(list_)):
            list_[num_] = np.log10(((float(list_[num_])))+1)
    pd = pandas.DataFrame(cov_file)
    pd.to_csv(outfile, sep="\t", index=False, header=False)
def get_multiple(file_, outfile, num):
    cov_file = list(csv.reader(open(str(file_), 'r'), delimiter='\t'))
    for list_ in cov_file:
        for num_ in range(1, len(list_)):
            list_[num_] = float(list_[num_])*int(num)
    pd = pandas.DataFrame(cov_file)
    pd.to_csv(outfile, sep="\t", index=False, header=False)
def get_squareroot(file_, outfile):
    cov_file = list(csv.reader(open(str(file_), 'r'), delimiter='\t'))
    for list_ in cov_file:
        for num_ in range(1, len(list_)):
            list_[num_] = np.sqrt(((float(list_[num_])))+1)
    pd = pandas.DataFrame(cov_file)
    pd.to_csv(outfile, sep="\t", index=False, header=False)
def get_100x_log(file_, outfile):
    cov_file = list(csv.reader(open(str(file_), 'r'), delimiter='\t'))
    for list_ in cov_file:
        for num_ in range(1, len(list_)):
            list_[num_] = np.log10((float(list_[num_])*int(100))+1)
    pd = pandas.DataFrame(cov_file)
    pd.to_csv(outfile, sep="\t", index=False, header=False)
########################################################################################
class Logger(object):
    def __init__(self, logfile, location):
        self.terminal = sys.stdout
        self.log = open(os.path.join(location, logfile), "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):
        pass
if __name__ == "__main__":
    parser = argparse.ArgumentParser(prog='Binsanity-profile', usage='%(prog)s -i fasta_file -s {sam,bam}_file --id contig_ids.txt -c output_file', description="""
    ***********************************************************************
    ******************************BinSanity********************************
    **                                                                   **
    **  Binsanity-profile is used to generate coverage files for         **
    **  input to BinSanity. This uses Featurecounts to generate a        **
    **  a coverage profile and transforms data for input into Binsanity, **
    **  Binsanity-refine, and Binsanity-wf                               **
    **                                                                   **
    ***********************************************************************
    ***********************************************************************""", formatter_class=RawTextHelpFormatter)
    parser.add_argument("-i", dest="inputFasta",
                        help="Specify fasta file being profiled")
    parser.add_argument("-s", dest="inputMapLoc", help="""
    identify location of BAM files
    BAM files should be indexed and sorted""")
    parser.add_argument("-c", dest="outCov", help="""
    Identify name of output file for coverage information""")
    parser.add_argument("--transform", dest="transform", default="scale", help="""
    Indicate what type of data transformation you want in the final file [Default:log]:
    scale --> Scaled by multiplying by 100 and log transforming
    log --> Log transform
    None --> Raw Coverage Values
    X5 --> Multiplication by 5
    X10 --> Multiplication by 10
    X100 --> Multiplication by 100
    SQR --> Square root
    We recommend using a scaled log transformation for initial testing.
    Other transformations can be useful on a case by case basis""")
    parser.add_argument('-T', dest='Threads', default=1, type=int,
                        help="Specify Number of Threads For Feature Counts [Default: 1]")
    parser.add_argument('-o', dest="outDirectory", default=".",
                        help="Specify directory for output files to be deposited [Default: Working Directory]")
    parser.add_argument('--version', action='version',
                        version='%(prog)s v0.4.4')
    args = parser.parse_args()
    if len(sys.argv) < 3:
        parser.print_help()
    elif args.inputFasta is None:
        print("Need to identify fasta file using '-i'")
    elif args.inputFasta and args.inputMapLoc is None:
        print("Need to identify fasta file using '-i' and directory containing SAM/BAM files using '-s'.")
    elif args.inputFasta and args.outCov is None:
        print("Need to identify fasta file using '-i' and output file for coverage profile using '-c'.")
    elif args.inputFasta is None:
        print("Need to identify fasta file using '-i'.")
    elif args.outCov is None:
        print("Need to identify the output file using '-c'")
    elif args.inputMapLoc is None:
        print("Need to identify directory containing SAM files using '-s'")
###########################################
    else:
        if args.outDirectory is not ".":
            if os.path.exists(args.outDirectory) is False:
                os.mkdir(args.outDirectory)
#       sys.stdout = Logger('Binsanity-Profile.log', args.outDirectory)
        format_for_feature(
            args.inputFasta, args.inputMapLoc, args.outDirectory)

        print("""
        ******************************************************
                    Contigs formated to generate counts
        ******************************************************
        """)
        feature_counts(args.inputMapLoc, args.outDirectory, args.Threads)

        make_coverage(args.inputFasta, args.outCov, args.outDirectory)
###########################################
        x = os.path.join(args.outCov+".cov")
        if args.transform == "scale":
            get_100x_log(x, x+".x100.lognorm")
        elif args.transform == "log":
            print("Transforming your combined coverage profile")
            get_log(x, x+".lognorm")
        elif args.transform == "X5":
            print("Transforming your combined coverage profile")
            get_multiple(x, x+".x5", int(5))
        elif args.transform == "X10":
            print("Transforming your combined coverage profile")
            get_multiple(x, x+".x10", int(10))
        elif args.transform == "X100":
            print("Transforming your combined coverage profile")
            get_multiple(x, x+".x100", int(100))
        elif args.transform == "SQR":
            print("Transforming your combined coverage profile")
            get_squareroot(x, x+".sqrt")
        print("""

        ********************************************************
                        Coverage profile produced
        ********************************************************

        """)
