from typing import List
from pathlib import Path

import pandas as pd
from snakemake.utils import min_version

# Configuration

min_version("6.0")

configfile: "config/config.yaml"

sample_files = pd.read_csv(config["samples"],sep='\t').set_index("Sample",drop=False)
samples_assemblies = sample_files[~sample_files['Assemblies'].isna()].index.tolist()
samples_paired_reads = sample_files[(~sample_files['Reads1'].isna()) & (~sample_files['Reads2'].isna())].index.tolist()
samples_single_reads = sample_files[(~sample_files['Reads1'].isna()) & (sample_files['Reads2'].isna())].index.tolist()

# If both reads and assemblies are set, prefer the assemblies for analysis
samples_assemblies_set = set(samples_assemblies)
samples_paired_reads = list(set(samples_paired_reads) - samples_assemblies_set)
samples_single_reads = list(set(samples_single_reads) - samples_assemblies_set)

reference_abs = Path(config["reference"]).absolute()

# Workflow

rule all:
    input:
        "gdi-input.fofn",
        "mlst.tsv" if config["include_mlst"] else []


rule prepare_reference:
    input:
        reference=reference_abs,
    output:
        reference="reference/reference.fasta",
    conda:
        "envs/gdi.yaml"
    threads: 1
    log:
        "logs/prepare-reference.log"
    script:
        "scripts/prepare-reference.py"


rule index_reference:
    input:
        reference="reference/reference.fasta",
    output:
        reference_mm2_index="reference/reference.fasta.mmi",
    conda:
        "envs/minimap2.yaml"
    threads: 1
    log:
        faidx="logs/index_reference.faidx.log",
        mm2_index="logs/index_reference.mm2_index.log"
    shell:
        "samtools faidx {input.reference} 2>&1 1>{log.faidx} && "
        "minimap2 -t {threads} -d {output.reference_mm2_index} {input.reference} 2>&1 1>{log.mm2_index}"


rule reference_genome_bedtools:
    input:
        reference="reference/reference.fasta",
    output:
        genome_bedtools="reference/reference-genome-bedtools.tsv",
    conda:
        "envs/main.yaml"
    threads: 1
    log:
        "logs/reference_genome_bedtools.log"
    script:
        "scripts/fasta-to-genome-bedtools.py"


rule prepare_snpeff_database:
    input:
        reference=reference_abs,
    output:
        snpeff_db_conf="snpeff_db/db.conf",
    conda:
        "envs/snpeff.yaml"
    threads: 1
    log:
        "logs/prepare-snpeff-database.log"
    script:
        "scripts/prepare-snpeff-database.py"


rule assembly_align: 
    input:
        reference_mm2_index="reference/reference.fasta.mmi",
        sample=lambda wildcards: sample_files.loc[wildcards.sample]['Assemblies'],
    output:
        "assemblies/align/{sample}.bam",
    conda:
        "envs/minimap2.yaml"
    threads: 1
    log:
        mm2="logs/assembly_align.{sample}.minimap2.log",
        samsort="logs/assembly_align.{sample}.samsort.log",
    shell:
        "minimap2 -t {threads} -a -x asm5 {input.reference_mm2_index} {input.sample} 2> {log.mm2} | "
        "samtools sort --threads {threads} --output-fmt BAM --write-index -o {output} 2> {log.samsort}"


rule assembly_variant_all:
    input:
        reference="reference/reference.fasta",
        bam="assemblies/align/{sample}.bam",
    output:
        vcf="assemblies/variant_all/{sample}.vcf.gz",
    conda:
        "envs/bcftools.yaml"
    threads: 1
    params:
        qual=8,
	
	# Changed --indel-bias in bcftools mpileup from default of 1.00 since 
	# otherwise I was missing a small number of indels in SARS-CoV-2
	indel_bias=5,
    log:
        mpileup="logs/assembly_variant_all.{sample}.mpileup.log",
        call="logs/assembly_variant_all.{sample}.call.log",
        norm="logs/assembly_variant_all.{sample}.norm.log",
        filter="logs/assembly_variant_all.{sample}.view.log",
        tags="logs/assembly_variant_all.{sample}.tags.log",
    shell:
        "bcftools mpileup -a DP --min-ireads 1 --indel-bias {params.indel_bias} --threads {threads} -Ou -f {input.reference} {input.bam} 2> {log.mpileup} | "
        "bcftools call --threads {threads} --ploidy 1 -Ou --multiallelic-caller 2> {log.call} | "
        "bcftools norm -f {input.reference} --threads {threads} --old-rec-tag -Ou 2> {log.norm} | "
        "bcftools view --include 'QUAL>{params.qual}' -Ou 2> {log.filter} | "
        "bcftools plugin fill-tags -Oz -o {output.vcf} -- -t TYPE 2> {log.tags}"


rule assembly_variant:
    input:
        "assemblies/variant_all/{sample}.vcf.gz",
    output:
        "assemblies/variant/{sample}.vcf.gz",
    conda:
        "envs/bcftools.yaml"
    threads: 1
    log:
        "logs/assembly_variant.{sample}.log"
    shell:
        "bcftools view --exclude 'TYPE==\"REF\"' -Oz -o {output} {input} 2> {log}"


rule reads_paired_snippy:
    input:
        reference="reference/reference.fasta",
        reads1=lambda wildcards: sample_files.loc[wildcards.sample]['Reads1'],
        reads2=lambda wildcards: sample_files.loc[wildcards.sample]['Reads2'],
    output:
        snippy_dir=directory("reads/snippy/paired/{sample}"),
        vcf="reads/paired/variant/{sample}.vcf.gz",
    conda:
        "envs/snippy.yaml"
    threads: 1
    params:
        mincov=config['reads_mincov'],
        minqual=config['reads_minqual'],
        subsample=config['reads_subsample'],
    log:
        snippy="logs/snippy_variant.{sample}.log",
    shell:
        "snippy --cpus {threads} --outdir {output.snippy_dir} --ref {input.reference} --mincov {params.mincov}"
        " --minqual {params.minqual} --subsample {params.subsample} --R1 {input.reads1} --R2 {input.reads2}"
        " 1> {log.snippy} 2>&1 && "
        "cp {output.snippy_dir}/snps.vcf.gz {output.vcf}"


rule reads_single_mask_snippy:
    input:
        snippy_dir="reads/snippy/single/{sample}"
    output:
        mask="reads/single/mask/{sample}.bed.gz",
    conda:
        "envs/gdi.yaml"
    log:
        "logs/reads_mask.{sample}.log"
    script:
        "scripts/create-snippy-mask.py"


use rule reads_single_mask_snippy as reads_paired_mask_snippy with:
    input:
        snippy_dir="reads/snippy/paired/{sample}"
    output:
        mask="reads/paired/mask/{sample}.bed.gz",


rule reads_single_snippy:
    input:
        reference="reference/reference.fasta",
        reads1=lambda wildcards: sample_files.loc[wildcards.sample]['Reads1'],
    output:
        snippy_dir=directory("reads/snippy/single/{sample}"),
        vcf="reads/single/variant/{sample}.vcf.gz",
    conda:
        "envs/snippy.yaml"
    threads: 1
    params:
        mincov=config['reads_mincov'],
        minqual=config['reads_minqual'],
        subsample=config['reads_subsample'],
    log:
        snippy="logs/snippy_variant.{sample}.log",
    shell:
        "snippy --cpus {threads} --outdir {output.snippy_dir} --ref {input.reference} --mincov {params.mincov}"
        " --minqual {params.minqual} --subsample {params.subsample} --se {input.reads1}"
        " 1> {log.snippy} 2>&1 && "
        "cp {output.snippy_dir}/snps.vcf.gz {output.vcf}"


rule assembly_variant_snpeff:
    input:
        snpeff_db_conf="snpeff_db/db.conf",
        vcf="assemblies/variant/{sample}.vcf.gz",
    output:
        snpeff_vcf="assemblies/variant-snpeff/{sample}.vcf.gz",
    conda:
        "envs/snpeff.yaml"
    threads: 1
    log:
        "logs/assembly_variant_snpeff.{sample}.log",
    script:
        "scripts/variant-snpeff.py"


use rule assembly_variant_snpeff as reads_variant_paired_snpeff with:
    input:
        snpeff_db_conf="snpeff_db/db.conf",
        vcf="reads/paired/variant/{sample}.vcf.gz",
    output:
        snpeff_vcf="reads/paired/variant-snpeff/{sample}.vcf.gz",


use rule assembly_variant_snpeff as reads_variant_single_snpeff with:
    input:
        snpeff_db_conf="snpeff_db/db.conf",
        vcf="reads/single/variant/{sample}.vcf.gz",
    output:
        snpeff_vcf="reads/single/variant-snpeff/{sample}.vcf.gz",


rule assembly_mask:
    input:
        reference="reference/reference.fasta",
        reference_bedtools="reference/reference-genome-bedtools.tsv",
        vcf_all="assemblies/variant_all/{sample}.vcf.gz",
    output:
        mask="assemblies/mask/{sample}.bed.gz",
    conda:
        "envs/bedtools.yaml"
    log:
        "logs/assembly_consensus_mask.{sample}.consensus.log",
    shell:
        "bedtools merge -i {input.vcf_all} | bedtools sort | bedtools merge |"
        "bedtools complement -i - -g {input.reference_bedtools} | "
        "gzip > {output.mask}"


rule assembly_sourmash_sketch:
    input:
        assembly=lambda wildcards: sample_files.loc[wildcards.sample]['Assemblies'],
    output:
        "assemblies/sketch/{sample}.sig.gz"
    threads: 1
    conda:
        "envs/sourmash.yaml"
    params:
        sourmash_params=config["sourmash_params"],
    log:
        sketch="logs/sourmash_sketch.{sample}.sketch.log"
    shell:
        "sourmash sketch dna -p {params.sourmash_params} --name {wildcards.sample} "
        "--output {output}.tmp {input} 2>{log.sketch} && "
        "gzip {output}.tmp && mv {output}.tmp.gz {output}"


use rule assembly_sourmash_sketch as reads_paired_sourmash_sketch with:
    input:
        reads1=lambda wildcards: sample_files.loc[wildcards.sample]['Reads1'],
        reads2=lambda wildcards: sample_files.loc[wildcards.sample]['Reads2'],
    output:
        "reads/paired/sketch/{sample}.sig.gz"


use rule assembly_sourmash_sketch as reads_single_sourmash_sketch with:
    input:
        reads1=lambda wildcards: sample_files.loc[wildcards.sample]['Reads1'],
    output:
        "reads/single/sketch/{sample}.sig.gz"


rule basic_mlst:
    input:
        sample=lambda wildcards: sample_files.loc[wildcards.sample]['Assemblies'],
    output:
        "mlst/{sample}.tsv"
    conda:
        "envs/mlst.yaml"
    threads: 1
    log:
        mlst="logs/basic_mlst.{sample}.log"
    shell:
        "mlst --threads {threads} --nopath {input.sample} > {output} 2> {log.mlst}"


def sample_vcfs_input(wildcards) -> List[str]:
    if config["include_snpeff"]:
        return expand("assemblies/variant-snpeff/{sample}.vcf.gz",sample=samples_assemblies) + \
               expand("reads/paired/variant-snpeff/{sample}.vcf.gz",sample=samples_paired_reads) + \
               expand("reads/single/variant-snpeff/{sample}.vcf.gz",sample=samples_single_reads)
    else:
        return expand("assemblies/variant/{sample}.vcf.gz",sample=samples_assemblies) + \
               expand("reads/paired/variant/{sample}.vcf.gz",sample=samples_paired_reads) + \
               expand("reads/single/variant/{sample}.vcf.gz",sample=samples_single_reads)


def sample_mask_input(wildcards) -> List[str]:
    return expand("assemblies/mask/{sample}.bed.gz",sample=samples_assemblies) + \
           expand("reads/paired/mask/{sample}.bed.gz",sample=samples_paired_reads) + \
           expand("reads/single/mask/{sample}.bed.gz",sample=samples_single_reads)


def sample_sketches_input(wildcards) -> List[str]:
    if config["include_kmer"]:
        return expand("assemblies/sketch/{sample}.sig.gz",sample=samples_assemblies) + \
               expand("reads/paired/sketch/{sample}.sig.gz",sample=samples_paired_reads) + \
               expand("reads/single/sketch/{sample}.sig.gz",sample=samples_single_reads)
    else:
        return []


rule gdi_input_fofn:
    input:
        sample_vcfs=sample_vcfs_input,
        sample_masks=sample_mask_input,
        sample_sketches=sample_sketches_input,
    log:
        "logs/gdi_input_fofn.log"
    conda:
        "envs/main.yaml"
    message: "Generating input file to be loaded into index [{output}]."
    output:
        "gdi-input.fofn"
    script:
        "scripts/gdi-create-fofn.py"


rule mlst_full_table:
    input:
        sample_vcfs=expand("mlst/{sample}.tsv",sample=samples_assemblies),
    log:
        "logs/mlst_full_table.log"
    conda:
        "envs/main.yaml"
    output:
        "mlst.tsv"
    shell:
        "cat {input} > {output}"
