import pandas as pd
import pathlib

dataset = 'REPLACE_DATASET'
cluster_col = 'REPLACE_CLUSTER_COL'
categorical_key = REPLACE_CATEGORICAL_KEY
integration_group_key = 'REPLACE_INTEGRATION_GROUP_KEY'
integration_plot_key = 'REPLACE_INTEGRATION_PLOT_KEY'
template_dir = 'REPLACE_TEMPLATE_DIR'
mc_type = 'REPLACE_MC_TYPE'

template_dir = template_dir.rstrip('/')
categorical_key_str = "'{" + f'categorical_key:{categorical_key}' + "}'"


group_names = [p.name 
               for p in pathlib.Path().glob('*') 
               if (not p.name.startswith('.')) and p.is_dir() and (p.name != 'Summary') and (p.name != 'Neuron')]
print('group names', group_names)


rule all:
    input:
        expand('{group_dir}/mc_integration_group.csv.gz', group_dir=group_names),
        expand('{group_dir}/rna_integration_group.csv.gz', group_dir=group_names),


rule r01:
    input:
        '{group_dir}/mc_cells.txt',
    params:
        nb=f"{template_dir}/01.select_mc_cef.ipynb",
    output:
        adata="{group_dir}/CEF.csv"
    log:
        nb="{group_dir}/01.select_mc_cef.ipynb",
        log="{group_dir}/log/01.select_mc_cef.log"
    threads: 1
    shell:
        "papermill --cwd {wildcards.group_dir} "
        "-p cluster_col {cluster_col} "
        "-p mc_type {mc_type} "
        "{params.nb} {log.nb} "
        "> {log.log} 2>&1"


rule r02a:
    input:
        cef='{group_dir}/CEF.csv',
        cells='{group_dir}/mc_cells.txt',
    params:
        nb=f"{template_dir}/02a.prepare_mc.ipynb",
    output:
        adata=temp("{group_dir}/mc_input.h5ad")
    log:
        nb="{group_dir}/02a.prepare_mc.ipynb",
        log="{group_dir}/log/02a.prepare_mc.log"
    threads: 1
    shell:
        "papermill --cwd {wildcards.group_dir} "
        "{params.nb} {log.nb} "
        "-p mc_type {mc_type} "
        "> {log.log} 2>&1"


rule r02b:
    input:
        cef='{group_dir}/CEF.csv',
        cells='{group_dir}/rna_cells.txt',
    params:
        nb=f"{template_dir}/02b.prepare_rna.ipynb",
    output:
        adata=temp("{group_dir}/rna_input.h5ad")
    log:
        nb="{group_dir}/02b.prepare_rna.ipynb",
        log="{group_dir}/log/02b.prepare_rna.log"
    threads: 1
    shell:
        "papermill --cwd {wildcards.group_dir} "
        "-p dataset {dataset} "
        "-p mc_type {mc_type} "
        "{params.nb} {log.nb} "
        "> {log.log} 2>&1"


rule r03:
    input:
        mc='{group_dir}/mc_input.h5ad',
        atac='{group_dir}/rna_input.h5ad',
    params:
        nb=f"{template_dir}/03.merge_and_pca.ipynb",
    output:
        temp("{group_dir}/mc_pca.h5ad"),
        temp("{group_dir}/rna_pca.h5ad")
    log:
        nb="{group_dir}/03.merge_and_pca.ipynb",
        log="{group_dir}/log/03.merge_and_pca.log"
    threads: 1
    shell:
        "papermill --cwd {wildcards.group_dir} "
        "{params.nb} {log.nb} "
        "> {log.log} 2>&1"


rule r04:
    input:
        mc='{group_dir}/mc_pca.h5ad',
        atac='{group_dir}/rna_pca.h5ad',
    params:
        nb=f"{template_dir}/04.integration.ipynb",
    output:
        h5ad=temp("{group_dir}/final.h5ad"),
        integrator=directory("{group_dir}/integration")
    log:
        nb="{group_dir}/04.integration.ipynb",
        log="{group_dir}/log/04.integration.log"
    threads: 1
    shell:
        "papermill --cwd {wildcards.group_dir} "
        "-p dataset {dataset} "
        "-y {categorical_key_str} "
        "{params.nb} {log.nb} "
        "> {log.log} 2>&1"


rule r05:
    input:
        adata='{group_dir}/final.h5ad',
    params:
        nb=f"{template_dir}/05.embedding.ipynb",
    output:
        h5ad="{group_dir}/final_with_coords.h5ad"
    log:
        nb="{group_dir}/05.embedding.ipynb",
        log="{group_dir}/log/05.embedding.log"
    threads: 1
    shell:
        "papermill --cwd {wildcards.group_dir} "
        "-p dataset {dataset} "
        "-p plot_key {integration_plot_key} "
        "{params.nb} {log.nb} "
        "> {log.log} 2>&1"


rule r06:
    input:
        mc='{group_dir}/final_with_coords.h5ad',
    params:
        nb=f"{template_dir}/06.confusion_matrix.ipynb",
    output:
        flag="{group_dir}/finish"
    log:
        nb="{group_dir}/06.confusion_matrix.ipynb",
        log="{group_dir}/log/06.confusion_matrix.log"
    threads: 1
    shell:
        "papermill --cwd {wildcards.group_dir} "
        "-p dataset {dataset} "
        "-y {categorical_key_str} "
        "{params.nb} {log.nb} "
        "> {log.log} 2>&1"


rule r07:
    input:
        mc='{group_dir}/final_with_coords.h5ad',
        flag="{group_dir}/finish"
    params:
        nb=f"{template_dir}/07.select_integration_group.ipynb",
    output:
        mc_ig="{group_dir}/mc_integration_group.csv.gz",
        rna_ig="{group_dir}/rna_integration_group.csv.gz",
    log:
        nb="{group_dir}/07.select_integration_group.ipynb",
        log="{group_dir}/log/07.select_integration_group.log"
    threads: 1
    shell:
        "papermill --cwd {wildcards.group_dir} "
        "-p category_key {integration_group_key} "
        "-p plot_key {integration_plot_key} "
        "-p dataset {dataset} "
        "{params.nb} {log.nb} "
        "> {log.log} 2>&1"
