# kate: syntax Python;
"""
"""
import shutil
from sqt.dna import reverse_complement
from sqt import FastaReader
import igdiscover
from igdiscover.utils import relative_symlink, Config

try:
	config = Config()
except FileNotFoundError as e:
	sys.exit("Pipeline configuration file {!r} not found. Please create it!".format(e.filename))

print('IgDiscover configuration:')
for k, v in sorted(vars(config).items()):
	# TODO the following line is only necessary for non-YAML configurations
	if k.startswith('_'):
		continue
	print('   ', k, ': ', repr(v), sep='')

# This command is run before every shell command and helps to catch errors early
shell.prefix("set -euo pipefail;")

READS1, READS2 = 'reads/1-limited.1.fastq.gz', 'reads/1-limited.2.fastq.gz'
PREPROCESSED_READS = 'reads/sequences.fasta.gz'

# Targets for each iteration
ITERATION_TARGETS = [
	'clusterplots/done',
	#'correlationVJ.pdf',
	'errorhistograms.pdf',
	'V_usage.tab',
	'V_usage.pdf',
	'V_dendrogram.pdf',
]
# Targets for non-final iterations
DISCOVERY_TARGETS = [
	'candidates.tab',
	'new_V_database.fasta',
]
# Targets for final iteration
FINAL_TARGETS = [
	#'consensus/correlationVJ.pdf',
	#'consensus/V_usage.tab',
	#'consensus/V_usage.pdf',
]
TARGETS = expand('iteration-{nr:02d}/{path}', nr=range(1, config.iterations+1), path=ITERATION_TARGETS+DISCOVERY_TARGETS)
TARGETS += expand('final/{path}', path=ITERATION_TARGETS+FINAL_TARGETS)
TARGETS += ['stats/readlengths.pdf', 'stats/reads.txt']
if config.barcode_length > 0:
	TARGETS += ['stats/barcodes.txt']

# Use pigz (parallel gzip) if available
GZIP = 'pigz' if shutil.which('pigz') is not None else 'gzip'

rule all:
	input:
		TARGETS


if config.limit:
	rule limit_reads_gz:
		output: 'reads/1-limited.{nr,([12]\\.|)}{ext,(fasta|fastq)}.gz'
		input: 'reads.{nr}{ext}.gz'
		shell:
			'sqt fastxmod -w 0 --limit {config.limit} {input} | {GZIP} > {output}'

	rule limit_reads:
		output: 'reads/1-limited.{nr,([12]\\.|)}{ext,(fasta|fastq)}.gz'
		input: 'reads.{nr}{ext}'
		shell:
			'sqt fastxmod -w 0 --limit {config.limit} {input} | {GZIP} > {output}'

else:
	rule symlink_limited:
		output: fastaq='reads/1-limited.{nr,([12]\\.|)}{ext,(fasta|fastq)}.gz'
		input: fastaq='reads.{nr}{ext}.gz'
		resources: time=1
		run:
			relative_symlink(input.fastaq, output.fastaq, force=True)

	# TODO compressing the input file is an unnecessary step
	rule gzip_limited:
		output: fastaq='reads/1-limited.{nr,([12]\\.|)}{ext,(fasta|fastq)}.gz'
		input: fastaq='reads.{nr}{ext}'
		shell:
			'{GZIP} < {input} > {output}'


if config.merge_program == 'flash':
	rule flash_merge:
		"""Use FLASH to merge paired-end reads"""
		output: 'reads/2-merged.fastq.gz'
		input: 'reads/1-limited.1.fastq.gz', 'reads/1-limited.2.fastq.gz'
		resources: time=60
		threads: 8
		log: 'reads/flash.log'
		shell:
			# -M: maximal overlap (2x300, 420-450bp expected fragment size)
			"flash -t {threads} -c -M {config.flash_maximum_overlap} {input} 2> >(tee {log} >&2) | {GZIP} > {output}"
elif config.merge_program == 'pear':
	rule pear_merge:
		"""Use pear to merge paired-end reads"""
		output:
			unmerged1='reads/2-pear.unassembled.forward.fastq.gz',
			unmerged2='reads/2-pear.unassembled.reverse.fastq.gz',
			discarded='reads/2-pear.discarded.fastq.gz',
			fastq='reads/2-merged.fastq.gz',
		input: 'reads/1-limited.1.fastq.gz', 'reads/1-limited.2.fastq.gz'
		resources: time=60
		threads: 8
		log: 'reads/2-pear.log'
		shell:
			r"""
			pear -j {threads} -f {input[0]} -r {input[1]} -o reads/2-pear | tee {log}
			for n in unassembled.forward unassembled.reverse assembled discarded; do
				{GZIP} -f reads/2-pear.$n.fastq
			done
			mv reads/2-pear.assembled.fastq.gz {output.fastq}
			"""
else:
	sys.exit("merge_program {config.merge_program!r} given in configuration file not recognized".format(config=config))


rule symlink_merged:
	output: fastaq='reads/2-merged.{ext,(fasta|fastq)}.gz'
	input: fastaq='reads/1-limited.{ext}.gz'
	run:
		relative_symlink(input.fastaq, output.fastaq, force=True)


rule merged_read_length_histogram:
	output:
		txt="stats/merged.readlengths.txt",
		pdf="stats/merged.readlengths.pdf"
	input:
		fastq='reads/2-merged.fastq.gz'
	shell:
		"""sqt readlenhisto --bins 100 --left 300 --title "Lengths of merged reads" --plot {output.pdf} {input}  > {output.txt}"""


rule read_length_histogram:
	output:
		txt="stats/readlengths.txt",
		pdf="stats/readlengths.pdf"
	input:
		fastq=PREPROCESSED_READS
	shell:
		"""sqt readlenhisto --bins 100 --left 300 --title "Lengths of pre-processed reads" --plot {output.pdf} {input}  > {output.txt}"""


rule barcode_stats_fastq:
	"""Print out number of random barcodes in the library
	TODO
	- run this only if random barcodes are actually used
	- make sure that a stranded protocol is used
	"""
	output: txt="stats/barcodes.txt"
	input: fastq='reads/2-merged.fastq.gz'
	shell:
		"""
		{GZIP} -dc {input} | awk 'NR%4==2 {{print substr($1,1,12)}}' | grep -v N | sort -u | wc -l > {output}
		"""

rule barcode_stats_fasta:
	"""
	TODO Implement this
	"""
	output: txt="stats/barcodes.txt"
	input: fasta='reads/1-limited.fasta.gz'
	shell: "touch {output}"


rule reads_stats_fastq:
	output: txt="stats/reads.txt"
	input:
		reads=READS1,
		merged='reads/2-merged.fastq.gz',
		sequences=PREPROCESSED_READS,
	run:
		shell("""
		echo -n "Number of paired-end reads: " > {output}
		{GZIP} -dc {input.reads} | awk 'END {{ print NR/4 }}' >> {output}
		echo -n "Number of barcodes (looking at 1st read in pair): " >> {output}
		{GZIP} -dc {input.reads} | awk 'NR % 4 == 2 {{ print substr($1, 1, 12) }}' | sort -u | grep -v N | wc -l >> {output}
		echo -n "Number of merged sequences: " >> {output}
		{GZIP} -dc {input.merged} | awk 'END {{ print NR/4 }}' >> {output}
		""")
		if config.barcode_length > 0:
			shell("""
			echo -n "Number of barcodes in merged sequences: " >> {output}
			{GZIP} -dc {input.merged} | awk 'NR % 4 == 2 {{ print substr($1, 1, {config.barcode_length}) }}' | sort -u | grep -v N | wc -l >> {output}
			echo -n "Number of preprocesed sequences: " >> {output}
			zgrep -c '^>' {input.sequences} >> {output}
			echo -n "Number of barcodes in preprocessed sequences: " >> {output}
			zgrep -A 1 '^>' {input.sequences} | awk '!/^>/ && $1 != "--" {{ print substr($1,1,{config.barcode_length}) }}' | sort -u | grep -v N | wc -l >> {output}
			""")


rule reads_stats_fasta:
	"""
	TODO implement this
	"""
	output: txt="stats/reads.txt"
	input:
		merged='reads/1-limited.fasta.gz'
	shell: "touch {output}"


# Remove primer sequences

if config.forward_primers:
	# At least one forward primer is to be removed
	rule trim_forward_primers:
		output: fastaq=temp('reads/3-forward-primer-trimmed.{ext,(fasta|fastq)}.gz')
		input: fastaq='reads/2-merged.{ext}.gz'
		resources: time=120
		log: 'reads/3-forward-primer-trimmed.cutadapt.log'
		run:
			primers = config.forward_primers
			param = ' '.join('-g ^{}'.format(seq) for seq in primers)
			if not config.stranded:
				param += ' ' + ' '.join('-a {}$'.format(reverse_complement(seq)) for seq in config.forward_primers)
			shell(
			"""
			cutadapt --discard-untrimmed {param} -o {output.fastaq} {input.fastaq} | tee {log}
			""")
else:
	# No trimming, just symlink the file
	rule dont_trim_forward_primers:
		output: fastaq='reads/3-forward-primer-trimmed.{ext,(fasta|fastq)}.gz'
		input: fastaq='reads/2-merged.{ext}.gz'
		resources: time=1
		run:
			relative_symlink(input.fastaq, output.fastaq, force=True)


if config.reverse_primers:
	# At least one reverse primer is to be removed
	rule trim_reverse_primers:
		output: fastaq='reads/4-trimmed.{ext,(fasta|fastq)}.gz'
		input: fastaq='reads/3-forward-primer-trimmed.{ext,(fasta|fastq)}.gz'
		resources: time=120
		log: 'reads/4-trimmed.cutadapt.log'
		run:
			primers = config.reverse_primers
			# Reverse primers should appear reverse-complemented at the 3' end
			# of the merged read.
			param = ' '.join('-a {}$'.format(reverse_complement(seq)) for seq in primers)
			if not config.stranded:
				param += ' ' + ' '.join('-g ^{}'.format(seq) for seq in config.reverse_primers)
			shell(
			"""
			cutadapt --discard-untrimmed {param} -o {output.fastaq} {input.fastaq} | tee {log}
			""")
else:
	# No trimming, just symlink the file
	rule dont_trim_reverse_primers:
		output: fastaq='reads/4-trimmed.{ext,(fasta|fastq)}.gz'
		input: fastaq='reads/3-forward-primer-trimmed.{ext,(fasta|fastq)}.gz'
		resources: time=1
		run:
			relative_symlink(input.fastaq, output.fastaq, force=True)


rule fastqc:
	output:
		zip='fastqc/{file}.zip',
		png='fastqc/{file}/Images/per_base_quality.png',
		html='fastqc/{file}_fastqc.html'
	input: fastq='{file}.fastq.gz'
	shell:
		r"""
		rm -rf fastqc/{wildcards.file}/ fastqc/{wildcards.file}_fastqc/ && \
		fastqc -o fastqc {input} && \
		mv fastqc/{wildcards.file}_fastqc.zip {output.zip} && \
		unzip -o -d fastqc/ {output.zip} && \
		mv fastqc/{wildcards.file}_fastqc/ fastqc/{wildcards.file}
		"""


rule filter_fasta:
	"""
	* Remove low-quality sequences
	* Discard too short sequences
	"""
	output: fasta=temp("reads/5-filtered.fasta")
	input: fasta="reads/4-trimmed.fasta.gz"
	params: max_errors=" --max-errors {config.maximum_expected_errors}" if config.maximum_expected_errors is not None else ""
	shell:
		"sqt fastxmod{params.max_errors} --minimum-length {config.minimum_merged_read_length} {input.fasta} > {output.fasta}"


rule fastq_to_fasta:
	"""
	Same as filter_fasta, but also convert from FASTQ to FASTA
	"""
	output: fasta=temp("reads/5-filtered.fasta")
	input: fastq="reads/4-trimmed.fastq.gz"
	params: max_errors=" --max-errors {config.maximum_expected_errors}" if config.maximum_expected_errors is not None else ""
	shell:
		"sqt fastxmod{params.max_errors} --minimum-length {config.minimum_merged_read_length} --fasta {input.fastq} > {output.fasta}"


if config.barcode_length > 0:
	rule igdiscover_group:
		"""Group by barcode and CDR3 (also implicitly removes duplicates)"""
		output:
			fastagz=PREPROCESSED_READS,
			pdf="stats/groupsizes.pdf"
		input:
			fasta="reads/5-filtered.fasta"
		log: PREPROCESSED_READS + ".log"
		shell:
			"igdiscover group --barcode-length {config.barcode_length} --trim-g --plot-sizes {output.pdf} {input.fasta} 2> {log} | {GZIP} > {output.fastagz}"

else:
	rule dereplicate:
		"""Collapse identical sequences with VSEARCH"""
		output: fastagz=PREPROCESSED_READS
		input: fasta="reads/5-filtered.fasta"
		shell:
			'vsearch --derep_fulllength {input.fasta} --strand both --sizeout --output >({GZIP} > {output.fastagz})'



rule copy_dj_database:
	"""Copy D and J gene database into the iteration folder"""
	output:
		fasta="{base}/database/{species}_{gene,[DJ]}.fasta"
	input:
		fasta="database/{species}_{gene}.fasta"
	shell:
		"cp -p {input} {output}"


rule v_database_iteration_1:
	"""Copy original V gene database into the iteration 1 folder"""
	output:
		fasta="iteration-01/database/{species}_V.fasta"
	input:
		fasta="database/{species}_V.fasta"
	shell:
		"cp -p {input} {output}"


for i in range(2, config.iterations + 1):
	rule:
		output:
			fasta='iteration-{nr:02d}/database/{{species}}_V.fasta'.format(nr=i)
		input:
			fasta='iteration-{nr:02d}/new_V_database.fasta'.format(nr=i-1)
		shell:
			"cp -p {input.fasta} {output.fasta}"

if config.iterations == 0:
	# Copy over the input database (would be nice to avoid this)
	rule copy_database:
		output:
			fasta='final/database/{species}_V.fasta'
		input:
			fasta='database/{species}_V.fasta'
		shell:
			"cp -p {input.fasta} {output.fasta}"
else:
	rule copy_final_v_database:
		output:
			fasta='final/database/{species}_V.fasta'.format(species=config.species)
		input:
			fasta='iteration-{nr:02d}/new_V_database.fasta'.format(nr=config.iterations)
		shell:
			"cp -p {input.fasta} {output.fasta}"


rule makeblastdb:
	output: "{dir}/{species}_{gene,[VDJ]}.nhr"  # and nin nog nsd nsi nsq
	input: fasta="{dir}/{species}_{gene}.fasta"
	params: dbname="{dir}/{species}_{gene}"
	log: '{dir}/{species}_{gene}.log'
	threads: 100  # force to run as single job
	run:
		with FastaReader(input.fasta) as fr:
			sequences = list(fr)
		if not sequences:
			raise ValueError("The FASTA file {} is empty, cannot continue!".format(input.fasta))
		shell(r"""
		makeblastdb -parse_seqids -dbtype nucl -in {input.fasta} -out {params.dbname} >& {log}
		grep 'Error: ' {log} && {{ echo "makeblastdb failed when creating {params.dbname}"; false; }} || true
		""")


rule igdiscover_igblast:
	output:
		txtgz="{dir}/igblast.txt.gz"
	input:
		fastagz=PREPROCESSED_READS,
		db_v="{{dir}}/database/{species}_V.nhr".format(species=config.species),
		db_d="{{dir}}/database/{species}_D.nhr".format(species=config.species),
		db_j="{{dir}}/database/{species}_J.nhr".format(species=config.species)
	params:
		penalty='--penalty {}'.format(config.mismatch_penalty) if config.mismatch_penalty is not None else '',
		database='{dir}/database'
	threads: 16
	shell:
		"igdiscover igblast --threads {threads} {params.penalty} "
		"--species {config.species} {params.database} {input.fastagz} | "
		"{GZIP} > {output.txtgz}"


rule igdiscover_parse:
	output:
		tabgz="{dir}/assigned.tab.gz",
	input:
		txt="{dir}/igblast.txt.gz",
		fasta=PREPROCESSED_READS
	shell:
		"igdiscover parse --rename {config.library_name!r}_ {input.txt} {input.fasta} | {GZIP} > {output.tabgz}"


rule igdiscover_filter:
	output:
		filtered="{dir}/filtered.tab.gz"
	input:
		assigned="{dir}/assigned.tab.gz"
	shell:
		"igdiscover filter {input} | {GZIP} > {output}"


rule igdiscover_count:
	output:
		plot="{dir}/{gene,[VDJ]}_usage.pdf",
		counts="{dir}/{gene}_usage.tab"
	input:
		reference="{{dir}}/database/{species}_{{gene}}.fasta".format(species=config.species),
		tab="{dir}/filtered.tab.gz"
	shell:
		"igdiscover count --database {input.reference} --gene {wildcards.gene} "
		"{input.tab} {output.plot} > {output.counts}"


rule igdiscover_clusterplot:
	output:
		done="{dir}/clusterplots/done"
	input:
		tab="{dir}/filtered.tab.gz"
	params:
		clusterplots="{dir}/clusterplots/",
		ignore_j='--ignore-J' if config.ignore_j else ''
	shell:
		"igdiscover clusterplot {params.ignore_j} {input.tab} {params.clusterplots} && touch {output.done}"


rule igdiscover_discover:
	"""Discover potential new V gene sequences"""
	output:
		tab="{dir}/candidates.tab"
	input:
		v_reference="{{dir}}/database/{species}_V.fasta".format(species=config.species),
		tab="{dir}/filtered.tab.gz"
	params:
		ignore_j='--ignore-J' if config.ignore_j else ''
	threads: 4  # TODO this could be increased up to 6
	shell:
		"igdiscover discover -j {threads} --subsample {config.subsample} "
		"--database {input.v_reference} {params.ignore_j} "
		"{input.tab} > {output.tab}"


rule igdiscover_compose:
	"""Construct a new database out of the discovered sequences"""
	output:
		fasta='iteration-{nr}/new_V_database.fasta'
	input:
		tab='iteration-{nr}/candidates.tab'
	run:
		nr = int(wildcards.nr, base=10)
		criteria = '--looks-like-V --max-differences=2 '
		if nr == config.iterations:
			# Apply stricter filtering criteria for final iteration
			criteria += '--unique-CDR3=3 --cluster-size={config.minimum_cluster_size}'.format(config=config)
		else:
			criteria += '--unique-CDR3=2'
		shell("igdiscover compose {criteria} {input.tab} > {output.fasta}")


rule stats_correlation_V_J:
	output:
		pdf="{dir}/correlationVJ.pdf"
	input:
		table="{dir}/assigned.tab.gz"
	run:
		import matplotlib
		matplotlib.use('pdf')
		# sns.heatmap will not work properly with the object-oriented interface,
		# so use pyplot
		import matplotlib.pyplot as plt
		import seaborn as sns
		import numpy as np
		import pandas as pd
		from collections import Counter
		table = igdiscover.read_table(input.table)
		fig = plt.figure(figsize=(29.7/2.54, 21/2.54))
		counts = np.zeros((21, 11), dtype=np.int64)
		counter = Counter(zip(table['V_errors'], table['J_errors']))
		for (v,j), count in counter.items():
			if v is not None and v < counts.shape[0] and j is not None and j < counts.shape[1]:
				counts[v,j] = count
		df = pd.DataFrame(counts.T)[::-1]
		df.index.name = 'J errors'
		df.columns.name = 'V errors'
		sns.heatmap(df, annot=True, fmt=',d', cbar=False)
		fig.suptitle('V errors vs. J errors in unfiltered sequences')
		fig.set_tight_layout(True)
		fig.savefig(output.pdf)


rule plot_errorhistograms:
	output:
		pdf='{dir}/errorhistograms.pdf',
	input:
		table='{dir}/filtered.tab.gz'
	params:
		ignore_j='--ignore-J' if config.ignore_j else ''
	shell:
		'igdiscover errorplot {params.ignore_j} {input.table} {output.pdf}'


rule dendrogram:
	output:
		pdf='{dir}/{gene}_dendrogram.pdf'
	input:
		fasta='{{dir}}/database/{species}_{{gene}}.fasta'.format(species=config.species)
	shell:
		'igdiscover dendrogram --mark database/{config.species}_{wildcards.gene}.fasta {input.fasta} {output.pdf}'
