# ---- begin snakebids boilerplate ----------------------------------------------

from glob import glob
import snakebids
from snakebids import bids


configfile: "config/snakebids.yml"


# writes inputs_config.yml and updates config dict
config.update(
    snakebids.generate_inputs(
        bids_dir=config["bids_dir"],
        pybids_inputs=config["pybids_inputs"],
        derivatives=config["derivatives"],
        participant_label=config["participant_label"],
        exclude_participant_label=config["exclude_participant_label"],
    )
)


# this adds constraints to the bids naming
wildcard_constraints:
    **snakebids.get_wildcard_constraints(config["pybids_inputs"]),


# ---- end snakebids boilerplate ------------------------------------------------

model = config["use_downloaded"]


include: "rules/common.smk"


if config["analysis_level"] == "participant":

    rule all_test:
        input:
            nii=expand(
                bids(
                    root="results",
                    datatype="func",
                    desc="brain",
                    suffix="mask.nii.gz",
                    **config["input_wildcards"]["bold"]
                ),
                zip,
                **config["input_zip_lists"]["bold"]
            ),

    include: "rules/inference.smk"


elif config["analysis_level"] == "train":

    rule all_train:
        input:
            expand(
                "resources/trained_models/nnUNet/{arch}/{unettask}/{trainer}__nnUNetPlansv2.1/fold_{fold}/{checkpoint}.model",
                fold=range(5),
                checkpoint=config["download_model"][model]["checkpoint"],
                unettask=config["download_model"][model]["unettask"],
                arch=config["nnunet"]["arch"],
                trainer=config["nnunet"]["trainer"],
            ),

    rule all_model_tar:
        """Target rule to package trained model into a tar file"""
        input:
            model_tar=expand(
                "resources/trained_model.{arch}.{unettask}.{trainer}.{checkpoint}.tar",
                checkpoint=config["download_model"][model]["checkpoint"],
                unettask=config["download_model"][model]["unettask"],
                arch=config["nnunet"]["arch"],
                trainer=config["nnunet"]["trainer"],
            ),

    include: "rules/training.smk"


elif config["analysis_level"] == "evaluate":

    rule all_evaluate:
        input:
            "test_dice.csv",

    include: "rules/inference.smk"
    include: "rules/evaluate.smk"


elif config["analysis_level"] == "evaluate_rutherford":

    rule all_evaluate_rutherford:
        input:
            "test_rutherfordunet_dice.csv",

    include: "rules/evaluate_rutherford.smk"
