#!/usr/bin/env python3

# Copyright (C) 2017, Weizhi Song, Torsten Thomas.
# songwz03@gmail.com or t.thomas@unsw.edu.au

# Binning_refiner is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# Binning_refiner is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import os
import shutil
import argparse
from Bio import SeqIO
from datetime import datetime
from Binning_refiner.Binning_refiner_config import config_dict


def force_create_folder(folder_to_create):
    if os.path.isdir(folder_to_create):
        shutil.rmtree(folder_to_create, ignore_errors=True)
        if os.path.isdir(folder_to_create):
            shutil.rmtree(folder_to_create, ignore_errors=True)
            if os.path.isdir(folder_to_create):
                shutil.rmtree(folder_to_create, ignore_errors=True)
                if os.path.isdir(folder_to_create):
                    shutil.rmtree(folder_to_create, ignore_errors=True)
    os.mkdir(folder_to_create)


def get_no_hidden_folder_list(wd):

    folder_list = []
    for each_folder in os.listdir(wd):
        if not each_folder.startswith('.'):
            folder_list.append(each_folder)

    folder_list_sorte = sorted(folder_list)

    return folder_list_sorte


######################################## CONFIGURATION #######################################

parser = argparse.ArgumentParser()

parser.add_argument('-i',       required=True,  help='input bin folder')
parser.add_argument('-p',       required=False, default='Refined',      help='output prefix, default: Refined')
parser.add_argument('-m',       required=False, default=512,  type=int, help='minimal size (Kbp) of refined bin, default: 512')
parser.add_argument('-plot',    required=False, default=False,          action="store_true", help='visualize refinement with Sankey plot')
parser.add_argument('-x',       required=False, default=False,          help='the width of sankey plot')
parser.add_argument('-y',       required=False, default=False,          help='the height of sankey plot')
parser.add_argument('-q',       required=False, default=False,          help='silent progress report')

args = vars(parser.parse_args())

input_bin_folder =      args['i']
output_bin_prefix =     args['p']
minBin_size_Kbp =       args['m']
plot_sankey =           args['plot']
sankey_plot_width =     args['x']
sankey_plot_height =    args['y']
keep_quiet =            args['q']


if input_bin_folder[-1] == '/':
    input_bin_folder = input_bin_folder[:-1]

pwd_plot_sankey_R = config_dict['get_sankey_plot_R']


################################### define output filename ###################################

Binning_refiner_wd =                  '%s_Binning_refiner_outputs'      % output_bin_prefix
pwd_report_file_sources_and_length =  '%s/%s_sources_and_length.txt'    % (Binning_refiner_wd, output_bin_prefix)
pwd_report_file_contigs =             '%s/%s_contigs.txt'               % (Binning_refiner_wd, output_bin_prefix)
pwd_report_file_for_sankey =          '%s/%s_sankey.csv'                % (Binning_refiner_wd, output_bin_prefix)
pwd_plot_file_for_sankey =            '%s/%s_sankey.html'               % (Binning_refiner_wd, output_bin_prefix)
pwd_output_bin_folder =               '%s/%s_refined_bins'              % (Binning_refiner_wd, output_bin_prefix)

force_create_folder(Binning_refiner_wd)
force_create_folder(pwd_output_bin_folder)

time_format = '[%Y-%m-%d %H:%M:%S]'


################################### precheck of input files ##################################

# get input_bin_subfolders
input_bin_subfolders = get_no_hidden_folder_list(input_bin_folder)

for bin_subfolder in input_bin_subfolders:

    # get bin file list in each input bin subfolder
    pwd_bin_subfolder = '%s/%s' % (input_bin_folder, bin_subfolder)
    bin_file_list = get_no_hidden_folder_list(pwd_bin_subfolder)

    bin_ext_list = set()
    for each_bin in bin_file_list:
        each_bin_ext = each_bin.split('.')[-1]
        bin_ext_list.add(each_bin_ext)

    if len(bin_ext_list) > 1:
        print('Program exited, please make sure all bins within %s folder have the same extension' % bin_subfolder)
        exit()


##################################### refine input bins ######################################

# progress report
if keep_quiet is False:
    print('%s Detected %s input bin folders' % ((datetime.now().strftime(time_format)), len(input_bin_subfolders)))

ctg_length_dict = {}
ctg_to_bin_dict = {}
for bin_subfolder in input_bin_subfolders:

    # get bin file list in each input bin subfolder
    pwd_bin_subfolder = '%s/%s' % (input_bin_folder, bin_subfolder)
    bin_file_list = get_no_hidden_folder_list(pwd_bin_subfolder)

    # progress report
    if keep_quiet is False:
        print('%s Read in %s %s bins' % ((datetime.now().strftime(time_format)), len(bin_file_list), bin_subfolder))


    for each_bin in bin_file_list:
        pwd_each_bin = '%s/%s/%s' % (input_bin_folder, bin_subfolder, each_bin)

        for each_seq in SeqIO.parse(pwd_each_bin, 'fasta'):
            each_seq_id = str(each_seq.id)
            each_seq_len = len(each_seq.seq)

            # store contig length into dict
            if each_seq_id not in ctg_length_dict:
                ctg_length_dict[each_seq_id] = each_seq_len

            # store contig to bin info into dict
            if each_seq_id not in ctg_to_bin_dict:
                ctg_to_bin_dict[each_seq_id] = [each_bin]
            else:
                ctg_to_bin_dict[each_seq_id].append(each_bin)


# progress report
if keep_quiet is False:
    print('%s Refine input bins' % (datetime.now().strftime(time_format)))


# only keep contigs existed in all bin sets
ctg_to_bin_dict_shared = {}
for each_ctg in ctg_to_bin_dict:
    if len(ctg_to_bin_dict[each_ctg]) == len(input_bin_subfolders):
        ctg_to_bin_dict_shared[each_ctg] = '___'.join(ctg_to_bin_dict[each_ctg])


# get concatenated_bins_to_ctg_dict
concatenated_bins_to_ctg_dict = {}
concatenated_bins_length_dict = {}
for each_shared_ctg in ctg_to_bin_dict_shared:

    concatenated_bins = ctg_to_bin_dict_shared[each_shared_ctg]

    if concatenated_bins not in concatenated_bins_to_ctg_dict:
        concatenated_bins_to_ctg_dict[concatenated_bins] = [each_shared_ctg]
        concatenated_bins_length_dict[concatenated_bins] = ctg_length_dict[each_shared_ctg]
    else:
        concatenated_bins_to_ctg_dict[concatenated_bins].append(each_shared_ctg)
        concatenated_bins_length_dict[concatenated_bins] += ctg_length_dict[each_shared_ctg]


# remove short length
concatenated_bins_to_ctg_dict_short_removed = {}
concatenated_bins_length_dict_short_removed = {}
maximum_length = 0
for concatenated_bins in concatenated_bins_length_dict:
    concatenated_bins_length = concatenated_bins_length_dict[concatenated_bins]
    if concatenated_bins_length >= minBin_size_Kbp * 1024:
        concatenated_bins_to_ctg_dict_short_removed[concatenated_bins] = concatenated_bins_to_ctg_dict[concatenated_bins]
        concatenated_bins_length_dict_short_removed[concatenated_bins] = concatenated_bins_length_dict[concatenated_bins]

        if concatenated_bins_length_dict[concatenated_bins] > maximum_length:
            maximum_length = concatenated_bins_length_dict[concatenated_bins]


# add leading zero to length
concatenated_bins_length_dict_short_removed_str = {}
for each_length in concatenated_bins_length_dict_short_removed:
    concatenated_bins_length_dict_short_removed_str[each_length] = ("{:0%sd}" % len(str(maximum_length))).format(concatenated_bins_length_dict_short_removed[each_length])


concatenated_bins_list = []
for each_concatenated_bin in concatenated_bins_length_dict_short_removed_str:
    concatenated_bins_list.append('%s___%s' % (concatenated_bins_length_dict_short_removed_str[each_concatenated_bin], each_concatenated_bin))


# sort concatenated_bins
concatenated_bins_list_sorted = sorted(concatenated_bins_list, reverse=True)


# progress report
if keep_quiet is False:
    print('%s Got %s refined bins with size larger than %s Kbp' % ((datetime.now().strftime(time_format)), len(concatenated_bins_list_sorted), minBin_size_Kbp))


# progress report
if keep_quiet is False:
    print('%s %s' % ((datetime.now().strftime(time_format)), 'Extract sequences for refined bins'))


# extract sequences and write out report file
pwd_report_file_sources_and_length_handle = open(pwd_report_file_sources_and_length, 'w')
pwd_report_file_sources_and_length_handle.write('Refined_bin\tSize(Kbp)\tSource\n')

pwd_report_file_contigs_handle = open(pwd_report_file_contigs, 'w')
pwd_report_file_contigs_handle.write('Refined_bin\tContigs\n')

pwd_report_file_for_sankey_handle = open(pwd_report_file_for_sankey, 'w')
pwd_report_file_for_sankey_handle.write('C1,C2,Length_Kbp\n')

refined_bin_index = 1
for refined_bin in concatenated_bins_list_sorted:

    refined_bin_split = refined_bin.split('___')
    refined_bin_length = concatenated_bins_length_dict_short_removed['___'.join(refined_bin_split[1:])]
    refined_bin_length_Kbp = float("{0:.2f}".format(refined_bin_length/1024))
    refined_bin_source = refined_bin_split[1:]
    refined_bin_ctgs = concatenated_bins_to_ctg_dict['___'.join(refined_bin_split[1:])]
    refined_bin_id = '%s_%s' % (output_bin_prefix, refined_bin_index)
    pwd_refined_bin = '%s/%s_%s.fasta' % (pwd_output_bin_folder, output_bin_prefix, refined_bin_index)
    pwd_ctg_source = '%s/%s/%s' % (input_bin_folder,input_bin_subfolders[0], refined_bin_split[1])

    # write out sources_and_length file
    pwd_report_file_sources_and_length_handle.write('%s\t%s\t%s\n' % (refined_bin_id, refined_bin_length_Kbp, ','.join(refined_bin_source)))

    # write out contig file
    pwd_report_file_contigs_handle.write('%s\t%s\n' % (refined_bin_id, ','.join(sorted(refined_bin_ctgs))))

    # write out file for sankey plot
    n = 0
    while n <= len(refined_bin_source) - 2:
        pwd_report_file_for_sankey_handle.write('%s\n' % ','.join([refined_bin_source[n], refined_bin_source[n+1], str(refined_bin_length_Kbp)]))
        n += 1

    # extract sequences for refined bin
    pwd_refined_bin_handle = open(pwd_refined_bin, 'w')
    for each_source_ctg in SeqIO.parse(pwd_ctg_source, 'fasta'):
        if str(each_source_ctg.id) in refined_bin_ctgs:
            SeqIO.write(each_source_ctg, pwd_refined_bin_handle, 'fasta')
    pwd_refined_bin_handle.close()

    refined_bin_index += 1

pwd_report_file_sources_and_length_handle.close()
pwd_report_file_contigs_handle.close()
pwd_report_file_for_sankey_handle.close()


# progress report
if keep_quiet is False:
    print('%s %s' % ((datetime.now().strftime(time_format)), 'Prepare Sankey plot'))


# get sankey plot
if plot_sankey is True:

    if sankey_plot_width is False:
        sankey_plot_width = len(input_bin_subfolders) * 200
    if sankey_plot_height is False:
        sankey_plot_height = len(concatenated_bins_list_sorted) * 12

    cmd_sankey = 'Rscript %s -f %s -x %s -y %s' % (pwd_plot_sankey_R, pwd_report_file_for_sankey, sankey_plot_width, sankey_plot_height)
    os.system(cmd_sankey)


# progress report
if keep_quiet is False:
    print('%s %s' % ((datetime.now().strftime(time_format)), 'Done!'))
