import pandas as pd
import pathlib


template_dir = 'REPLACE_TEMPLATE_DIR'
template_dir = template_dir.rstrip('/')


group_names = [p.name for p in pathlib.Path().glob('*') if p.name.startswith('c')]


def get_cell_numbers(wildcards):
    group_dir = wildcards.group_dir
    n_cells = pd.read_csv(f'{group_dir}/select_cells.txt', header=None).shape[0]
    return n_cells


rule all:
    input:
        expand('{group_dir}/finish', group_dir=group_names)

rule r01:
    input:
        config='{group_dir}/config/01.yaml',
    params:
        nb=f"{template_dir}/01-BasicFeatureFiltering.ipynb",
    output:
        features="{group_dir}/FeatureList.BasicFilter.txt"
    resources:
        mem_mb=lambda wildcard: min(50000, get_cell_numbers(wildcard) * 0.1)
    log:
        nb="{group_dir}/01-BasicFeatureFiltering.ipynb",
        log="{group_dir}/log/01-BasicFeatureFiltering.log"
    threads: 1
    shell:
        "papermill -f {input.config} --cwd {wildcards.group_dir} "
        "{params.nb} {log.nb} "
        "> {log.log} 2>&1"

rule r02:
    input:
        config='{group_dir}/config/02.yaml',
        features="{group_dir}/FeatureList.BasicFilter.txt"
    params:
        nb=f"{template_dir}/02-HighlyVariableFeatureSelection.ipynb",
    output:
        mch_hvf='{group_dir}/mch_hvf.hdf',
        mcg_hvf='{group_dir}/mcg_hvf.hdf',
    resources:
        mem_mb=lambda wildcard: min(50000, get_cell_numbers(wildcard) * 0.2)
    log:
        nb="{group_dir}/02-HighlyVariableFeatureSelection.ipynb",
        log="{group_dir}/log/02-HighlyVariableFeatureSelection.log"
    threads: 1
    shell:
        "papermill -f {input.config} --cwd {wildcards.group_dir} "
        "{params.nb} {log.nb} "
        "> {log.log} 2>&1"

rule r03a:
    input:
        config='{group_dir}/config/03a.yaml',
        mch_hvf='{group_dir}/mch_hvf.hdf',
    params:
        nb=f"{template_dir}/03a-GeneratemCHAdata.ipynb",
    output:
        temp('{group_dir}/mCH.HVF.h5ad')
    resources:
        mem_mb=lambda wildcard: min(50000, get_cell_numbers(wildcard) * 0.4)
    log:
        nb="{group_dir}/03a-GeneratemCHAdata.ipynb",
        log="{group_dir}/log/03a-GeneratemCHAdata.log"
    threads: 4
    shell:
        "papermill -f {input.config} --cwd {wildcards.group_dir} "
        "{params.nb} {log.nb} "
        "> {log.log} 2>&1"

rule r03b:
    input:
        config='{group_dir}/config/03b.yaml',
        mcg_hvf='{group_dir}/mcg_hvf.hdf',
    params:
        nb=f"{template_dir}/03b-GeneratemCGAdata.ipynb",
    output:
        temp('{group_dir}/mCG.HVF.h5ad')
    resources:
        mem_mb=lambda wildcard: min(100000, get_cell_numbers(wildcard) * 0.4)
    log:
        nb="{group_dir}/03b-GeneratemCGAdata.ipynb",
        log="{group_dir}/log/03b-GeneratemCGAdata.log"
    threads: 4
    shell:
        "papermill -f {input.config} --cwd {wildcards.group_dir} "
        "{params.nb} {log.nb} "
        "> {log.log} 2>&1"

rule r04a:
    input:
        config='{group_dir}/config/04a.yaml',
        adata='{group_dir}/mCH.HVF.h5ad'
    params:
        nb=f"{template_dir}/04a-PreclusteringAndClusterEnrichedFeatures-mCH.ipynb",
    output:
       adata=temp('{group_dir}/mCH.CEF.h5ad')
    resources:
        mem_mb=lambda wildcard: min(100000, get_cell_numbers(wildcard) * 0.3)
    log:
        nb="{group_dir}/04a-PreclusteringAndClusterEnrichedFeatures-mCH.ipynb",
        log="{group_dir}/log/04a-PreclusteringAndClusterEnrichedFeatures-mCH.log"
    threads: 4
    shell:
        "papermill -f {input.config} --cwd {wildcards.group_dir} "
        "{params.nb} {log.nb} "
        "> {log.log} 2>&1"

rule r04b:
    input:
        config='{group_dir}/config/04b.yaml',
        adata='{group_dir}/mCG.HVF.h5ad'
    params:
        nb=f"{template_dir}/04b-PreclusteringAndClusterEnrichedFeatures-mCG.ipynb",
    output:
        adata=temp('{group_dir}/mCG.CEF.h5ad')
    resources:
        mem_mb=lambda wildcard: min(50000, get_cell_numbers(wildcard) * 0.3)
    log:
        nb="{group_dir}/04b-PreclusteringAndClusterEnrichedFeatures-mCG.ipynb",
        log="{group_dir}/log/04b-PreclusteringAndClusterEnrichedFeatures-mCG.log"
    threads: 4
    shell:
        "papermill -f {input.config} --cwd {wildcards.group_dir} "
        "{params.nb} {log.nb} "
        "> {log.log} 2>&1"

rule r05a:
    input:
        config='{group_dir}/config/05a.yaml',
        adata='{group_dir}/mCH.CEF.h5ad'
    params:
        nb=f"{template_dir}/05a-Decomposition-mCH.ipynb",
    output:
        temp('{group_dir}/mch_adata.with_coords.h5ad')
    resources:
        mem_mb=lambda wildcard: min(50000, get_cell_numbers(wildcard) * 0.3)
    log:
        nb="{group_dir}/05a-Decomposition-mCH.ipynb",
        log="{group_dir}/log/05a-Decomposition-mCH.log"
    threads: 12
    shell:
        "papermill -f {input.config} --cwd {wildcards.group_dir} "
        "{params.nb} {log.nb} "
        "> {log.log} 2>&1"

rule r05b:
    input:
        config='{group_dir}/config/05b.yaml',
        adata='{group_dir}/mCG.CEF.h5ad'
    params:
        nb=f"{template_dir}/05b-Decomposition-mCG.ipynb",
    output:
        temp('{group_dir}/mcg_adata.with_coords.h5ad')
    resources:
        mem_mb=lambda wildcard: min(50000, get_cell_numbers(wildcard) * 0.3)
    log:
        nb="{group_dir}/05b-Decomposition-mCG.ipynb",
        log="{group_dir}/log/05b-Decomposition-mCG.log"
    threads: 12
    shell:
        "papermill -f {input.config} --cwd {wildcards.group_dir} "
        "{params.nb} {log.nb} "
        "> {log.log} 2>&1"

rule r05c:
    input:
        config='{group_dir}/config/05c.yaml',
        mch_adata='{group_dir}/mch_adata.with_coords.h5ad',
        mcg_adata='{group_dir}/mcg_adata.with_coords.h5ad'
    params:
        nb=f"{template_dir}/05c-Embedding.ipynb",
    output:
        '{group_dir}/adata.with_coords.h5ad'
    resources:
        mem_mb=lambda wildcard: min(100000, get_cell_numbers(wildcard) * 0.1)
    log:
        nb="{group_dir}/05c-Embedding.ipynb",
        log="{group_dir}/log/05c-Embedding.log"
    threads: 12
    shell:
        "papermill -f {input.config} --cwd {wildcards.group_dir} "
        "{params.nb} {log.nb} "
        "> {log.log} 2>&1"

rule r06:
    input:
        config='{group_dir}/config/06.yaml',
        adata='{group_dir}/adata.with_coords.h5ad'
    params:
        nb=f"{template_dir}/06-Clustering.ipynb",
    output:
        '{group_dir}/ConcensusClustering.model.lib'
    resources:
        mem_mb=lambda wildcard: min(100000, get_cell_numbers(wildcard) * 0.1)
    log:
        nb="{group_dir}/06-Clustering.ipynb",
        log="{group_dir}/log/06-Clustering.log"
    threads: 12
    shell:
        "papermill -f {input.config} --cwd {wildcards.group_dir} "
        "{params.nb} {log.nb} "
        "> {log.log} 2>&1"

rule r07:
    input:
        config='{group_dir}/config/07.yaml',
        adata='{group_dir}/adata.with_coords.h5ad',
        cluster='{group_dir}/ConcensusClustering.model.lib'
    params:
        nb=f"{template_dir}/07-Plot.ipynb",
    output:
        flag="{group_dir}/finish"
    resources:
        mem_mb=lambda wildcard: min(50000, get_cell_numbers(wildcard) * 0.02)
    log:
        nb="{group_dir}/07-Plot.ipynb",
        log="{group_dir}/log/07-Plot.log"
    threads: 1
    shell:
        "papermill -f {input.config} --cwd {wildcards.group_dir} "
        "{params.nb} {log.nb} "
        "> {log.log} 2>&1"
