#! /usr/bin/env python

from __future__ import print_function
import os,sys
import argparse
from sys import stdout
import shutil
import subprocess
import glob
import tempfile 
import errno
from time import time


import itertools

import signal
from multiprocessing import Pool
import multiprocessing as mp

# import math
import re

def mkdir_p(path):
    try:
        os.makedirs(path)
        print("creating", path)
    except OSError as exc:  # Python >2.5
        if exc.errno == errno.EEXIST and os.path.isdir(path):
            pass
        else:
            raise


def wccount(filename):
    out = subprocess.Popen(['wc', '-l', filename],
                         stdout=subprocess.PIPE,
                         stderr=subprocess.STDOUT
                         ).communicate()[0]
    # print(int(out.split()[0]))
    return int(out.split()[0])

def isoncorrect(data):
    isoncorrect_location, read_fastq_file, outfolder, batch_id, isoncorrect_algorithm_params = data[0],data[1],data[2], data[3], data[4]
    mkdir_p(outfolder)
    isoncorrect_exec = os.path.join(isoncorrect_location, "isONcorrect")
    isoncorrect_error_file = os.path.join(outfolder, "stderr.txt")
    with open(isoncorrect_error_file, "w") as error_file:
        print('Running isoncorrect batch_id:{0}...'.format(batch_id), end=' ')
        stdout.flush()

        dyn_flag = "--set_w_dynamically" if isoncorrect_algorithm_params["set_w_dynamically"] else ''
        racon_flag = "--use_racon" if isoncorrect_algorithm_params["use_racon"] else ''
        
        # null = open("/dev/null", "w")
        isoncorrect_out_file = open(os.path.join(outfolder, "stdout.txt"), "w")
        # print( " ".join([ "/usr/bin/time", isoncorrect_exec, "--fastq",  read_fastq_file,  "--outfolder",  outfolder, 
        #                         "--exact_instance_limit",  str(isoncorrect_algorithm_params["exact_instance_limit"]),
        #                         "--set_w_dynamically" if isoncorrect_algorithm_params["set_w_dynamically"] else '',  "--max_seqs", str(isoncorrect_algorithm_params["max_seqs"]),
        #                         "--use_racon" if isoncorrect_algorithm_params["use_racon"] else '',
        #                         "--k",  str(isoncorrect_algorithm_params["k"]),  "--w",  str(isoncorrect_algorithm_params["w"]),
        #                         "--xmin",  str(isoncorrect_algorithm_params["xmin"]),  "--xmax",  str(isoncorrect_algorithm_params["xmax"]),
        #                         "--T",  str(isoncorrect_algorithm_params["T"]) ]))

        # subprocess.check_call([ "/usr/bin/time", isoncorrect_exec, "--fastq {0} --outfolder {1} --exact_instance_limit {2} {3} {4} --max_seqs {5} --k {6} --w {7} --xmin {8} --xmax {9} --T {10}".format(read_fastq_file,  outfolder, 
        #                          str(isoncorrect_algorithm_params["exact_instance_limit"]),
        #                         dyn_flag, racon_flag, str(isoncorrect_algorithm_params["max_seqs"]),
        #                         str(isoncorrect_algorithm_params["k"]), str(isoncorrect_algorithm_params["w"]),
        #                         str(isoncorrect_algorithm_params["xmin"]), str(isoncorrect_algorithm_params["xmax"]),
        #                         str(isoncorrect_algorithm_params["T"])) ], stderr=error_file, stdout=isoncorrect_out_file)
        if dyn_flag and racon_flag:
            subprocess.check_call([ "/usr/bin/time", isoncorrect_exec, "--fastq",  read_fastq_file,  "--outfolder",  outfolder, 
                                    "--exact_instance_limit",  str(isoncorrect_algorithm_params["exact_instance_limit"]),
                                    dyn_flag,  "--max_seqs", str(isoncorrect_algorithm_params["max_seqs"]),
                                    racon_flag,
                                    "--k",  str(isoncorrect_algorithm_params["k"]),  "--w",  str(isoncorrect_algorithm_params["w"]),
                                    "--xmin",  str(isoncorrect_algorithm_params["xmin"]),  "--xmax",  str(isoncorrect_algorithm_params["xmax"]),
                                    "--T",  str(isoncorrect_algorithm_params["T"]) ], stderr=error_file, stdout=isoncorrect_out_file)
        elif dyn_flag:
            subprocess.check_call([ "/usr/bin/time", isoncorrect_exec, "--fastq",  read_fastq_file,  "--outfolder",  outfolder, 
                                    "--exact_instance_limit",  str(isoncorrect_algorithm_params["exact_instance_limit"]),
                                    dyn_flag,  "--max_seqs", str(isoncorrect_algorithm_params["max_seqs"]),
                                    "--k",  str(isoncorrect_algorithm_params["k"]),  "--w",  str(isoncorrect_algorithm_params["w"]),
                                    "--xmin",  str(isoncorrect_algorithm_params["xmin"]),  "--xmax",  str(isoncorrect_algorithm_params["xmax"]),
                                    "--T",  str(isoncorrect_algorithm_params["T"]) ], stderr=error_file, stdout=isoncorrect_out_file)      
        elif racon_flag:
            subprocess.check_call([ "/usr/bin/time", isoncorrect_exec, "--fastq",  read_fastq_file,  "--outfolder",  outfolder, 
                                    "--exact_instance_limit",  str(isoncorrect_algorithm_params["exact_instance_limit"]),
                                    racon_flag,  "--max_seqs", str(isoncorrect_algorithm_params["max_seqs"]),
                                    "--k",  str(isoncorrect_algorithm_params["k"]),  "--w",  str(isoncorrect_algorithm_params["w"]),
                                    "--xmin",  str(isoncorrect_algorithm_params["xmin"]),  "--xmax",  str(isoncorrect_algorithm_params["xmax"]),
                                    "--T",  str(isoncorrect_algorithm_params["T"]) ], stderr=error_file, stdout=isoncorrect_out_file)
        else:
            subprocess.check_call([ "/usr/bin/time", isoncorrect_exec, "--fastq",  read_fastq_file,  "--outfolder",  outfolder, 
                                    "--exact_instance_limit",  str(isoncorrect_algorithm_params["exact_instance_limit"]),
                                    "--max_seqs", str(isoncorrect_algorithm_params["max_seqs"]),
                                    "--k",  str(isoncorrect_algorithm_params["k"]),  "--w",  str(isoncorrect_algorithm_params["w"]),
                                    "--xmin",  str(isoncorrect_algorithm_params["xmin"]),  "--xmax",  str(isoncorrect_algorithm_params["xmax"]),
                                    "--T",  str(isoncorrect_algorithm_params["T"]) ], stderr=error_file, stdout=isoncorrect_out_file)

        print('Done with batch_id:{0}.'.format(batch_id))
        stdout.flush()
    error_file.close()
    isoncorrect_out_file.close()
    return batch_id

def splitfile(indir, tmp_outdir, fname, chunksize): 
    # from https://stackoverflow.com/a/27641636/2060202
    # fpath, fname = os.path.split(infilepath)
    cl_id, ext = fname.rsplit('.',1)
    infilepath = os.path.join(indir, fname)
    # print(fpath, cl_id, ext)
    # print(indir, tmp_outdir, cl_id, ext)

    i = 0
    written = False
    with open(infilepath) as infile:
        while True:
            outfilepath = os.path.join(tmp_outdir, '{0}_{1}.{2}'.format(cl_id, i, ext) ) #"{}_{}.{}".format(foutpath, fname, i, ext)
            print(outfilepath)
            with open(outfilepath, 'w') as outfile:
                for line in (infile.readline() for _ in range(chunksize)):
                    outfile.write(line)
                written = bool(line)
            # print(os.stat(outfilepath).st_size == 0)
            if os.stat(outfilepath).st_size == 0: # Corner case: Original file is even multiple of max_seqs, hence the last file becomes empty. Remove this
                os.remove(outfilepath)
            if not written:
                break
            i += 1

import os, errno

def symlink_force(target, link_name):
    try:
        os.symlink(target, link_name)
    except OSError as e:
        if e.errno == errno.EEXIST:
            os.remove(link_name)
            os.symlink(target, link_name)
        else:
            raise e

def split_cluster_in_batches(indir, outdir, tmp_work_dir, max_seqs):
    # create a modified indir
    tmp_work_dir = os.path.join(tmp_work_dir, 'split_in_batches')
    # print(indir)
    mkdir_p(tmp_work_dir)
    smaller_than_max_seqs = False
    # print(sorted(os.listdir(indir), key=lambda x: int(x.split('.')[0])) )
    # sys.exit()
    # add split fiels to this indir
    for file_ in sorted(os.listdir(indir), key=lambda x: int(x.split('.')[0])):
        fastq_file = os.fsdecode(file_)
        if fastq_file.endswith(".fastq"): 
            if not smaller_than_max_seqs:
                num_lines = sum(1 for line in open(os.path.join(indir, fastq_file)))
                print(fastq_file, num_lines)
                smaller_than_max_seqs = False if num_lines > 4*max_seqs else True
            else:
                smaller_than_max_seqs = True

            if not smaller_than_max_seqs:
                splitfile(indir, tmp_work_dir, fastq_file, 4*max_seqs) # is fastq file
            else:
                cl_id, ext = fastq_file.rsplit('.',1)
                print(fastq_file, "symlinking instead")
                symlink_force(os.path.join( indir, fastq_file), os.path.join(tmp_work_dir, '{0}_{1}.{2}'.format(cl_id, 0, ext) ))
            # cl_id = read_fastq_file.split(".")[0]
            # outfolder = os.path.join(args.outfolder, cl_id)
    return tmp_work_dir


def join_back_corrected_batches_into_cluster(tmp_work_dir, outdir, split_mod, residual):
    print(outdir, tmp_work_dir)
    unique_cl_ids = set()
    for file in os.listdir(tmp_work_dir):
        file = os.fsdecode(file)
        # print(file)
        fname =  file.split('_')
        if len(fname) == 2:
            cl_id, batch_id = fname[0], fname[1] #file.split('_')

            if int(cl_id) % split_mod != residual:
                # print('skipping {0} because args.split_mod:{1} and args.residual:{2} set.'.format(cl_id, args.split_mod, args.residual))
                continue
            unique_cl_ids.add(cl_id)

    for cl_id in unique_cl_ids:
        out_pattern = os.path.join(outdir, cl_id)
        # print(type(tmp_work_dir), type(cl_id))
        batches_pattern = os.path.join(os.fsdecode(outdir), cl_id+'_*')
        # print("joining all", out_pattern, "from", batches_pattern)
        mkdir_p(out_pattern)

        error_file = open(os.path.join(out_pattern, 'cat.stderr'), 'w')
        outfilename = os.path.join(out_pattern, 'corrected_reads.fastq')
        # print("into outfile", outfilename)

        with open(outfilename, 'wb') as outfile:
            for batch_id in sorted(glob.glob(batches_pattern)):
                # print(batch_id)
                filename = os.path.join(batch_id, 'corrected_reads.fastq')
                if filename == outfilename:
                    # don't want to copy the output into the output
                    continue
                with open(filename, 'rb') as readfile:
                    shutil.copyfileobj(readfile, outfile)
                # print('Removing', batch_id)
                shutil.rmtree(batch_id)



def main(args):
    directory = args.fastq_folder #os.fsencode(args.fastq_folder)
    isoncorrect_location = os.path.dirname(os.path.realpath(__file__))
    if args.split_wrt_batches:
        tmp_work_dir = tempfile.mkdtemp()
        print("Temporary workdirektory:", tmp_work_dir)
        split_tmp_directory = split_cluster_in_batches(directory, args.outfolder, tmp_work_dir, args.max_seqs)
        split_directory = os.fsencode(split_tmp_directory)
    else:
        split_directory = os.fsencode(directory)

    print(isoncorrect_location)
    instances = []
    for file_ in os.listdir(split_directory):
        read_fastq_file = os.fsdecode(file_)
        if read_fastq_file.endswith(".fastq"): 
            batch_id = read_fastq_file.split(".")[0]
            cl_id = batch_id.split("_")[0]
            outfolder = os.path.join(args.outfolder, batch_id)
            if int(cl_id) % args.split_mod != args.residual:
                print('skipping {0} because args.split_mod:{1} and args.residual:{2} set.'.format(batch_id, args.split_mod, args.residual))
                continue
            # print(batch_id, outfolder, read_fastq_file, split_directory)
            fastq_file_path = os.path.join(os.fsdecode(split_directory), read_fastq_file)
            # print(fastq_file_path)
            compute = True
            if args.keep_old:
                candidate_corrected_file = os.path.join(outfolder, "corrected_reads.fastq")
                if os.path.isfile(candidate_corrected_file): 
                    if wccount(candidate_corrected_file) == wccount(fastq_file_path):
                        print("already computed cluster and complete file", batch_id)
                        compute = False

            if compute:
                isoncorrect_algorithm_params = {  "set_w_dynamically" : args.set_w_dynamically, "exact_instance_limit" : args.exact_instance_limit, 
                                                    "k": args.k, "w" : args.w, "xmin" : args.xmin, "xmax" :  args.xmax, "T" : args.T, "max_seqs" : args.max_seqs,  "use_racon" : args.use_racon  }
                instances.append((isoncorrect_location, fastq_file_path, outfolder, batch_id, isoncorrect_algorithm_params) )
            # else:
            #     isoncorrect_algorithm_params = {  "set_w_dynamically" : args.set_w_dynamically, "exact_instance_limit" : args.exact_instance_limit, "k": args.k, "w" : args.w, "xmin" : args.xmin, "xmax" :  args.xmax, "T" : args.T }
            #     instances.append((isoncorrect_location, fastq_file_path, outfolder, int(batch_id), isoncorrect_algorithm_params) )
            # print(os.path.join(split_directory, read_fastq_file))
            # continue
        else:
            continue

    # sys.exit()
    instances.sort(key = lambda x: x[3])
    for t in instances:
        print(t)
    original_sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
    signal.signal(signal.SIGINT, original_sigint_handler)
    mp.set_start_method('spawn')
    print(mp.get_context())
    print("Environment set:", mp.get_context())
    print("Using {0} cores.".format(args.nr_cores))
    start_multi = time()
    pool = Pool(processes=int(args.nr_cores))
    try:
        # res = pool.map_async(isoncorrect, instances, chunksize=1)
        # score_results =res.get(999999999) # Without the timeout this blocking call ignores all signals.
        start = time()
        for x in pool.imap_unordered(isoncorrect, instances):
            print("{} (Time elapsed: {}s)".format(x, int(time() - start)))
    except KeyboardInterrupt:
        print("Caught KeyboardInterrupt, terminating workers")
        pool.terminate()
        sys.exit()
    else:
        pool.close()
    pool.join()

    print("Time elapesd multiprocessing:", time() - start_multi)

    if args.split_wrt_batches:
        file_handling = time()
        join_back_corrected_batches_into_cluster(split_directory, args.outfolder, args.split_mod, args.residual)
        shutil.rmtree(split_directory)
        print("Joined back batched files in:", time() - file_handling)

    return


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="De novo clustering of long-read transcriptome reads", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--version', action='version', version='%(prog)s 0.0.6')
    parser.add_argument('--fastq_folder', type=str,  default=False, help='Path to input fastq folder with reads in clusters')
    parser.add_argument('--t', dest="nr_cores", type=int, default=8, help='Number of cores allocated for clustering')
    parser.add_argument('--k', type=int, default=9, help='Kmer size')
    parser.add_argument('--w', type=int, default=10, help='Window size')
    parser.add_argument('--xmin', type=int, default=14, help='Lower interval length')
    parser.add_argument('--xmax', type=int, default=80, help='Upper interval length')
    parser.add_argument('--T', type=float, default=0.1, help='Minimum fraction keeping substitution')
    parser.add_argument('--exact_instance_limit', type=int, default=50,  help='Do exact correction for clusters under this threshold')
    # parser.add_argument('--w_equal_k_limit', type=int, default=100,  help='Do not recompute previous results')
    parser.add_argument('--keep_old', action="store_true", help='Do not recompute previous results if corrected_reads.fq is found and has the smae number of reads as input file (i.e., is complete).')
    parser.add_argument('--set_w_dynamically', action="store_true", help='Set w = k + max(2*k, floor(cluster_size/1000)).')
    parser.add_argument('--max_seqs', type=int, default=1000,  help='Maximum number of seqs to correct at a time (in case of large clusters).')
    parser.add_argument('--use_racon', action="store_true", help='Use racon to polish consensus after spoa (more time consuming but higher accuracy).')

    parser.add_argument('--split_mod', type=int, default=1, help='Splits cluster ids in n (default=1) partitions by computing residual of cluster_id divided by n.\
                                                                    this parameter needs to be combined with  --residual to take effect.')
    parser.add_argument('--residual', type=int, default=0, help='Run isONcorrect on cluster ids with residual (default 0) of cluster_id divided by --split_mod. ')

    # parser.add_argument('--exact', action="store_true", help='Get exact solution for WIS for evary read (recalculating weights for each read (much slower but slightly more accuracy,\
    #                                                              not to be used for clusters with over ~500 reads)')

    parser.add_argument('--split_wrt_batches', action="store_true", help='Process reads per batch (of max_seqs sequences) instead of per cluster. Significantly decrease runtime when few very large clusters are less than the number of cores used.')
  
    parser.add_argument('--outfolder', type=str,  default=None, help='Outfolder with all corrected reads.')
    args = parser.parse_args()

    if len(sys.argv)==1:
        parser.print_help()
        sys.exit()
    
    if args.split_mod > 1:
        assert args.residual < args.split_mod 

    if args.outfolder and not os.path.exists(args.outfolder):
        os.makedirs(args.outfolder)

    main(args)
