"""
Document how all the notebooks for a single experiment are connected.
"""
from snakemake.logging import logger


configfile: "config/single_dev_dataset/proteinGroups_N50/config.yaml"


MAX_WALLTIME = "24:00:00"
# Thinnode resources sharing: 40 cores and 196 GB RAM (minus 2GB for snakemake)
# JOB_RAM_MB = int(204_800 / 40 * config['THREATS_MQ'])
JOB_RAM_MB = "4gb"
folder_experiment = config["folder_experiment"]
logger.info(f"{folder_experiment = }")


# local rules are excuted in the process (job) running snakemake
localrules:
    all,
    comparison,
    transform_NAGuideR_predictions,
    transform_data_to_wide_format,
    create_splits,


rule all:
    input:
        f"{folder_experiment}/figures/2_1_test_errors_binned_by_int.pdf",
        f"{folder_experiment}/01_2_performance_summary.xlsx",


nb = "01_2_performance_plots.ipynb"

MODELS = config["models"].copy()
if config["NAGuideR_methods"]:
    MODELS += config["NAGuideR_methods"]

nb_stem = "01_2_performance_summary"


rule comparison:
    input:
        nb=nb,
        runs=expand(
            "{folder_experiment}/preds/pred_test_{model}.csv",
            folder_experiment=config["folder_experiment"],
            model=MODELS,
        ),
    output:
        xlsx=f"{{folder_experiment}}/{nb_stem}.xlsx",
        pdf="{folder_experiment}/figures/2_1_test_errors_binned_by_int.pdf",
        nb="{folder_experiment}" f"/{nb}",
    params:
        meta_data=config["fn_rawfile_metadata"],
        models=",".join(MODELS),
        err=f"{{folder_experiment}}/{nb_stem}.e",
        out=f"{{folder_experiment}}/{nb_stem}.o",
    conda:
        "envs/pimms.yaml"
    shell:
        "papermill {input.nb} {output.nb:q}"
        " -p fn_rawfile_metadata {params.meta_data:q}"
        " -r folder_experiment {wildcards.folder_experiment:q}"
        " -r models {params.models:q}"
        " && jupyter nbconvert --to html {output.nb:q}"


##########################################################################################
# train NaGuideR methods
nb_stem = "01_1_transfer_NAGuideR_pred"


rule transform_NAGuideR_predictions:
    input:
        dumps=expand(
            "{{folder_experiment}}/preds/pred_all_{method}.csv",
            method=config["NAGuideR_methods"],
        ),
        nb=f"{nb_stem}.ipynb",
    output:
        # "{{folder_experiment}}/preds/pred_real_na_{method}.csv"),
        expand(
            (
                "{{folder_experiment}}/preds/pred_val_{method}.csv",
                "{{folder_experiment}}/preds/pred_test_{method}.csv",
            ),
            method=config["NAGuideR_methods"],
        ),
        nb="{folder_experiment}/01_1_transfer_NAGuideR_pred.ipynb",
    benchmark:
        "{folder_experiment}/" f"{nb_stem}.tsv"
    params:
        err=f"{{folder_experiment}}/{nb_stem}.e",
        out=f"{{folder_experiment}}/{nb_stem}.o",
        folder_experiment="{folder_experiment}",
        # https://snakemake.readthedocs.io/en/stable/snakefiles/rules.html#non-file-parameters-for-rules
        dumps_as_str=lambda wildcards, input: ",".join(input.dumps),
    conda:
        "envs/pimms.yaml"
    shell:
        "papermill {input.nb} {output.nb:q}"
        " -r folder_experiment {params.folder_experiment:q}"
        " -p dumps {params.dumps_as_str}"
        " && jupyter nbconvert --to html {output.nb:q}"


rule train_NAGuideR_model:
    input:
        nb="01_1_train_NAGuideR_methods.ipynb",
        train_split="{folder_experiment}/data/data_wide_sample_cols.csv",
    output:
        nb="{folder_experiment}/01_1_train_NAGuideR_{method}.ipynb",
        dump="{folder_experiment}/preds/pred_all_{method}.csv",
    resources:
        mem_mb=JOB_RAM_MB,
        walltime=MAX_WALLTIME,
    threads: 1  # R is single threaded
    benchmark:
        "{folder_experiment}/01_1_train_NAGuideR_{method}.tsv"
    params:
        err="{folder_experiment}/01_1_train_NAGuideR_{method}.e",
        out="{folder_experiment}/01_1_train_NAGuideR_{method}.o",
        folder_experiment="{folder_experiment}",
        method="{method}",
        name="{method}",
    # log:
    #     err="{folder_experiment}/01_1_train_NAGuideR_{method}.log",
    conda:
        "envs/trainRmodels.yaml"
    shell:
        "papermill {input.nb} {output.nb:q}"
        " -r train_split {input.train_split:q}"
        " -r method {params.method}"
        " -r folder_experiment {params.folder_experiment:q}"
        # " 2> {log.err}"
        " && jupyter nbconvert --to html {output.nb:q}"


nb_stem = "01_0_transform_data_to_wide_format"


rule transform_data_to_wide_format:
    input:
        nb=f"{nb_stem}.ipynb",
        train_split="{folder_experiment}/data/train_X.csv",
    output:
        nb="{folder_experiment}/01_0_transform_data_to_wide_format.ipynb",
        train_split="{folder_experiment}/data/data_wide_sample_cols.csv",
    params:
        folder_experiment="{folder_experiment}",
        err=f"{{folder_experiment}}/{nb_stem}.e",
        out=f"{{folder_experiment}}/{nb_stem}.o",
    conda:
        "envs/pimms.yaml"
    shell:
        "papermill {input.nb} {output.nb:q}"
        " -r folder_experiment {params.folder_experiment:q}"
        " && jupyter nbconvert --to html {output.nb:q}"


##########################################################################################
# train models in python
rule train_models:
    input:
        nb="01_1_train_{model}.ipynb",
        train_split="{folder_experiment}/data/train_X.csv",
        configfile=config["config_train"],
    output:
        nb="{folder_experiment}/01_1_train_{model}.ipynb",
        pred="{folder_experiment}/preds/pred_test_{model}.csv",
    benchmark:
        "{folder_experiment}/01_1_train_{model}.tsv"
    params:
        folder_experiment="{folder_experiment}",
        meta_data=config["fn_rawfile_metadata"],
        err="{folder_experiment}/01_1_train_{model}.e",
        out="{folder_experiment}/01_1_train_{model}.o",
        name="{model}",
    # log:
    #     err="{folder_experiment}/01_1_train_{model}.log",
    conda:
        "envs/pimms.yaml"
    shell:
        "papermill {input.nb:q} {output.nb:q}"
        " -f {input.configfile:q}"
        " -r folder_experiment {params.folder_experiment:q}"
        " -p fn_rawfile_metadata {params.meta_data:q}"
        " -r model_key {wildcards.model:q}"
        # " 2> {log.err}"
        " && jupyter nbconvert --to html {output.nb:q}"


##########################################################################################
# Create Data splits
# separate workflow by level -> provide custom configs
nb_stem = "01_0_split_data"


rule create_splits:
    input:
        nb=f"{nb_stem}.ipynb",
        configfile=config["config_split"],
    output:
        train_split="{folder_experiment}/data/train_X.csv",
        nb="{folder_experiment}" f"/{nb_stem}.ipynb",
    params:
        folder_experiment="{folder_experiment}",
        meta_data=config["fn_rawfile_metadata"],
        err=f"{{folder_experiment}}/{nb_stem}.e",
        out=f"{{folder_experiment}}/{nb_stem}.o",
    conda:
         "envs/pimms.yaml"
    shell:
        "papermill {input.nb} {output.nb}"
        " -f {input.configfile:q}"
        " -r folder_experiment {params.folder_experiment:q}"
        " -p fn_rawfile_metadata {params.meta_data:q}"
        " && jupyter nbconvert --to html {output.nb:q}"
