import os

TARGET = '{prefix}/merged.peakcounts.h5ad'.format(prefix = config['sample'])
rule all:
    input : TARGET

rule filter_short_fragments:
    output : '{prefix}/fragments.short.bed.gz'
    params:
        max_fraglen = config['max_fraglen'],
        fragments = config['fragment_file'],
        genome_file = config['genome_file'],
    shell:
        "gzip -d -c {params.fragments} | mira-preprocess filter-chroms - {params.genome_file} | awk '$3 - $2 <= {params.max_fraglen} && $1 != \"#\"'"
        "| gzip > {output}"


rule filter_barcodes:
    input: 
        fragments= rules.filter_short_fragments.output,
        barcodes='{prefix}/clusters/{name}.barcodes.txt'
    output: '{prefix}/{name}.fragments.short.barcodes.bed.gz'
    shell:
        "gzip -d -c {input.fragments} | mira-preprocess filter-barcodes - {input.barcodes}"
        "| gzip > {output}"


def get_filtered_fragment_files(wildcards):

    if wildcards.name == 'bulk':
        return rules.filter_short_fragments.output
    else:
        return rules.filter_barcodes.output

rule call_peaks:
    input : get_filtered_fragment_files
    output: '{prefix}/{name}/peakcall/{name}_summits.bed'
    params:
        outdir = lambda w : '{prefix}/{name}/peakcall/'.format(prefix = w.prefix, name = w.name),
        name = lambda w : w.name,
        genome_size = config['genome_size'],
    shell:
        "mira-preprocess call-peaks -i {input} -d {params.outdir} -n {params.name} "
        "-g {params.genome_size}"


rule slop_peaks:
    input : rules.call_peaks.output
    output : '{prefix}/{name}/peakcall/{name}_summits.slopped.bed'
    params :
        genome_file = config['genome_file'],
        slop_distance = config['slop_distance']
    shell :
        "bedtools slop -i {input} -g {params.genome_file} -b {params.slop_distance} > {output}"


def get_peakset(wildcards):
    if wildcards.name == 'merged':
        return rules.merge_peaks.output
    else:
        return rules.slop_peaks.output

rule aggregate_peakcounts:
    input : get_peakset
    output:
        '{prefix}/{name}.peakcounts.h5ad'
    params :
        genome_file = config['genome_file'],
        fragments = config['fragment_file']
    shell:
        "mira-preprocess agg-countmatrix --fragments {params.fragments} --peaks {input} "
        "--genome-file {params.genome_file} -o {output}"


checkpoint cluster_cells:
    input: '{prefix}/bulk.peakcounts.h5ad'
    output: directory('{prefix}/clusters/')
    params:
        resolution = config['leiden_resolution'],
        report_dir = lambda w : w.prefix,
        min_peaks = config['min_peaks'],
        max_counts = config['max_counts'],
        min_frip = config['min_frip'],
        min_fragments_in_cluster = config['min_fragments_in_cluster'],
        components = config['num_lsi_components']
    shell:
        "mira-preprocess cluster-cells -i {input} -o {output} "
        "-r {params.report_dir} --resolution {params.resolution} "
        "--min-peaks {params.min_peaks} --max-counts {params.max_counts} "
        "--min-frip {params.min_frip} --min-fragments-in-cluster {params.min_fragments_in_cluster} "
        "--num-lsi-components {params.components}"


def get_cluster_peaks(wildcards):
    checkpoint_output = checkpoints.cluster_cells.get(**wildcards).output[0]
    return expand(rules.call_peaks.output, 
            prefix = wildcards.prefix,
            name = glob_wildcards(
                    os.path.join(checkpoint_output, '{name}.barcodes.txt')
                ).name
        )

rule merge_peaks:
    input: get_cluster_peaks
    output: '{prefix}/final_peakset.bed.merged'
    params : 
        genome_file = config['genome_file'],
        slop_distance = config['slop_distance']
    shell:
        "mira-preprocess merge-peaks -s {input} -d {params.slop_distance} "
        "-g {params.genome_file} -o {output}"
