#! /usr/bin/env python

from __future__ import print_function
import os,sys
import argparse
from sys import stdout
import shutil
import subprocess

import errno
from time import time
import itertools

import signal
from multiprocessing import Pool
import multiprocessing as mp

# import math
import re

def mkdir_p(path):
    try:
        os.makedirs(path)
        print("creating", path)
    except OSError as exc:  # Python >2.5
        if exc.errno == errno.EEXIST and os.path.isdir(path):
            pass
        else:
            raise


def wccount(filename):
    out = subprocess.Popen(['wc', '-l', filename],
                         stdout=subprocess.PIPE,
                         stderr=subprocess.STDOUT
                         ).communicate()[0]
    # print(int(out.split()[0]))
    return int(out.split()[0])

def isoncorrect(data):
    isoncorrect_location, read_fastq_file, outfolder, cl_id, isoncorrect_algorithm_params = data[0],data[1],data[2], data[3], data[4]
    mkdir_p(outfolder)
    isoncorrect_exec = os.path.join(isoncorrect_location, "isONcorrect")
    isoncorrect_error_file = os.path.join(outfolder, "stderr.txt")
    with open(isoncorrect_error_file, "w") as error_file:
        print('Running isoncorrect cl_id:{0}...'.format(cl_id), end=' ')
        stdout.flush()
        
        # null = open("/dev/null", "w")
        isoncorrect_out_file = open(os.path.join(outfolder, "stdout.txt"), "w")

        subprocess.check_call([ "/usr/bin/time", isoncorrect_exec, "--fastq",  read_fastq_file,  "--outfolder",  outfolder, 
                                "--exact_instance_limit",  str(isoncorrect_algorithm_params["exact_instance_limit"]),
                                "--set_w_dynamically", 
                                "--k",  str(isoncorrect_algorithm_params["k"]),  "--w",  str(isoncorrect_algorithm_params["w"]),
                                "--xmin",  str(isoncorrect_algorithm_params["xmin"]),  "--xmax",  str(isoncorrect_algorithm_params["xmax"]),
                                "--T",  str(isoncorrect_algorithm_params["T"]) ], stderr=error_file, stdout=isoncorrect_out_file)
        print('Done with cl_id:{0}.'.format(cl_id))
        stdout.flush()
    error_file.close()
    isoncorrect_out_file.close()
    return cl_id

def main(args):
    directory = os.fsencode(args.fastq_folder)
    isoncorrect_location = os.path.dirname(os.path.realpath(__file__))
    print(isoncorrect_location)
    instances = []
    for file_ in os.listdir(directory):
        read_fastq_file = os.fsdecode(file_)
        if read_fastq_file.endswith(".fastq"): 
            cl_id = read_fastq_file.split(".")[0]
            outfolder = os.path.join(args.outfolder, cl_id)
            if int(cl_id) % args.split_mod != args.residual:
                print('skipping {0} because args.split_mod:{1} and args.residual:{2} set.'.format(cl_id, args.split_mod, args.residual))
                continue
            # print(cl_id, outfolder, read_fastq_file, directory)
            fastq_file_path = os.path.join(os.fsdecode(directory), read_fastq_file)
            # print(fastq_file_path)
            compute = True
            if args.keep_old:
                candidate_corrected_file = os.path.join(outfolder, "corrected_reads.fastq")
                if os.path.isfile(candidate_corrected_file): 
                    if wccount(candidate_corrected_file) == wccount(fastq_file_path):
                        print("already computed cluster and complete file", cl_id)
                        compute = False

            if compute:
                isoncorrect_algorithm_params = {  "set_w_dynamically" : args.set_w_dynamically, "exact_instance_limit" : args.exact_instance_limit, "k": args.k, "w" : args.w, "xmin" : args.xmin, "xmax" :  args.xmax, "T" : args.T }
                instances.append((isoncorrect_location, fastq_file_path, outfolder, int(cl_id), isoncorrect_algorithm_params) )
            # else:
            #     isoncorrect_algorithm_params = {  "set_w_dynamically" : args.set_w_dynamically, "exact_instance_limit" : args.exact_instance_limit, "k": args.k, "w" : args.w, "xmin" : args.xmin, "xmax" :  args.xmax, "T" : args.T }
            #     instances.append((isoncorrect_location, fastq_file_path, outfolder, int(cl_id), isoncorrect_algorithm_params) )
            # print(os.path.join(directory, read_fastq_file))
            # continue
        else:
            continue

    # sys.exit()
    instances.sort(key = lambda x: x[3])
    for t in instances:
        print(t)
    original_sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
    signal.signal(signal.SIGINT, original_sigint_handler)
    mp.set_start_method('spawn')
    print(mp.get_context())
    print("Environment set:", mp.get_context())
    print("Using {0} cores.".format(args.nr_cores))
    start_multi = time()
    pool = Pool(processes=int(args.nr_cores))
    try:
        # res = pool.map_async(isoncorrect, instances, chunksize=1)
        # score_results =res.get(999999999) # Without the timeout this blocking call ignores all signals.
        start = time()
        for x in pool.imap_unordered(isoncorrect, instances):
            print("{} (Time elapsed: {}s)".format(x, int(time() - start)))
    except KeyboardInterrupt:
        print("Caught KeyboardInterrupt, terminating workers")
        pool.terminate()
        sys.exit()
    else:
        pool.close()
    pool.join()

    print("Time elapesd multiprocessing:", time() - start_multi)


    return


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="De novo clustering of long-read transcriptome reads", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--version', action='version', version='%(prog)s 0.0.3')
    parser.add_argument('--fastq_folder', type=str,  default=False, help='Path to input fastq folder with reads in clusters')
    parser.add_argument('--t', dest="nr_cores", type=int, default=8, help='Number of cores allocated for clustering')
    parser.add_argument('--k', type=int, default=9, help='Kmer size')
    parser.add_argument('--w', type=int, default=10, help='Window size')
    parser.add_argument('--xmin', type=int, default=14, help='Lower interval length')
    parser.add_argument('--xmax', type=int, default=80, help='Upper interval length')
    parser.add_argument('--T', type=float, default=0.1, help='Minimum fraction keeping substitution')
    parser.add_argument('--exact_instance_limit', type=int, default=50,  help='Do exact correction for clusters under this threshold')
    # parser.add_argument('--w_equal_k_limit', type=int, default=100,  help='Do not recompute previous results')
    parser.add_argument('--keep_old', action="store_true", help='Do not recompute previous results if corrected_reads.fq is found and has the smae number of reads as input file (i.e., is complete).')
    parser.add_argument('--set_w_dynamically', action="store_true", help='Set w = k + max(2*k, floor(cluster_size/1000)).')

    parser.add_argument('--split_mod', type=int, default=1, help='Splits cluster ids in n (default=1) partitions by computing residual of cluster_id divided by n.\
                                                                    this parameter needs to be combined with  --residual to take effect.')
    parser.add_argument('--residual', type=int, default=0, help='Run isONcorrect on cluster ids with residual (default 0) of cluster_id divided by --split_mod. ')

    # parser.add_argument('--exact', action="store_true", help='Get exact solution for WIS for evary read (recalculating weights for each read (much slower but slightly more accuracy,\
    #                                                              not to be used for clusters with over ~500 reads)')
  
    parser.add_argument('--outfolder', type=str,  default=None, help='A fasta file with transcripts that are shared between samples and have perfect illumina support.')
    args = parser.parse_args()

    if len(sys.argv)==1:
        parser.print_help()
        sys.exit()
    
    if args.split_mod > 1:
        assert args.residual < args.split_mod 

    if args.outfolder and not os.path.exists(args.outfolder):
        os.makedirs(args.outfolder)

    main(args)
