import shutil
import textwrap
import json
from itertools import islice
import dnaio
import igdiscover
from igdiscover.dna import reverse_complement
from igdiscover.utils import relative_symlink
from igdiscover.readlenhistogram import read_length_histogram
from igdiscover.config import Config


try:
    config = Config.from_default_path()
except FileNotFoundError as e:
    sys.exit("Pipeline configuration file {!r} not found. Please create it!".format(e.filename))

# Use pigz (parallel gzip) if available
GZIP = 'pigz' if shutil.which('pigz') is not None else 'gzip'

PREPROCESSED_READS = 'reads/sequences.fasta.gz'

if config.debug:
    # Do not delete intermediate files when debugging
    temp = lambda x: x

# Targets for each iteration
ITERATION_TARGETS = [
    'clusterplots/done',
    'errorhistograms.pdf',
    'v-shm-distributions.pdf',
] + expand(['expressed_{gene}.tab', 'expressed_{gene}.pdf', 'dendrogram_{gene}.pdf'], gene=['V', 'D', 'J'])

# Targets for non-final iterations
DISCOVERY_TARGETS = [
    'candidates.tab',
    'new_V_germline.fasta',
    'new_V_pregermline.fasta',
]
TARGETS = expand('iteration-{nr:02d}/{path}', nr=range(1, config.iterations+1), path=ITERATION_TARGETS + DISCOVERY_TARGETS)
TARGETS += [
    'stats/readlengths.pdf',
    'stats/merging-successful',
    'stats/trimming-successful',
    'stats/stats_nofinal.json'
]
if config.iterations >= 1:
    TARGETS += ['iteration-01/new_J.fasta']

FINAL_TARGETS = expand('final/{path}', path=ITERATION_TARGETS) + ['stats/stats.json']


rule all:
    input:
        TARGETS + FINAL_TARGETS
    message: "IgDiscover finished."


rule nofinal:
    input:
        TARGETS


def fasta_head(input_path, output_path, n):
    with dnaio.open(input_path) as infile:
        with dnaio.open(output_path, mode="w") as outfile:
            for record in islice(infile, n):
                outfile.write(record)


if config.limit:
    rule limit_reads_gz:
        output: 'reads/1-limited.{nr,([12]\\.|)}{ext,(fasta|fastq)}.gz'
        input: 'reads.{nr}{ext}.gz'
        run:
            fasta_head(input[0], output[0], config.limit)

    rule limit_reads:
        output: 'reads/1-limited.{nr,([12]\\.|)}{ext,(fasta|fastq)}.gz'
        input: 'reads.{nr}{ext}'
        run:
            fasta_head(input[0], output[0], config.limit)

else:
    rule symlink_limited:
        output: fastaq='reads/1-limited.{nr,([12]\\.|)}{ext,(fasta|fastq)}.gz'
        input: fastaq='reads.{nr}{ext}.gz'
        resources: time=1
        run:
            relative_symlink(input.fastaq, output.fastaq, force=True)

    # TODO compressing the input file is an unnecessary step
    rule gzip_limited:
        output: fastaq='reads/1-limited.{nr,([12]\\.|)}{ext,(fasta|fastq)}.gz'
        input: fastaq='reads.{nr}{ext}'
        shell:
            '{GZIP} < {input} > {output}'


# After the rules above, we either end up with
#
# 'reads/1-limited.1.fastq.gz' and 'reads/1-limited.2.fastq.gz'
# or
# 'reads/1-limited.fasta.gz'
# or
# 'reads/1-limited.fastq.gz'


if config.merge_program == 'flash':
    rule flash_merge:
        """Use FLASH to merge paired-end reads"""
        output: fastqgz='reads/2-merged.fastq.gz', log='reads/2-flash.log'
        input: 'reads/1-limited.1.fastq.gz', 'reads/1-limited.2.fastq.gz'
        resources: time=60
        threads: 8
        shell:
            # -M: maximal overlap (2x300, 420-450bp expected fragment size)
            "time flash -t {threads} -c -M {config.flash_maximum_overlap} {input} 2> "
            ">(tee {output.log} >&2) | {GZIP} > {output.fastqgz}"

    rule parse_flash_stats:
        input: log='reads/2-flash.log'
        output:
            json='stats/reads.json'
        run:
            total_ex = re.compile(r'\[FLASH\]\s*Total pairs:\s*([0-9]+)')
            merged_ex = re.compile(r'\[FLASH\]\s*Combined pairs:\s*([0-9]+)')
            with open(input.log) as f:
                for line in f:
                    match = total_ex.search(line)
                    if match:
                        total = int(match.group(1))
                        continue
                    match = merged_ex.search(line)
                    if match:
                        merged = int(match.group(1))
                        break
                else:
                    sys.exit('Could not parse the FLASH log file')
            d = OrderedDict({'total': total})
            d['merged'] = merged
            d['merging_was_done'] = True
            with open(output.json, 'w') as f:
                json.dump(d, f)


elif config.merge_program == 'pear':

    rule pear_merge:
        """Use pear to merge paired-end reads"""
        output:
            fastq='reads/2-merged.fastq.gz',
            log='reads/2-pear.log'
        input:
            fastq1='reads/1-limited.1.fastq.gz',
            fastq2='reads/1-limited.2.fastq.gz'
        log: 'reads/2-pear.log'

        resources: time=60
        threads: 20
        shell:
            "igdiscover merge -j {threads} {input.fastq1} {input.fastq2} {output.fastq} | tee {log}"

    rule parse_pear_stats:
        input: log='reads/2-pear.log'
        output:
            json='stats/reads.json'
        run:
            expression = re.compile(r"Assembled reads \.*: (?P<merged>[0-9,]*) / (?P<total>[0-9,]*)")
            with open(input.log) as f:
                for line in f:
                    match = expression.search(line)
                    if match:
                        merged = int(match.group('merged').replace(',', ''))
                        total = int(match.group('total').replace(',', ''))
                        break
                else:
                    sys.exit('Could not parse the PEAR log file')
            d = OrderedDict({'total': total})
            d['merged'] = merged
            d['merging_was_done'] = True
            with open(output.json, 'w') as f:
                json.dump(d, f)
else:
    sys.exit("merge_program {config.merge_program!r} given in configuration file not recognized".format(config=config))


# This rule applies only when the input is single-end or already merged
rule symlink_merged:
    output:
        fastaq='reads/2-merged.{ext,(fasta|fastq)}.gz'
    input: fastaq='reads/1-limited.{ext}.gz'
    run:
        relative_symlink(input.fastaq, output.fastaq, force=True)


# After the rules above, we end up with
#
# 'reads/2-merged.fasta.gz'
# or
# 'reads/2-merged.fastq.gz'


rule read_stats_single_fasta:
    """Compute statistics if no merging was done (FASTA input)"""
    output: json='stats/reads.json',
    input: fastagz='reads/1-limited.fasta.gz'
    run:
        total = count_sequences(input.fastagz)
        d = OrderedDict({'total': total})
        d['merged'] = total
        d['merging_was_done'] = False
        with open(output.json, 'w') as f:
            json.dump(d, f)


rule read_stats_single_fastq:
    """Compute statistics if no merging was done (FASTQ input)"""
    output: json='stats/reads.json',
    input: fastagz='reads/1-limited.fastq.gz'
    run:
        total = count_sequences(input.fastagz)
        d = OrderedDict({'total': total})
        d['merged'] = total
        d['merging_was_done'] = False
        with open(output.json, 'w') as f:
            json.dump(d, f)


rule check_merging:
    """Ensure the merging succeeded"""
    output: success='stats/merging-successful'
    input:
        json='stats/reads.json'
    run:
        with open(input.json) as f:
            d = json.load(f)
        total = d['total']
        merged = d['merged']
        if total == 0 or merged / total >= 0.3:
            with open(output.success, 'w') as f:
                print('This marker file exists if at least 30% of the input '
                    'reads could be merged.', file=f)
        else:
            sys.exit('Less than 30% of the input reads could be merged. Please '
                'check whether there is an issue with your input data. To skip '
                'this check, create the file "stats/merging-successful" and '
                're-start "igdiscover run".')


rule merged_read_length_histogram:
    output:
        txt="stats/merged.readlengths.txt",
        pdf="stats/merged.readlengths.pdf"
    input:
        fastq='reads/2-merged.fastq.gz'
    run:
        read_length_histogram(input.fastq, tsv_path=output.txt, plot_path=output.pdf, bins=100, left=config.minimum_merged_read_length, title="Lengths of merged reads")


rule read_length_histogram:
    output:
        txt="stats/readlengths.txt",
        pdf="stats/readlengths.pdf"
    input:
        fastq=PREPROCESSED_READS
    run:
        read_length_histogram(input.fastq, tsv_path=output.txt, plot_path=output.pdf, bins=100, left=config.minimum_merged_read_length, title="Lengths of pre-processed reads")


rule reads_stats_fasta:
    """
    TODO implement this
    """
    output: txt="stats/reads.txt"
    input:
        merged='reads/1-limited.fasta.gz'
    shell: "touch {output}"


# Remove primer sequences

if config.forward_primers:
    # At least one forward primer is to be removed
    rule trim_forward_primers:
        output: fastaq=temp('reads/3-forward-primer-trimmed.{ext,(fasta|fastq)}.gz')
        input: fastaq='reads/2-merged.{ext}.gz', mergesuccess='stats/merging-successful'
        resources: time=120
        log: 'reads/3-forward-primer-trimmed.{ext}.log'
        params:
            fwd_primers=''.join(' -g ^{}'.format(seq) for seq in config.forward_primers),
            rev_primers=''.join(' -a {}$'.format(reverse_complement(seq)) for seq in config.forward_primers) if not config.stranded else '',
        shell:
            "cutadapt --discard-untrimmed"
            "{params.fwd_primers}"
            "{params.rev_primers}"
            " -o {output.fastaq} {input.fastaq} | tee {log}"
else:
    # No trimming, just symlink the file
    rule dont_trim_forward_primers:
        output: fastaq='reads/3-forward-primer-trimmed.{ext,(fasta|fastq)}.gz'
        input: fastaq='reads/2-merged.{ext}.gz', mergesuccess='stats/merging-successful'
        resources: time=1
        run:
            relative_symlink(input.fastaq, output.fastaq, force=True)


if config.reverse_primers:
    # At least one reverse primer is to be removed
    rule trim_reverse_primers:
        output: fastaq='reads/4-trimmed.{ext,(fasta|fastq)}.gz'
        input: fastaq='reads/3-forward-primer-trimmed.{ext}.gz'
        resources: time=120
        log: 'reads/4-trimmed.{ext}.log'
        params:
            # Reverse primers should appear reverse-complemented at the 3' end
            # of the merged read.
            fwd_primers=''.join(' -a {}$'.format(reverse_complement(seq)) for seq in config.reverse_primers),
            rev_primers=''.join(' -g ^{}'.format(seq) for seq in config.reverse_primers) if not config.stranded else ''
        shell:
            "cutadapt --discard-untrimmed"
            "{params.fwd_primers}"
            "{params.rev_primers}"
            " -o {output.fastaq} {input.fastaq} | tee {log}"

else:
    # No trimming, just symlink the file
    rule dont_trim_reverse_primers:
        output: fastaq='reads/4-trimmed.{ext,(fasta|fastq)}.gz'
        input: fastaq='reads/3-forward-primer-trimmed.{ext}.gz'
        resources: time=1
        run:
            if config.forward_primers:
                # The target is marked as temp().
                # To avoid a dangling symlink, use a hardlink.
                os.link(input.fastaq, output.fastaq)
            else:
                relative_symlink(input.fastaq, output.fastaq, force=True)


rule trimmed_fasta_stats:
    output: json='stats/trimmed.json',
    input: fastagz='reads/4-trimmed.fasta.gz'
    run:
        with open(output.json, 'w') as f:
            json.dump({'trimmed': count_sequences(input.fastagz)}, f)


rule trimmed_fastq_stats:
    output: json='stats/trimmed.json',
    input: fastqgz='reads/4-trimmed.fastq.gz'
    run:
        with open(output.json, 'w') as f:
            json.dump({'trimmed': count_sequences(input.fastqgz)}, f)


rule check_trimming:
    """Ensure that some reads are left after trimming"""
    output: success='stats/trimming-successful'
    input:
        reads_json='stats/reads.json',
        trimmed_json='stats/trimmed.json'
    run:
        with open(input.reads_json) as f:
            total = json.load(f)['total']
        with open(input.trimmed_json) as f:
            trimmed = json.load(f)['trimmed']
        if total == 0 or trimmed / total >= 0.1:
            with open(output.success, 'w') as f:
                print('This marker file exists if at least 10% of input '
                    'reads contain the required primer sequences.', file=f)
        else:
            print(*textwrap.wrap(
                'Less than 10% of the input reads contain the required primer '
                'sequences. Please check whether you have specified the '
                'correct primer sequences in the configuration file. To skip '
                'this check, create the file "stats/trimming-successful" and '
                're-start "igdiscover run".'), sep='\n')
            sys.exit(1)


def group_cdr3_arg():
    if not config.cdr3_location:
        cdr3_arg = ''
    elif config.cdr3_location == 'detect':
        cdr3_arg = ' --real-cdr3'
    else:
        cdr3_arg = ' --pseudo-cdr3={}:{}'.format(*config.cdr3_location)
    return cdr3_arg


for ext in ('fasta', 'fastq'):
    if config.barcode_length and config.barcode_consensus:
        rule:
            """Group by barcode and CDR3 (also implicitly removes duplicates)"""
            output:
                fastagz=PREPROCESSED_READS,
                pdf="stats/groupsizes.pdf",
                groups="reads/4-groups.tab.gz",
                json="stats/groups.json"
            input:
                fastaq='reads/4-trimmed.{ext}.gz'.format(ext=ext), success='stats/trimming-successful'
            log: 'reads/4-sequences.fasta.gz.log'
            params:
                race_arg=' --trim-g' if config.race_g else '',
                cdr3_arg=group_cdr3_arg(),
            shell:
                "igdiscover group"
                "{params.cdr3_arg}{params.race_arg}"
                " --json={output.json}"
                " --minimum-length={config.minimum_merged_read_length}"
                " --groups-output={output.groups}"
                " --barcode-length={config.barcode_length}"
                " --plot-sizes={output.pdf}"
                " {input.fastaq} 2> {log} | {GZIP} > {output.fastagz}"

    else:
        rule:
            """Collapse identical sequences, remove barcodes"""
            output:
                fastagz=PREPROCESSED_READS,
                json="stats/groups.json"
            input:
                fastaq='reads/4-trimmed.{ext}.gz'.format(ext=ext), success='stats/trimming-successful',
            params:
                barcode_length=' --barcode-length={}'.format(config.barcode_length) if config.barcode_length else '',
                race_arg = ' --trim-g' if config.race_g else '',
            shell:
                "igdiscover dereplicate"
                "{params.barcode_length}{params.race_arg}"
                " --json={output.json}"
                " --minimum-length={config.minimum_merged_read_length}"
                " {input.fastaq} | {GZIP} > {output.fastagz}"


rule copy_d_database:
    """Copy D gene database into the iteration folder"""
    output:
        fasta="{base}/database/D.fasta"
    input:
        fasta="database/D.fasta"
    shell:
        "cp -p {input} {output}"


rule vj_database_iteration_1:
    """Copy original V or J gene database into the iteration 1 folder"""
    output:
        fasta="iteration-01/database/{gene,[VJ]}.fasta"
    input:
        fasta="database/{gene}.fasta"
    shell:
        "cp -p {input} {output}"


def ensure_fasta_not_empty(path, gene):
    with dnaio.open(path) as fr:
        for _ in fr:
            has_records = True
            break
        else:
            has_records = False
    if not has_records:
        print(
            'ERROR: No {gene} genes were discovered in this iteration (file '
            '{path!r} is empty)! Cannot continue.\n'
            'Check whether the starting database is of the correct chain type '
            '(heavy, light lambda, light kappa). It needs to match the type '
            'of sequences you analyze.'.format(gene=gene, path=path), file=sys.stderr)
        sys.exit(1)


for i in range(2, config.iterations + 1):
    rule:
        output:
            fasta='iteration-{nr:02d}/database/V.fasta'.format(nr=i)
        input:
            fasta='iteration-{nr:02d}/new_V_pregermline.fasta'.format(nr=i-1)
        run:
            ensure_fasta_not_empty(input.fasta, 'V')
            shell("cp -p {input.fasta} {output.fasta}")

    rule:
        # Even with multiple iterations, J genes are discovered only once
        output:
            fasta='iteration-{nr:02d}/database/J.fasta'.format(nr=i)
        input:
            fasta='iteration-01/new_J.fasta' if config.j_discovery['propagate'] else 'database/J.fasta'
        run:
            ensure_fasta_not_empty(input.fasta, 'J')
            shell("cp -p {input.fasta} {output.fasta}")


# Rules for last iteration

if config.iterations == 0:
    # Copy over the input database (would be nice to avoid this)
    rule copy_database:
        output:
            fasta='final/database/{gene,[VJ]}.fasta'
        input:
            fasta='database/{gene}.fasta'
        shell:
            "cp -p {input.fasta} {output.fasta}"
else:
    rule copy_final_v_database:
        output:
            fasta='final/database/V.fasta'
        input:
            fasta='iteration-{nr:02d}/new_V_germline.fasta'.format(nr=config.iterations)
        run:
            ensure_fasta_not_empty(input.fasta, 'V')
            shell("cp -p {input.fasta} {output.fasta}")

    rule copy_final_j_database:
        output:
            fasta='final/database/J.fasta'
        input:
            fasta=('iteration-01/new_J.fasta'
                    if config.j_discovery['propagate'] else 'database/J.fasta')
        run:
            ensure_fasta_not_empty(input.fasta, 'J')
            shell("cp -p {input.fasta} {output.fasta}")


rule igdiscover_igblast:
    output:
        tabgz=temp("{dir}/airr.tsv.gz"),
    input:
        fastagz=PREPROCESSED_READS,
        db_v="{dir}/database/V.fasta",
        db_d="{dir}/database/D.fasta",
        db_j="{dir}/database/J.fasta"
    params:
        penalty=' --penalty {}'.format(config.mismatch_penalty) if config.mismatch_penalty is not None else '',
        database='{dir}/database',
        species=' --species={}'.format(config.species) if config.species else '',
        sequence_type=' --sequence-type={}'.format(config.sequence_type),
    threads: 16
    shell:
        "time igdiscover igblastwrap{params.sequence_type}{params.penalty} --threads={threads}"
        "{params.species} {params.database} {input.fastagz} | "
        "{GZIP} > {output.tabgz}"


rule igdiscover_augment:
    output:
        tabgz="{dir}/assigned.tsv.gz",
        json="{dir}/stats/assigned.json"
    input:
        airr_tsv="{dir}/airr.tsv.gz",
        db_v="{dir}/database/V.fasta",
        db_d="{dir}/database/D.fasta",
        db_j="{dir}/database/J.fasta"
    params:
        database='{dir}/database',
        sequence_type=' --sequence-type={}'.format(config.sequence_type),
        rename=' --rename {path!r}_'.format(path=os.path.basename(os.getcwd())) if config.rename else ''
    shell:
        "time igdiscover augment{params.sequence_type}{params.rename}"
        " --stats={output.json} {params.database} {input.airr_tsv}"
        " | "
        "{GZIP} > {output.tabgz}"


rule check_parsing:
    output:
        success="{dir}/stats/parsing-successful"
    input:
        json="{dir}/stats/assigned.json"
    run:
        with open(input.json) as f:
            d = json.load(f)
        n = d['total']
        detected_cdr3s = d['detected_cdr3s']
        if n == 0:
            print('No IgBLAST assignments found, something is wrong.')
            sys.exit(1)
        elif detected_cdr3s / n >= 0.1:
            with open(output.success, 'w') as f:
                print('This marker file exists if a CDR3 sequence could be '
                    'detected for at least 10% of IgBLAST-assigned sequences.',
                    file=f)
        else:
            print(*textwrap.wrap(
                'A CDR3 sequence could be detected in less than 10% of the '
                'IgBLAST-assigned sequences. Possibly there is a problem with '
                'the starting database. To skip this check and continue anyway, '
                'create the file "{}" and re-start "igdiscover run".'.format(
                    output.success)), sep='\n')
            sys.exit(1)



rule igdiscover_filter:
    output:
        filtered="{dir}/filtered.tsv.gz",
        json="{dir}/stats/filtered.json"
    input:
        assigned="{dir}/assigned.tsv.gz",
        success="{dir}/stats/parsing-successful"
    run:
        conf = config.preprocessing_filter
        criteria = ['--v-coverage={}'.format(conf['v_coverage'])]
        criteria += ['--j-coverage={}'.format(conf['j_coverage'])]
        criteria += ['--v-evalue={}'.format(conf['v_evalue'])]
        criteria = ' '.join(criteria)
        shell("igdiscover filter --json={output.json} {criteria} {input.assigned} | {GZIP} > {output.filtered}")


rule igdiscover_exact:
    output:
        exact="{dir}/exact.tab"
    input:
        filtered="{dir}/filtered.tsv.gz"
    shell:
        # extract rows where V_errors == 0
        """zcat {input.filtered} |"""
        """ awk 'NR==1 {{ for(i=1;i<=NF;i++) if ($i == "V_errors") col=i}};NR==1 || $col == 0' > {output}"""


rule igdiscover_count:
    output:
        plot="{dir}/expressed_{gene,[VDJ]}.pdf",
        counts="{dir}/expressed_{gene}.tab"
    input:
        tab="{dir}/filtered.tsv.gz"
    shell:
        "igdiscover count --gene={wildcards.gene} "
        "--allele-ratio=0.2 "
        "--plot={output.plot} {input.tab} > {output.counts}"


rule igdiscover_clusterplot:
    output:
        done="{dir}/clusterplots/done"
    input:
        tab="{dir}/filtered.tsv.gz"
    params:
        clusterplots="{dir}/clusterplots/",
        ignore_j=' --ignore-J' if config.ignore_j else ''
    shell:
        "igdiscover clusterplot{params.ignore_j} {input.tab} {params.clusterplots} && touch {output.done}"


rule igdiscover_discover:
    """Discover potential new V gene sequences"""
    output:
        tab="{dir}/candidates.tab",
        read_names="{dir}/read_names_map.tab",
    input:
        v_reference="{dir}/database/V.fasta",
        tab="{dir}/filtered.tsv.gz"
    params:
        ignore_j=' --ignore-J' if config.ignore_j else '',
        seed=' --seed={}'.format(config.seed) if config.seed is not None else '',
        exact_copies=' --exact-copies={}'.format(config.exact_copies) if config.exact_copies is not None else ''
    threads: 128
    shell:
        "time igdiscover discover -j {threads}{params.seed}{params.ignore_j}{params.exact_copies}"
        " --d-coverage={config.d_coverage}"
        " --read-names={output.read_names}"
        " --subsample={config.subsample} --database={input.v_reference}"
        " {input.tab} > {output.tab}"


def db_whitelist_or_not(wildcards):
    filterconf = config.pre_germline_filter if wildcards.pre == 'pre' else config.germline_filter
    if filterconf['whitelist']:
        # Use original (non-iteration-specific) database as whitelist
        return 'database/V.fasta'
    else:
        return []


def germlinefilter_criteria(wildcards, input):
    nr = int(wildcards.nr, base=10)
    conf = config.pre_germline_filter if wildcards.pre == 'pre' else config.germline_filter
    criteria = []
    for path in [input.db_whitelist, input.whitelist]:
        if path:
            criteria += ['--whitelist=' + path]
    if conf['allow_stop']:
        criteria += ['--allow-stop']
    # if conf['allow_chimeras']:
    #    criteria += ['--allow-chimeras']
    criteria += ['--unique-cdr3s={}'.format(conf['unique_cdr3s'])]
    criteria += ['--cluster-size={}'.format(conf['cluster_size'])]
    criteria += ['--unique-J={}'.format(conf['unique_js'])]
    criteria += ['--cross-mapping-ratio={}'.format(conf['cross_mapping_ratio'])]
    criteria += ['--clonotype-ratio={}'.format(conf['clonotype_ratio'])]
    criteria += ['--exact-ratio={}'.format(conf['exact_ratio'])]
    criteria += ['--cdr3-shared-ratio={}'.format(conf['cdr3_shared_ratio'])]
    criteria += ['--unique-D-ratio={}'.format(conf['unique_d_ratio'])]
    criteria += ['--unique-D-threshold={}'.format(conf['unique_d_threshold'])]
    return ' '.join(criteria)


rule igdiscover_germlinefilter:
    """Construct a new database out of the discovered sequences"""
    output:
        tab='iteration-{nr}/new_V_{pre,(pre|)}germline.tab',
        fasta='iteration-{nr}/new_V_{pre,(pre|)}germline.fasta',
        annotated_tab='iteration-{nr}/annotated_V_{pre,(pre|)}germline.tab',
    input:
        tab='iteration-{nr}/candidates.tab',
        db_whitelist=db_whitelist_or_not,
        whitelist='whitelist.fasta' if os.path.exists('whitelist.fasta') else [],
    params:
        criteria=germlinefilter_criteria
    log:
        'iteration-{nr}/new_V_{pre,(pre|)}germline.log'
    shell:
        "igdiscover germlinefilter {params.criteria}"
        " --annotate={output.annotated_tab}"
        " --fasta={output.fasta} {input.tab} "
        " 2> >(tee {log} >&2) "
        " > {output.tab}"


rule igdiscover_discover_j:
    """Discover potential new J gene sequences"""
    output:
        tab="iteration-01/new_J.tab",
        fasta="iteration-01/new_J.fasta",
    input:
        j_reference="iteration-01/database/J.fasta",
        tab="iteration-01/filtered.tsv.gz",
    params:
        allele_ratio='--allele-ratio={}'.format(config.j_discovery['allele_ratio']),
        cross_mapping_ratio=' --cross-mapping-ratio={}'.format(config.j_discovery['cross_mapping_ratio'])
    shell:
        "time igdiscover discoverjd {params.allele_ratio}{params.cross_mapping_ratio} "
        "--database={input.j_reference} "
        "--fasta={output.fasta} "
        "{input.tab} > {output.tab}"


rule stats_correlation_V_J:
    output:
        pdf="{dir}/correlationVJ.pdf"
    input:
        table="{dir}/assigned.tsv.gz"
    run:
        import matplotlib
        matplotlib.use('pdf')
        # sns.heatmap will not work properly with the object-oriented interface,
        # so use pyplot
        import matplotlib.pyplot as plt
        import seaborn as sns
        import numpy as np
        import pandas as pd
        from collections import Counter
        table = igdiscover.read_table(input.table)
        fig = plt.figure(figsize=(29.7/2.54, 21/2.54))
        counts = np.zeros((21, 11), dtype=np.int64)
        counter = Counter(zip(table['V_errors'], table['J_errors']))
        for (v,j), count in counter.items():
            if v is not None and v < counts.shape[0] and j is not None and j < counts.shape[1]:
                counts[v,j] = count
        df = pd.DataFrame(counts.T)[::-1]
        df.index.name = 'J errors'
        df.columns.name = 'V errors'
        sns.heatmap(df, annot=True, fmt=',d', cbar=False)
        fig.suptitle('V errors vs. J errors in unfiltered sequences')
        fig.set_tight_layout(True)
        fig.savefig(output.pdf)


rule plot_errorhistograms:
    output:
        multi_pdf='{dir}/errorhistograms.pdf',
        boxplot_pdf='{dir}/v-shm-distributions.pdf'
    input:
        table='{dir}/filtered.tsv.gz'
    params:
        ignore_j=' --max-j-shm=0' if config.ignore_j else ''
    shell:
        'igdiscover errorplot{params.ignore_j} --multi={output.multi_pdf} --boxplot={output.boxplot_pdf} {input.table}'


rule dendrogram:
    output:
        pdf='{dir}/dendrogram_{gene}.pdf'
    input:
        fasta='{dir}/database/{gene}.fasta'
    shell:
        'igdiscover dendrogram --mark database/{wildcards.gene}.fasta {input.fasta} {output.pdf}'


rule version:
    output: txt='stats/version.txt'
    run:
        with open(output.txt, 'w') as f:
            print('IgDiscover version', igdiscover.__version__, file=f)


def get_sequences(path):
    with dnaio.open(path) as fr:
        sequences = [record.sequence.upper() for record in fr]
    return sequences


def count_sequences(path):
    with dnaio.open(path) as fr:
        n = 0
        for _ in fr:
            n += 1
    return n


rule json_stats_nofinal:
    output: json='stats/stats_nofinal.json'
    input:
        original_db='database/V.fasta',
        v_pre_germline=['iteration-{:02d}/new_V_pregermline.fasta'.format(i+1) for i in range(config.iterations)],
        v_germline=['iteration-{:02d}/new_V_germline.fasta'.format(i+1) for i in range(config.iterations)],
        filtered_stats=['iteration-{:02d}/stats/filtered.json'.format(i+1) for i in range(config.iterations)],
        group_stats='stats/groups.json',
        reads='stats/reads.json',
        trimmed='stats/trimmed.json'
    run:
        d = OrderedDict()
        d['version'] = igdiscover.__version__

        with open(input.reads) as f:
            rp = json.load(f)
        rp['raw_reads'] = rp['total']
        rp['merged'] = rp['merged']
        rp['merging_was_done'] = rp['merging_was_done']
        with open(input.trimmed) as f:
            rp['after_primer_trimming'] = json.load(f)['trimmed']

        with open(input.group_stats) as f:
            rp['grouping'] = json.load(f)
        d['read_preprocessing'] = rp

        prev_sequences = set(get_sequences(input.original_db))
        size = len(prev_sequences)
        iterations = [{'database': {
            'iteration': 0,
            'size': size,
        }}]
        for iteration, (pre_germline_path, germline_path, filtered_json_path) in enumerate(
                zip(input.v_pre_germline, input.v_germline, input.filtered_stats), 1):

            pre_germline_sequences = set(get_sequences(pre_germline_path))
            germline_sequences = set(get_sequences(germline_path))

            gained = len(germline_sequences - prev_sequences)
            lost = len(prev_sequences - germline_sequences)
            gained_pre = len(pre_germline_sequences - prev_sequences)
            lost_pre = len(prev_sequences - pre_germline_sequences)

            iteration_info = OrderedDict()
            with open(filtered_json_path) as f:
                iteration_info['assignment_filtering'] = json.load(f)
            db_info = {'iteration': iteration}
            db_info['size'] = len(germline_sequences)
            db_info['gained'] = gained
            db_info['lost'] = lost
            db_info['size_pre'] = len(pre_germline_sequences)
            db_info['gained_pre'] = gained_pre
            db_info['lost_pre'] = lost_pre
            iteration_info['database'] = db_info
            iterations.append(iteration_info)
            prev_sequences = pre_germline_sequences
        d['iterations'] = iterations

        with open(output.json, 'w') as f:
            json.dump(d, f, indent=2)
            print(file=f)


rule json_stats:
    output: json='stats/stats.json'
    input:
        stats_nofinal='stats/stats_nofinal.json',
        final_stats='final/stats/filtered.json',
    run:
        with open(input.stats_nofinal) as f:
            d = json.load(f)

        with open(input.final_stats) as f:
            d['assignment_filtering'] = json.load(f)

        with open(output.json, 'w') as f:
            json.dump(d, f, indent=2)
            print(file=f)
