"""The snake make file that controls the process, imported by a python warper"""
import os
import glob
import pathlib
import pandas as pd
from join_asvbins.snake_functions import combine_mbstats_barrnap, \
    pullseqs_header_name_from_tab, filter_from_mbstats, set_program_output, \
    CANDIDATE_16S_SEQS_PATH



search_tool = 'blast' if config['blast'] else 'mmseqs'
candidate_16S_seqs = config.get('candidate_16S_seqs')
# TODO decide if this is neccicary
# output_name = config.get('output_name')
bins_path = config.get('bins')
candidate_asv_seqs = config.get('candidate_asv_seqs')
asv_seqs_path = config.get('asv_seqs')
fasta_extention = config.get('fasta_extention')
generic_16s_path = config.get('generic_16S')
allow_empty= config.get('allow_empty')
qiime_out= config.get('qiime_out')
s1_mmseqs_sensitivity = config.get('s1_mmseqs_sensitivity')
s2_mmseqs_sensitivity = config.get('s2_mmseqs_sensitivity')
min_len_with_overlap = config.get("min_len_with_overlap")
min_len_pct_no_overlap = config.get("min_len_pct_no_overlap")
s1_min_pct_id = config.get('s1_min_pct_id')
s2_min_pct_id = config.get('s2_min_pct_id')
s1_min_length = config.get('s1_min_length')
s2_min_length = config.get('s2_min_length')
s2_min_len_pct = config.get('min_len_pct')
s2_min_len_pct = config.get('min_len_pct')
s2_max_gaps = config.get('max_gaps')
s2_max_missmatch = config.get('max_missmatch')
verbosity= config.get('verbosity')
LOCALY_COMBINED_BINS = "all_bins_combined"
UNQIIME_ASV_FASTA = "asv_seqs.fa"
from snakemake.remote.HTTP import RemoteProvider as HTTPRemoteProvider

HTTP = HTTPRemoteProvider()



# Set the appropriate output, This can be moved
if bins_path is not None:
    if os.path.isdir(bins_path):
        bins_folder = bins_path
        path_to_combined_bins = LOCALY_COMBINED_BINS
    else:
        bins_folder = "NA"
        path_to_combined_bins = bins_path
else:
    bins_folder = "NA"
    path_to_combined_bins = "NA"

# Set the appropriate output, This also can be moved
if asv_seqs_path is None:
    asv_seqs_qva = 'NA'
    asv_seqs_fa = 'NA'
elif asv_seqs_path.endswith('.qza'):
    asv_seqs_qva = asv_seqs_path
    asv_seqs_fa = UNQIIME_ASV_FASTA
else:
    asv_seqs_qva = 'NA'
    asv_seqs_fa = asv_seqs_path


if candidate_16S_seqs is None:
    candidate_16S_seqs = CANDIDATE_16S_SEQS_PATH


rule all:
    input:
        set_program_output(bins_path, asv_seqs_path,  qiime_out)

rule combine_barrnap_with_other: # where other is blast or mmseqs
    input:
        mbstats_fasta_path = f"stage1_asvs_{search_tool}_matches.fna",
        mbstats_stats_path = f"stage1_asvs_{search_tool}.tab",
        barrnap_fasta_path = "barrnap_fasta-16S.fna",
        barrnap_stats_path= "barrnap_16S-gff.gff"
    output:
        out_fasta_path = protected(candidate_16S_seqs),
        out_stats_path = protected("candidate_statistics.tsv")
    run:
        combine_mbstats_barrnap(**input,
                                **output,
                                search_tool=search_tool,
                                allow_empty=allow_empty,
                                min_pct_id=s1_min_pct_id,
                                min_len_with_overlap=min_len_with_overlap,
                                min_len_pct_no_overlap=min_len_pct_no_overlap,
                                min_length=s1_min_length)


rule get_qiime2_environment:
    input:
        HTTP.remote("data.qiime2.org/distro/core/qiime2-2021.11-py38-linux-conda.yml", keep_local=True)
    output:
        temp('conda_qiime2.yml')
    run:
        outputName = os.path.basename(input[0])
        shell("mv {input} {output}")


rule get_asv_fa_folder_from_qiime:
    input:
        asv_seqs_qva
    output:
        temp(directory("qiime_data"))
    conda:
        'conda_qiime2.yml'
    shell:
       """
       qiime tools export \\
            --input-path {input[0]} \\
            --output-path {output}
       """


rule get_fa_from_qiime_folder:
    input:
       "qiime_data"
    output:
        temp(UNQIIME_ASV_FASTA)
    shell:
       """
       cat {input}/*.fasta > {output}
       """


rule combine_input_fa:
    input:
        bins_folder
    output:
        temp(LOCALY_COMBINED_BINS)
    run:
       input_list = glob.glob(os.path.join(input[0], f"*.{fasta_extention}"))
       if fasta_extention.endswith('gz'):
               shell(f"gzip -cd  {input_list} >> {{output}}")
       else:
               shell(f"cat {' '.join(input_list)} >> {{output}}")


rule mmseqs_stage1_search:
    input:
        path_to_combined_bins,
        generic_16s_path # Query
    output:
        temp(directory("mmseqs_stage1_db")),
        temp("stage1_asvs_mmseqs.tab"),
        temp(directory("temp"))
    threads:
        workflow.cores
    params:
        sensitivity = s2_mmseqs_sensitivity,
        verbosity = verbosity if verbosity <= 3 else 3
    shell:
        """
        mkdir {output[0]}
        mmseqs createdb -v {params.verbosity} {input[0]} {output[0]}/target
        mmseqs createdb -v {params.verbosity} {input[1]} {output[0]}/query
        mmseqs search --search-type 3 \\
               -v {params.verbosity} \\
               -s {params.sensitivity} \\
               --threads {threads} \\
               {output[0]}/query \\
               {output[0]}/target \\
               {output[0]}/mmseqs_out \\
               temp
        mmseqs convertalis \\
               -v {params.verbosity} \\
               --format-output \'query,target,pident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,qlen,tlen\' \\
               {output[0]}/query \\
               {output[0]}/target \\
               {output[0]}/mmseqs_out \\
               {output[1]}
        """


rule blast_stage1_search:
    input:
        path_to_combined_bins,
        generic_16s_path # Query
    output:
        temp(directory("blast_stage1_db")),
        temp("stage1_asvs_blast.tab")
    run:
       shell("mkdir {output[0]}")
       shell("makeblastdb -dbtype nucl -in {input[0]} -out {output[0]}/blast_db")
       shell("blastn -db {output[0]}/blast_db -out {output[1]} -query {input[1]} -outfmt \"6 qseqid sseqid pident length mismatch gapopen qstart qend sstart send evalue bitscore qlen slen\"")


rule pullseq_header_name:
    input:
        path_to_combined_bins,
        "{level}_asvs_{tool}.tab"
    output:
        temp("{level}_asvs_{tool}_matches.fna")
    run:
       pullseqs_header_name_from_tab(in_fasta_path=input[0],
                                     out_fasta_path=output[0],
                                     tab_file_path=input[1],
                                     header_column='sseqid')


rule mmseqs_stage2_search:
    input:
       candidate_16S_seqs, # Target
       asv_seqs_fa # Query
    output:
       temp(directory("mmseqs_stage2_db")),
       temp("stage2_asvs_mmseqs.tab"),
       temp(directory("temp"))
    threads:
        workflow.cores
    params:
        sensitivity = s1_mmseqs_sensitivity,
        verbosity = verbosity if verbosity <= 3 else 3
    shell:
        # TODO Split out the db creation, into other rules if it makes sense seeing as you may need to limit cores
        """
        mkdir {output[0]}
        mmseqs createdb -v {params.verbosity} \\
                        {input[0]} {output[0]}/target
        mmseqs createdb -v {params.verbosity} \\
                        {input[1]} {output[0]}/query
        mmseqs search --search-type 3 \\
               {output[0]}/query \\
               {output[0]}/target \\
               {output[0]}/mmseqs_out \\
               temp \\
               -s {params.sensitivity} \\
               -v {params.verbosity}
        mmseqs convertalis  \\
               -v {params.verbosity} \\
               --format-output \'query,target,pident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,qlen,tlen\' \\
               {output[0]}/query \\
               {output[0]}/target \\
               {output[0]}/mmseqs_out \\
               {output[1]}
        """


rule stage2_filtering:
    input:
       stats_file_in = f"stage2_asvs_{search_tool}.tab",
       fasta_file_in = candidate_16S_seqs
    output:
        fasta_file_out = protected("match_sequences.fna"),
        stats_file_out = protected("match_statistics.tsv")
    run:
       filter_from_mbstats(
                       **input,
                       **output,
                       min_pct_id=s2_min_pct_id,
                       min_length=s2_min_length,
                       min_len_pct=s2_min_len_pct,
                       max_gaps=s2_max_gaps,
                       max_missmatch=s2_max_missmatch,
                       search_tool=search_tool
                       )


rule blast_stage2_search:
    input:
        candidate_16S_seqs, # Target
        asv_seqs_fa, # Query
    output:
        temp(directory("blast_stage2_db")),
        temp("stage2_asvs_blast.tab")
    shell:
        """
        mkdir {output[0]}
        makeblastdb -dbtype nucl -in {input[0]} -out {output[0]}/blast_db
        blastn -db {output[0]}/blast_db -out {output[1]} -query {input[1]} \
        -outfmt \"6 qseqid sseqid pident length mismatch gapopen qstart qend sstart send evalue bitscore qlen slen\"
        """


rule run_barrnap_barrnap:
    input:
        path_to_combined_bins
    output:
        temp("barrnap_rrna.gff")
    threads:
        workflow.cores
    params:
        verbosity = "--quiet" if verbosity < 3 else ""
    shell:
        "barrnap --threads {threads} {params.verbosity} {input} > {output[0]}"


rule run_barrnap_16s_gtff:
    input:
        "barrnap_rrna.gff"
    output:
        temp("barrnap_16S-gff.gff")
    shell:
        "grep \"16S\" {input} > {output}"


rule run_barrnap_fasta_filter:
    input:
        path_to_combined_bins,
        "barrnap_16S-gff.gff"
    output:
        temp("barrnap_fasta_raw.fna"),
        temp(f"{path_to_combined_bins}.fai")
    shell:
        "bedtools getfasta -fi {input[0]} -bed {input[1]} -fo {output[0]}"


rule run_barrnap_headers:
    input:
        "barrnap_fasta_raw.fna"
    output:
        temp("barrnap_16S-id.txt")
    shell:
        "grep \">\" {input} | sed 's/>//g' > {output}"


rule run_barrnap_fasta_trim:
    input:
        "barrnap_fasta_raw.fna",
        "barrnap_16S-id.txt"
    output:
        temp("barrnap_fasta-16S.fna"),
        temp("barrnap_fasta_raw.fna.fai")
    shell:
        "xargs samtools faidx {input[0]} < {input[1]} > {output[0]}"



rule export_fa_to_qiime:
    input:
        "match_sequences.fna",
        'conda_qiime2.yml'
    output:
        protected("match_sequences.qza")
    conda:
        'conda_qiime2.yml'
    shell:
       """
       qiime tools import\\
            --input-path {input[0]} \\
            --output-path {output} \\
            --type SampleData[Sequences]
       """
