#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from pathlib import Path
from snakemake.logging import logger
import logging
from pprint import pprint as pp
# load own functions
import culebrONT
from culebrONT.module import CulebrONT
from culebrONT.snakeWrapper import *

# GLOBAL VARIABLES
version_CulebrONT = get_version()
try:
    culebront = CulebrONT(workflow, config)
    #pp(culebront)

except Exception as e:
    logger.logger.fatal(e, exc_info=False)
    logging.shutdown()
    raise SystemExit()

tools_config = culebront.tools_config
cluster_config = culebront.cluster_config

#print(culebront.string_to_dag)
# pp(culebront.export_use_yaml)

# print for debug:
# logger.debug(pp(culebront))
# exit()

# Getting paths on usefully variables
output_dir = config['DATA']['OUTPUT']
fastq_dir = config['DATA']['FASTQ']
reference_file = config['DATA']['REF']

################ WILDCARDS  ################
FASTQ, = glob_wildcards(f"{config['DATA']['FASTQ']}{{fastq}}{culebront.fastq_files_ext}")
nb_racon_rounds = culebront.nb_racon_rounds
nb_pilon_rounds = culebront.nb_pilon_rounds

# load some variables from culebront class
add_circular_name = culebront.add_circular_name

#output of Minipolish, -> (input de Tag Circular Minipolish)
TMP = culebront.TMP
#output of Tag Circular Minipolish , -> (input de fixstart)
TCM = culebront.TCM
#output of raven, shasta, flye, circlator , -> (input de tag_circular)
TAG = culebront.TAG

#######################################
# Change workdir to output path
workdir:config["DATA"]["OUTPUT"]
#######################################

def get_threads(rule, default):
    """
    retrieve threads value from cluster_config file avail for SGE and SLURM
    If local get the --core value
    if fail return default value define on each rules

    Examples:
        rule rule_graph:
            threads: get_threads('rule_graph', 1)
    """
    # if cluster mode
    if cluster_config:
        if rule in cluster_config:
            if 'threads' in cluster_config[rule]:
                return int(cluster_config[rule]['threads'])
            elif 'cpus-per-task' in cluster_config[rule]:
                return int(cluster_config[rule]['cpus-per-task'])
        elif '__default__' in cluster_config:
            if 'cpus-per-task' in cluster_config['__default__']:
                return int(cluster_config['__default__']['cpus-per-task'])
            elif 'threads' in cluster_config['__default__']:
                return int(cluster_config['__default__']['threads'])
    # if local
    elif workflow.global_resources["_cores"]:
        if default == 1:            # for rule not able threading
            return default
        else:
            return workflow.global_resources["_cores"]
    # if cluster not rule and not default or local not _cores return value from call
    return default


def get_fastq(wildcards):
    return f"{fastq_dir}{wildcards.fastq}{culebront.fastq_files_ext}"


def draft_to_circlator(wildcards):
    if 'CANU' in wildcards.assemblers:
        return rules.run_canu.output.fasta
    if 'SMARTDENOVO' in wildcards.assemblers:
        return rules.run_smartdenovo.output.fasta
    else:
        return 'None'

def draft_to_racon(wildcards):
    n = int(wildcards.nb)
    if n == 1:
        if config['CIRCULAR']:
            return f"{output_dir}{{fastq}}/ASSEMBLERS/{{assemblers}}/POLISHING/ROTATE/rotate_{n}/assembly.racon{n}.fasta"
        else:
            return f"{output_dir}{{fastq}}/ASSEMBLERS/{{assemblers}}/ASSEMBLER/assembly{add_circular_name}.fasta"
    elif n > 1:
        if config['CIRCULAR']:
            return f"{output_dir}{{fastq}}/ASSEMBLERS/{{assemblers}}/POLISHING/ROTATE/rotate_{n}/assembly.racon{n}.fasta"
        else:
            return f"{output_dir}{{fastq}}/ASSEMBLERS/{{assemblers}}/POLISHING/RACON/racon_{n-1}/assembly.racon{n-1}.fasta"
    else:
        raise ValueError(f"loop numbers must be 1 or greater: received {n}")

def draft_to_rotate(wildcards):
    n = int(wildcards.nb)
    if n == 1:
        return  f"{output_dir}{{fastq}}/ASSEMBLERS/{{assemblers}}/ASSEMBLER/assemblyCIRCULARISED_circTag.fasta"
    elif n > 1:
        return f"{output_dir}{{fastq}}/ASSEMBLERS/{{assemblers}}/POLISHING/RACON/racon_{n-1}/assembly.racon{n-1}.fasta"
    else:
        raise ValueError(f"loop numbers must be 1 or greater: received {n}")

def draft_to_correction(wildcards):
    if 'MINIASM' == wildcards.assemblers:
        return f"{output_dir}{{fastq}}/ASSEMBLERS/{{assemblers}}/POLISHING/RACON/racon_{nb_racon_rounds}/assembly.racon{nb_racon_rounds}{TCM}.fasta"
    elif config['POLISHING']['RACON']:
        return f"{output_dir}{{fastq}}/ASSEMBLERS/{{assemblers}}/POLISHING/RACON/racon_{nb_racon_rounds}/assembly.racon{nb_racon_rounds}.fasta"
    else:
        return f"{output_dir}{{fastq}}/ASSEMBLERS/{{assemblers}}/ASSEMBLER/assembly{add_circular_name}.fasta"


def draft_to_pilon(wildcards):
    n = int(wildcards.nbp)
    if n == 1:
        return f"{output_dir}{{fastq}}/ASSEMBLERS/{{assemblers}}/CORRECTION/PILON/pilon_1/assembly.pilon1.fasta"
    elif n > 1:
        return f"{output_dir}{{fastq}}/ASSEMBLERS/{{assemblers}}/CORRECTION/PILON/pilon_{n-1}/assembly.pilon{n-1}.fasta"
    else:
        raise ValueError(f"loop numbers for pilon must be 1 or greater: received {n}")

def get_illumina(wildcards):
    r1 = ""
    r2 = ""
    for element1 in culebront.R1 :
        if wildcards.fastq in element1:
            r1=element1
    for element2 in culebront.R2 :
        if wildcards.fastq in element2:
            r2=element2
    return f"{r1} {r2}"

################################ FINAL ####################################
rule final:
    input:
        f"{output_dir}FINAL_REPORT/CulebrONT_report.html"


################################ ASSEMBLY ####################################
include: f"{INSTALL_PATH}/snakefiles/assemblers.snake"

############################### CIRCULARISATION ##############################
include: f"{INSTALL_PATH}/snakefiles/circular.snake"

################################ POLISHING ####################################
include: f"{INSTALL_PATH}/snakefiles/polishers.snake"

################################ CORRECTION ####################################
include: f"{INSTALL_PATH}/snakefiles/correction.snake"

################################ QUALITY ####################################
include: f"{INSTALL_PATH}/snakefiles/quality.snake"

################################ RAPPORT ####################################

rule rule_graph:
    """
    run dag on {rule}
    """
    threads: get_threads('rule_graph', 1)
    input:
        config = f"{output_dir}config_corrected.yaml"
    params:
        tmp = f"{output_dir}FINAL_REPORT/dag.tmp",
        cmd = culebront.string_to_dag
    output:
        dag = f"{output_dir}FINAL_REPORT/dag.png",
    log:
        output = f"{output_dir}FINAL_REPORT/LOGS/GRAPH.o",
        error = f"{output_dir}FINAL_REPORT/LOGS/GRAPH.e"
    message:
        """
        making dag ...
        {params.cmd} --configfile {input.config} > {params.tmp}
        dot -Tpng {params.tmp} > {output.dag}
        """
    shell:
        """
        ({params.cmd} --configfile {input.config} > {params.tmp}
        dot -Tpng {params.tmp} > {output.dag}) 1>{log.output} 2>{log.error}
        """

rule run_report_snakemake:
    """
    run dag on {rule}
    """
    threads: get_threads('run_report_snakemake', 1)
    input:
        conf = str(culebront.path_config)
    output:
        report_snakemake = f"{output_dir}FINAL_REPORT/snake_report.html"
    log:
        output = f"{output_dir}FINAL_REPORT/LOGS/REPORT-SNAKE.o",
        error = f"{output_dir}FINAL_REPORT/LOGS/REPORT-SNAKE.e"
    message:
        """
        making report ...
        snakemake -n {SNAKEFILE} --configfile {input.conf} --report {output.report_snakemake}
        """
    shell:
        """
        cd {INSTALL_PATH}
        snakemake -s {SNAKEFILE} --configfile {input.conf} --report {output.report_snakemake} 1>{log.output} 2>{log.error}
        """

rule run_flagstats_stats:
    """
    run stat for flagstats
    """
    threads: get_threads('run_stats', 1)
    input:
        summary = expand(rules.run_flagstat.output.txt, fastq='{fastq}', assemblers=culebront.assembly_tools_activated, quality_step=culebront.last_steps_list)
    output:
        stat = f"{output_dir}{{fastq}}/REPORT/{{fastq}}_flagstats.csv"
    params:
        quality_list = culebront.quality_step,
        assembly_list = culebront.assembly_tools_activated,
        out_dir = directory(f"{output_dir}{{fastq}}/")
    log:
        output = f"{output_dir}{{fastq}}/REPORT/LOGS/run_flagstats_stats/{{fastq}}.o",
        error = f"{output_dir}{{fastq}}/REPORT/LOGS/run_flagstats_stats/{{fastq}}.e"
    message:
        """
        make stats from flagstats
        """
    script:
        f"{SNAKEMAKE_SCRIPTS}/get_stats_flagstats.py"

rule run_busco_stats:
    """
    run stat
    """
    threads: get_threads('run_stats', 1)
    input:
        summary = expand(rules.run_busco.output.summary, fastq='{fastq}', assemblers=culebront.assembly_tools_activated, quality_step=culebront.quality_step)
    output:
        stat = f"{output_dir}{{fastq}}/REPORT/{{fastq}}_busco.csv"
    params:
        quality_list = culebront.quality_step,
        assembly_list = culebront.assembly_tools_activated,
        out_dir = directory(f"{output_dir}{{fastq}}/")
    log:
        output = f"{output_dir}{{fastq}}/REPORT/LOGS/{{fastq}}_BUSCO-STATS.o",
        error = f"{output_dir}{{fastq}}/REPORT/LOGS/{{fastq}}_BUSCO-STATS.e"
    message:
        """
        make stats
        """
    script:
        f"{SNAKEMAKE_SCRIPTS}/get_stats_busco.py"

rule run_racon_version:
    """
    launch Racon version
    """
    threads: get_threads('run_racon_version', 1)
    input:
        fastq = get_fastq
    output:
        version = f"{output_dir}{{fastq}}/REPORT/RACON-version.txt"
    log:
        output =  f"{output_dir}{{fastq}}/REPORT/LOGS/{{fastq}}-RACON-VERSION.o",
        error =  f"{output_dir}{{fastq}}/REPORT/LOGS/{{fastq}}-RACON-VERSION.e"
    singularity:
        tools_config['SINGULARITY']['TOOLS']
    envmodules:
        tools_config["ENVMODULE"]["RACON"],
        tools_config["ENVMODULE"]["MINIMAP2"]
    shell:
        """
        [ ! -f {output.version} ] &&  [ ! -d versions ] && mkdir -p versions; racon --version > {output.version}
        ln -s -f {output.version} {output_dir}versions/RACON-version.txt
        """

rule run_busco_version:
    """
    busco version
    """
    threads: get_threads('run_busco_version', 1)
    input:
        stats = rules.run_busco_stats.output.stat
    output:
        busco_version = f"{output_dir}{{fastq}}/REPORT/BUSCO-version.txt",
        augustus_version = f"{output_dir}{{fastq}}/REPORT/BUSCO_AUGUSTUS-version.txt"
    log:
        output = f"{output_dir}{{fastq}}/REPORT/LOGS/{{fastq}}_BUSCO-VERSION.o",
        error = f"{output_dir}{{fastq}}/REPORT/LOGS/{{fastq}}_BUSCO-VERSION.e"
    message:
        """
        busco version
        """
    singularity:
        tools_config['SINGULARITY']['TOOLS']
    envmodules:
        tools_config["ENVMODULE"]["BUSCO"]
    shell:
        """
        ([ ! -f {output.busco_version} ] &&  [ ! -d versions ] && mkdir -p versions; busco --version > {output.busco_version}
        cp -f {output.busco_version} {output_dir}versions/BUSCO-version.txt
        [ ! -f {output.augustus_version} ] &&  [ ! -d versions ] && mkdir -p versions; augustus --version 2>&1 | grep 'AUGUSTUS' > {output.augustus_version}
        cp -f {output.augustus_version} {output_dir}versions/BUSCO_AUGUSTUS-version.txt) 1>{log.output} 2>{log.error}
        """

def get_inputs_benchmark(wildcards):
    dico_benchmark_inputs = {
            "assembly_list": expand(f"{output_dir}{{{{fastq}}}}/BENCHMARK/ASSEMBLER/{{assemblers}}.txt",
                                    assemblers=[ass for ass in culebront.assembly_tools_activated if ass not in ["MINIASM"]]),
    }
    if culebront.assembly_tools_activated:
        if "MINIASM" in culebront.assembly_tools_activated:
            dico_benchmark_inputs["assembly_list"].append(f"{output_dir}{{fastq}}/BENCHMARK/ASSEMBLER/MINIASM_MINIPOLISH.txt"),

    if culebront.polishing_tools_activated:
        dico_benchmark_inputs["polishers_list"] = expand(f"{output_dir}{{{{fastq}}}}/BENCHMARK/POLISHING/{{assemblers}}_{{polishers}}{{nb}}.txt",
                    assemblers=[ass for ass in culebront.assembly_tools_activated if ass not in ["MINIASM"]],
                    polishers=culebront.polishing_tools_activated,
                    nb = range(1, int(nb_racon_rounds)+1))

    if culebront.correction_tools_activated:
        dico_benchmark_inputs["correction_list"] = expand(f"{output_dir}{{{{fastq}}}}/BENCHMARK/CORRECTION/{{correction}}/{{assemblers}}.txt",
                    assemblers=culebront.assembly_tools_activated,
                    correction=[ass for ass in culebront.correction_tools_activated if ass not in ["PILON"]])
        if "PILON" in culebront.correction_tools_activated:
            dico_benchmark_inputs["correction_list"] = expand(f"{output_dir}{{{{fastq}}}}/BENCHMARK/CORRECTION/PILON/{{assemblers}}_PILON{{nb}}.txt",
                        assemblers=culebront.assembly_tools_activated,
                        correction=['PILON'],
                        nb = range(1, int(nb_pilon_rounds)+1))
    return dico_benchmark_inputs

rule run_benchmark_time:
    """
    run benchmark
    """
    threads: get_threads('run_benchmark_time', 1)
    input:
        unpack(get_inputs_benchmark)
    output:
        stat_time = f"{output_dir}{{fastq}}/REPORT/{{fastq}}_time.csv"
    params:
        out_dir = f"{output_dir}{{fastq}}",
        sample = f"{{fastq}}",
        dico = get_inputs_benchmark
    log:
        output = f"{output_dir}{{fastq}}/REPORT/LOGS/{{fastq}}_BENCHMARK-TIME.o",
        error = f"{output_dir}{{fastq}}/REPORT/LOGS/{{fastq}}_BENCHMARK-TIME.e"
    message:
        """
        unpacking final dico and benchmark
        """
    script:
        f"{SNAKEMAKE_SCRIPTS}/get_benchmark.py"

rule run_get_versions:
    """
    recovery soft versions
    """
    threads: get_threads('run_get_versions', 1)
    input:
        assemblers = expand(f"{output_dir}{{fastq}}/ASSEMBLERS/{{assemblers}}/ASSEMBLER/{{assemblers}}-version.txt", fastq = FASTQ[0], assemblers=culebront.assembly_tools_activated),
        polishers = expand(rules.run_racon_version.output.version, fastq=FASTQ),
        correction = expand(f"{output_dir}{{fastq}}/ASSEMBLERS/{{assemblers}}/CORRECTION/{{correction}}/{{correction}}-version.txt", fastq=FASTQ[0], assemblers=culebront.assembly_tools_activated, correction=culebront.correction_tools_activated),
        circular = expand(f"{output_dir}{{fastq}}/ASSEMBLERS/{{assemblers}}/ASSEMBLER/CIRCLATOR-version.txt", fastq=FASTQ[0], assemblers=[ass for ass in culebront.assembly_tools_activated if ass in ["CANU","SMARTDENOVO"] and bool(culebront.config['CIRCULAR'])]),
        quality = expand(f"{output_dir}{{fastq}}/ASSEMBLERS/{{assemblers}}/QUALITY/{{quality_step}}/{{quality}}/{{quality}}-version.txt", fastq=FASTQ[0], assemblers=culebront.assembly_tools_activated[0], quality_step=culebront.last_steps_list, quality=[qual for qual in culebront.quality_tools_activated if qual not in ["QUAST", "MAUVE", "BUSCO"]]),
        mauve = expand(f"{output_dir}/versions/{{quality}}-version.txt", quality=[qual for qual in culebront.quality_tools_activated if qual in ["MAUVE"]])
    params:
        dir =f'{output_dir}versions/'
    output:
        csv = f"{output_dir}versions.csv"
    message:
        """
        picking software versions used by CulebrONT
        """
    log:
        output = f'{output_dir}versions/LOGS/GET-VERSIONS.o',
        error = f'{output_dir}versions/LOGS/GET-VERSIONS.e'
    script:
        f"{SNAKEMAKE_SCRIPTS}/get_versions.py"

rule stats_assembly:
    """ build stats on assembly"""
    threads: get_threads('stats_assembly', 1)
    input:
        fasta_list = expand(rules.preparing_fasta_to_quality.output.renamed,fastq='{fastq}',assemblers=culebront.assembly_tools_activated,quality_step=culebront.quality_step)
    output:
        csv = f"{output_dir}FINAL_REPORT/all_stats_{{fastq}}.csv"
    params:
        sample = '{fastq}'
    log:
        output = f"{output_dir}FINAL_REPORT/LOGS/{{fastq}}_STATS-ASSEMBLIES.o",
        error = f"{output_dir}FINAL_REPORT/LOGS/{{fastq}}_STATS-ASSEMBLIES.e"
    message:
        """
        Launching {rule}
        threads : {threads}
        input:
             fasta_list : {input.fasta_list}
        output:
             csv : {output.csv}
        log:
            output : {log.output}
            error: {log.error}
        """
    script:
        f"{SNAKEMAKE_SCRIPTS}/assembly-stats.py"


rule copy_final_assemblies:
    """ build stats on assembly"""
    threads: get_threads('copy_final_assemblies', 1)
    input:
        fasta_list = rules.preparing_fasta_to_quality.output.renamed
    output:
        fasta_final = f"{output_dir}ASSEMBLIES/{{fastq}}-{{assemblers}}-{{quality_step}}{f'-{add_circular_name}' if add_circular_name else ''}.fasta"
    params:
        sample = '{fastq}'
    log:
        output = f"{output_dir}ASSEMBLIES/LOGS/{{fastq}}_{{assemblers}}_{{quality_step}}_CP-FINAL-ASSEMBLIES.o",
        error = f"{output_dir}ASSEMBLIES/LOGS/{{fastq}}_{{assemblers}}_{{quality_step}}_CP-FINAL-ASSEMBLIES.e"
    message:
        """
        Launching {rule}
        threads : {threads}
        input:
             fasta_list : {input.fasta_list}
        output:
             fasta_final : {output.fasta_final}
        log:
            output : {log.output}
            error: {log.error}
        """
    shell:
        """
        cp {input.fasta_list} {output.fasta_final}
        """

def output_final(wildcards):
    dico_final = {
        "fasta_list" : expand(rules.preparing_fasta_to_quality.output.renamed, fastq=FASTQ, assemblers=culebront.assembly_tools_activated, quality_step=culebront.quality_step),
        "bench_list" : expand(rules.run_benchmark_time.output.stat_time,fastq=FASTQ),
        "dag" : rules.rule_graph.output.dag,
        "versions" : rules.run_get_versions.output.csv,
        "stats_assembly": expand(rules.stats_assembly.output.csv, fastq = FASTQ),
        "all_assemblies":  expand(rules.copy_final_assemblies.output.fasta_final, fastq=FASTQ,assemblers=culebront.assembly_tools_activated,quality_step=culebront.quality_step)
         # "report_snakemake" : rules.run_report_snakemake.output.report_snakemake,
        }
    if config['QUALITY']['QUAST']:
        dico_final.update({
            "quast_files": expand(rules.run_quast.output.report_file, fastq=FASTQ),
            "quast_version" : expand(rules.run_quast.output.version,fastq=FASTQ[0])
         })
    if config['QUALITY']['BUSCO']:
         dico_final.update({
         "busco_stats": expand(rules.run_busco_stats.output.stat, fastq = FASTQ),
         "busco_version": expand(rules.run_busco_version.output.busco_version,fastq=FASTQ[0])
         })
    if config['QUALITY']['BLOBTOOLS']:
        dico_final.update({
            "blob_files": expand(rules.run_blobtools.output.table, fastq=FASTQ, assemblers=culebront.assembly_tools_activated, quality_step=culebront.last_steps_list)
        })
    if config['QUALITY']['ASSEMBLYTICS']:
        dico_final.update({
             "assemblytics_files": expand(rules.run_assemblytics.output.summary, fastq=FASTQ, assemblers=culebront.assembly_tools_activated, quality_step=culebront.last_steps_list)
        })
    if config['QUALITY']['KAT']:
        dico_final.update({
            "kat_files":  expand(rules.run_KAT.output.gcp, fastq=FASTQ, assemblers=culebront.assembly_tools_activated, quality_step=culebront.last_steps_list)
        })
    if config['QUALITY']['MERQURY']:
        dico_final.update({
            "merqury_files":  expand(rules.run_merqury.output.qv_global, fastq=FASTQ, assemblers=culebront.assembly_tools_activated, quality_step=culebront.last_steps_list),
            "merqury_stats":  expand(rules.run_merqury.output.comp_stats, fastq=FASTQ, assemblers=culebront.assembly_tools_activated, quality_step=culebront.last_steps_list)
        })
    if config['QUALITY']['FLAGSTATS']:
         dico_final.update({
         "flagstats_stats": expand(rules.run_flagstats_stats.output.stat, fastq = FASTQ)
         })
    if config['MSA']['MAUVE']:
        dico_final.update({
            "mauve_files": expand(f"{output_dir}{{fastq}}/AGGREGATED_QC/MAUVE_ALIGN/candidate_assemblies.xmfa", fastq=FASTQ)
        })
    return dico_final

rule run_report:
    """
    printing  CulebrONT report
    """
    threads: get_threads('run_report', 1)
    input:
        unpack(output_final)
    params:
        samples_list = FASTQ,
        txt_config = culebront.export_use_yaml,
        out_dir_report = directory(f"{output_dir}FINAL_REPORT"),
        quality_tools_list = culebront.quality_tools_activated,
        logo = culebrONT.logo
    output:
        report = f"{output_dir}FINAL_REPORT/CulebrONT_report.html"
    log:
        output = f"{output_dir}FINAL_REPORT/LOGS/REPORT.o",
        error = f"{output_dir}FINAL_REPORT/LOGS/REPORT.e"
    message:
        """
        print final CulebrONT_report ...
        """
    singularity:
        tools_config['SINGULARITY']['REPORT']
    envmodules:
        tools_config['ENVMODULE']['R']
    script:
        f"{SNAKEMAKE_SCRIPTS}/Report.Rmd"
