import openpyxl, gzip, datetime, psutil, time
from Bio.Data.IUPACData import ambiguous_dna_values
from itertools import product
from Bio.SeqIO.QualityIO import FastqGeneralIterator
from pathlib import Path
from joblib import Parallel, delayed

def extend_ambiguous_dna(seq):
   """return list of all possible sequences given an ambiguous DNA input"""
   d = ambiguous_dna_values
   return list(map("".join, product(*map(d.get, seq))))

## extends a tag to site dict in the form of
## combines all possible tags that can be created when taking ambiguous bases into account
## {(fwd_tag, rev_tag): site}
def extend_tag_to_site_dict(dict):

    tag_to_site = {}

    for key in dict.keys():
        forward = extend_ambiguous_dna(key[0])
        reverse = extend_ambiguous_dna(key[1])

        combinations = list(product(forward, reverse))

        for combination in combinations:
            tag_to_site[combination] = dict[key]

    return tag_to_site

## the core demultiplexing function. only processes one line of the tagging scheme
## for easy handling of multiprocessing later on
## primerset is the set generated by main
## tagging_scheme_line is one row of the tagging scheme file
def demultiplex(primerset, tagging_scheme_header ,tagging_scheme_line, output_folder, tag_removal):

    ## create a dictionary for the in_handles which are the two files to demultiplex
    in_handles = {'fwd': FastqGeneralIterator(gzip.open(Path(tagging_scheme_line[0]), 'rt')),
                  'rev': FastqGeneralIterator(gzip.open(Path(tagging_scheme_line[1]), 'rt'))}

    ## read the combinations from the header, replace name with sequence
    combinations = [tuple(cell.split(' - ')) for cell in tagging_scheme_header[4:]]
    combinations = [(primerset[primername[0]], primerset[primername[1]]) for primername in combinations]

    ## connect combinations and sites if site was used
    tag_to_site = {tag: site for tag, site in zip(combinations, tagging_scheme_line[4:]) if site}

    ## blow up tag to site dict by taking ambiguities into account
    ## will slow down the algorithm a lot so its a good idea not to have
    ## ambiguous bases in the tag
    tag_to_site = extend_tag_to_site_dict(tag_to_site)

    ## generate all output handles, add the nomatch handle last where everything
    ## that does not match a tag is written to
    out_handles = {}

    for sample in tagging_scheme_line[4:]:
        if sample:
            fwd_path = Path(output_folder).joinpath('{}_r1.fastq.gz'.format(sample))
            rev_path = Path(output_folder).joinpath('{}_r2.fastq.gz'.format(sample))
            out_handles[sample] = (gzip.open(fwd_path, 'wt', compresslevel = 6), gzip.open(rev_path, 'wt', compresslevel = 6))

    nomatch_fwd = Path(output_folder).joinpath('no_match_{}_r1.fastq.gz'.format(Path(tagging_scheme_line[2]).with_suffix('').with_suffix('')))
    nomatch_rev = Path(output_folder).joinpath('no_match_{}_r2.fastq.gz'.format(Path(tagging_scheme_line[3]).with_suffix('').with_suffix('')))
    out_handles['nomatch'] = (gzip.open(nomatch_fwd, 'wt', compresslevel = 6), gzip.open(nomatch_rev, 'wt', compresslevel = 6))

    ## core demultiplexing code. checks all lines of the input file against all sequence combinations
    ## selects the corresponding output file and add the line to it
    ## counter count how many reads are processed for output
    ## optinal tag cutting removes the tag within the demultiplexing step
    count = 0

    for (title_f, seq_f, qual_f), (title_r, seq_r, qual_r) in zip(in_handles['fwd'],  in_handles['rev']):

        ## handle nomatches only after all combinations where checked
        no_match = False

        ## check all combinations for a match
        for combination in tag_to_site.keys():
            if seq_f.startswith(combination[0]) and seq_r.startswith(combination[1]):

                ## tag removal code is only required if a tag is found, otherwise there is nothing to cut off
                if tag_removal:
                    fwd_tag_len, rev_tag_len = len(combination[0]), len(combination[1])
                else:
                    fwd_tag_len, rev_tag_len = 0, 0

                ## write output, optinal removal of tags
                out_handles[tag_to_site[combination]][0].write('@{}\n{}\n+\n{}\n'.format(title_f, seq_f[fwd_tag_len:], qual_f[fwd_tag_len:]))
                out_handles[tag_to_site[combination]][1].write('@{}\n{}\n+\n{}\n'.format(title_r, seq_r[rev_tag_len:], qual_r[rev_tag_len:]))
                count += 1
                break
        else:
            no_match = True

        ## append to nomatch if tag cannot be found
        if no_match:
            out_handles['nomatch'][0].write('@{}\n{}\n+\n{}\n'.format(title_f, seq_f, qual_f))
            out_handles['nomatch'][1].write('@{}\n{}\n+\n{}\n'.format(title_r, seq_r, qual_r))
            count += 1

    ## close output files when done with demutliplexing
    for sample in out_handles.keys():
        out_handles[sample][0].close()
        out_handles[sample][1].close()

    ## show user output STILL TO CHANGE
    print('{}: {} - {}: {} reads demultiplexed.'.format(datetime.datetime.now().strftime("%H:%M:%S"), tagging_scheme_line[2], tagging_scheme_line[3], count))

## main function to handle the control of the demutliplexing process
## accepts a primerset, tagging scheme, output folder from the main script
## also an optimal removal of the tags as well as a print handle for pretty output
def main(primerset, tagging_scheme, output_folder, tag_removal, print_handle, window):

    ## creates a dict where primer names are associated with the corresponding sequence
    primerset = {line.split(',')[0]: line.split(',')[1] for line in open(primerset, 'r')}

    ## load the tagging scheme
    wb = openpyxl.load_workbook(tagging_scheme)
    ws = wb.active

    ## collect all rows from the tagging scheme
    rows = [[cell.value for cell in row] for row in ws.iter_rows()]

    ## compute physical cores - 1 to use for demutliplexing
    cores_to_use = psutil.cpu_count(logical = False) - 1

    ## run the demultiplex function on every line in the tagging scheme in parallel
    print_handle.print('{}: Starting to demultiplex {} file pairs. Output will be routed to the terminal. The window will freeze during this process'.format(datetime.datetime.now().strftime("%H:%M:%S"), len(rows) - 1))
    window.Refresh()
    Parallel(n_jobs = cores_to_use)(delayed(demultiplex)(primerset, rows[0], rows[i], output_folder, tag_removal) for i in range(1, len(rows)))
    print_handle.print('{}: Done'.format(datetime.datetime.now().strftime("%H:%M:%S")))
