from typing import List
from pathlib import Path

import pandas as pd
from snakemake.utils import min_version

# Configuration

min_version("6.0")

configfile: "config/config.yaml"

sample_files = pd.read_csv(config["samples"],sep='\t').set_index("Sample",drop=False)
samples_assemblies = sample_files[~sample_files['Assemblies'].isna()].index.tolist()
samples_paired_reads = sample_files[(~sample_files['Reads1'].isna()) & (~sample_files['Reads2'].isna())].index.tolist()
samples_single_reads = sample_files[(~sample_files['Reads1'].isna()) & (sample_files['Reads2'].isna())].index.tolist()

# If both reads and assemblies are set, prefer the assemblies for analysis
samples_assemblies_set = set(samples_assemblies)
samples_paired_reads = list(set(samples_paired_reads) - samples_assemblies_set)
samples_single_reads = list(set(samples_single_reads) - samples_assemblies_set)

reference_abs = Path(config["reference"]).absolute()

# Workflow

rule all:
    input:
        "gdi-input.fofn",
        "mlst.tsv" if config["include_mlst"] else []


rule prepare_reference:
    input:
        reference=reference_abs,
    output:
        reference="reference/reference.fasta",
    threads: 1
    log:
        "logs/prepare-reference.log"
    script:
        "scripts/prepare-reference.py"


rule index_reference:
    input:
        reference="reference/reference.fasta",
    output:
        reference_mm2_index="reference/reference.fasta.mmi",
    conda:
        "envs/minimap2.yaml"
    threads: 1
    log:
        faidx="logs/index_reference.faidx.log",
        mm2_index="logs/index_reference.mm2_index.log"
    shell:
        "samtools faidx {input.reference} 2>&1 1>{log.faidx} && "
        "minimap2 -t {threads} -d {output.reference_mm2_index} {input.reference} 2>&1 1>{log.mm2_index}"


rule prepare_snpeff_database:
    input:
        reference=reference_abs,
    output:
        snpeff_db_conf="snpeff_db/db.conf",
    threads: 1
    log:
        "logs/prepare-snpeff-database.log"
    script:
        "scripts/prepare-snpeff-database.py"


rule assembly_align:
    input:
        reference_mm2_index="reference/reference.fasta.mmi",
        sample=lambda wildcards: sample_files.loc[wildcards.sample]['Assemblies'],
    output:
        "assemblies/align/{sample}.bam",
    conda:
        "envs/minimap2.yaml"
    threads: 1
    log:
        mm2="logs/assembly_align.{sample}.minimap2.log",
        samsort="logs/assembly_align.{sample}.samsort.log",
    shell:
        "minimap2 -t {threads} -a -x asm5 {input.reference_mm2_index} {input.sample} 2> {log.mm2} | "
        "samtools sort --threads {threads} --output-fmt BAM --write-index -o {output} 2> {log.samsort}"


rule assembly_variant:
    input:
        reference="reference/reference.fasta",
        bam="assemblies/align/{sample}.bam",
    output:
        "assemblies/variant/{sample}.vcf.gz",
    conda:
        "envs/bcftools.yaml"
    threads: 1
    log:
        mpileup="logs/assembly_variant.{sample}.mpileup.log",
        call="logs/assembly_variant.{sample}.call.log",
        norm="logs/assembly_variant.{sample}.norm.log",
    shell:
        "bcftools mpileup -a DP --min-ireads 1 --threads {threads} -Ou -f {input.reference} {input.bam} 2> {log.mpileup} | "
        "bcftools call --threads {threads} --ploidy 1 -Ou -mv 2> {log.call} | "
        "bcftools norm -f {input.reference} --threads {threads} --old-rec-tag -Oz -o {output} 2> {log.norm}"


rule reads_paired_snippy:
    input:
        reference="reference/reference.fasta",
        reads1=lambda wildcards: sample_files.loc[wildcards.sample]['Reads1'],
        reads2=lambda wildcards: sample_files.loc[wildcards.sample]['Reads2'],
    output:
        snippy_dir=directory("reads/snippy/paired/{sample}"),
        vcf="reads/paired/variant/{sample}.vcf.gz",
        consensus="reads/paired/consensus/{sample}.fasta.gz",
    conda:
        "envs/snippy.yaml"
    threads: 1
    params:
        mincov=config['reads_mincov'],
        minqual=config['reads_minqual'],
        subsample=config['reads_subsample'],
    log:
        snippy="logs/snippy_variant.{sample}.log",
    shell:
        "snippy --cpus {threads} --outdir {output.snippy_dir} --ref {input.reference} --mincov {params.mincov}"
        " --minqual {params.minqual} --subsample {params.subsample} --R1 {input.reads1} --R2 {input.reads2}"
        " 1> {log.snippy} 2>&1 && "
        "cp {output.snippy_dir}/snps.vcf.gz {output.vcf} && "
        "gzip --stdout {output.snippy_dir}/snps.aligned.fa > {output.consensus}"


rule reads_single_snippy:
    input:
        reference="reference/reference.fasta",
        reads1=lambda wildcards: sample_files.loc[wildcards.sample]['Reads1'],
    output:
        snippy_dir=directory("reads/snippy/single/{sample}"),
        vcf="reads/single/variant/{sample}.vcf.gz",
        consensus="reads/single/consensus/{sample}.fasta.gz",
    conda:
        "envs/snippy.yaml"
    threads: 1
    params:
        mincov=config['reads_mincov'],
        minqual=config['reads_minqual'],
        subsample=config['reads_subsample'],
    log:
        snippy="logs/snippy_variant.{sample}.log",
    shell:
        "snippy --cpus {threads} --outdir {output.snippy_dir} --ref {input.reference} --mincov {params.mincov}"
        " --minqual {params.minqual} --subsample {params.subsample} --se {input.reads1}"
        " 1> {log.snippy} 2>&1 && "
        "cp {output.snippy_dir}/snps.vcf.gz {output.vcf} && "
        "gzip --stdout {output.snippy_dir}/snps.aligned.fa > {output.consensus}"


rule assembly_variant_snpeff:
    input:
        snpeff_db_conf="snpeff_db/db.conf",
        vcf="assemblies/variant/{sample}.vcf.gz",
    output:
        snpeff_vcf="assemblies/variant-snpeff/{sample}.vcf.gz",
    threads: 1
    log:
        "logs/assembly_variant_snpeff.{sample}.log",
    script:
        "scripts/variant-snpeff.py"


use rule assembly_variant_snpeff as reads_variant_paired_snpeff with:
    input:
        snpeff_db_conf="snpeff_db/db.conf",
        vcf="reads/paired/variant/{sample}.vcf.gz",
    output:
        snpeff_vcf="reads/paired/variant-snpeff/{sample}.vcf.gz",


use rule assembly_variant_snpeff as reads_variant_single_snpeff with:
    input:
        snpeff_db_conf="snpeff_db/db.conf",
        vcf="reads/single/variant/{sample}.vcf.gz",
    output:
        snpeff_vcf="reads/single/variant-snpeff/{sample}.vcf.gz",


rule assembly_consensus:
    input:
        reference="reference/reference.fasta",
        bam="assemblies/align/{sample}.bam",
    output:
        "assemblies/consensus/{sample}.fasta.gz",
    conda:
        "envs/htsbox.yaml"
    log:
        consensus="logs/assembly_consensus.{sample}.consensus.log",
    shell:
        "htsbox pileup -f {input.reference} -d -F {input.bam} 2> {log.consensus} | "
        "gzip -c - > {output}"


rule assembly_sourmash_sketch:
    input:
        assembly=lambda wildcards: sample_files.loc[wildcards.sample]['Assemblies'],
    output:
        "assemblies/sketch/{sample}.sig.gz"
    threads: 1
    conda:
        "envs/sourmash.yaml"
    params:
        sourmash_params=config["sourmash_params"],
    log:
        sketch="logs/sourmash_sketch.{sample}.sketch.log"
    shell:
        "sourmash sketch dna -p {params.sourmash_params} --name {wildcards.sample} "
        "--output {output}.tmp {input} 2>{log.sketch} && "
        "gzip {output}.tmp && mv {output}.tmp.gz {output}"


use rule assembly_sourmash_sketch as reads_paired_sourmash_sketch with:
    input:
        reads1=lambda wildcards: sample_files.loc[wildcards.sample]['Reads1'],
        reads2=lambda wildcards: sample_files.loc[wildcards.sample]['Reads2'],
    output:
        "reads/paired/sketch/{sample}.sig.gz"


use rule assembly_sourmash_sketch as reads_single_sourmash_sketch with:
    input:
        reads1=lambda wildcards: sample_files.loc[wildcards.sample]['Reads1'],
    output:
        "reads/single/sketch/{sample}.sig.gz"


rule basic_mlst:
    input:
        sample=lambda wildcards: sample_files.loc[wildcards.sample]['Assemblies'],
    output:
        "mlst/{sample}.tsv"
    conda:
        "envs/mlst.yaml"
    threads: 1
    log:
        mlst="logs/basic_mlst.{sample}.log"
    shell:
        "mlst --threads {threads} --nopath {input.sample} > {output} 2> {log.mlst}"


def sample_vcfs_input(wildcards) -> List[str]:
    if config["include_snpeff"]:
        return expand("assemblies/variant-snpeff/{sample}.vcf.gz",sample=samples_assemblies) + \
               expand("reads/paired/variant-snpeff/{sample}.vcf.gz",sample=samples_paired_reads) + \
               expand("reads/single/variant-snpeff/{sample}.vcf.gz",sample=samples_single_reads)
    else:
        return expand("assemblies/variant/{sample}.vcf.gz",sample=samples_assemblies) + \
               expand("reads/paired/variant/{sample}.vcf.gz",sample=samples_paired_reads) + \
               expand("reads/single/variant/{sample}.vcf.gz",sample=samples_single_reads)


def sample_consensus_input(wildcards) -> List[str]:
    return expand("assemblies/consensus/{sample}.fasta.gz",sample=samples_assemblies) + \
           expand("reads/paired/consensus/{sample}.fasta.gz",sample=samples_paired_reads) + \
           expand("reads/single/consensus/{sample}.fasta.gz",sample=samples_single_reads)


def sample_sketches_input(wildcards) -> List[str]:
    if config["include_kmer"]:
        return expand("assemblies/sketch/{sample}.sig.gz",sample=samples_assemblies) + \
               expand("reads/paired/sketch/{sample}.sig.gz",sample=samples_paired_reads) + \
               expand("reads/single/sketch/{sample}.sig.gz",sample=samples_single_reads)
    else:
        return []


rule gdi_input_fofn:
    input:
        sample_vcfs=sample_vcfs_input,
        sample_consensus=sample_consensus_input,
        sample_sketches=sample_sketches_input,
    log:
        "logs/gdi_input_fofn.log"
    conda:
        "envs/main.yaml"
    message: "Generating input file to be loaded into index [{output}]."
    output:
        "gdi-input.fofn"
    script:
        "scripts/gdi-create-fofn.py"


rule mlst_full_table:
    input:
        sample_vcfs=expand("mlst/{sample}.tsv",sample=samples_assemblies),
    log:
        "logs/mlst_full_table.log"
    conda:
        "envs/main.yaml"
    output:
        "mlst.tsv"
    shell:
        "cat {input} > {output}"
