#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (C) 2011, A. Murat Eren
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free
# Software Foundation; either version 2 of the License, or (at your option)
# any later version.
#
# Please read the COPYING file.

# v.

import re
import os
import sys
import tempfile
import subprocess
import configparser


try:
    import Levenshtein as l
except:
    print('''
    ERROR: This script requires Levenshtein module, which seems to not installed on this machine.

    Here is a fast implementation of Levenshtein distance for Python:

        http://code.google.com/p/pylevenshtein/

''')
    sys.exit()


import IlluminaUtils.lib.fastqlib as u
import IlluminaUtils.lib.fastalib as f
import IlluminaUtils.rapidmerge as rapidmerge
from IlluminaUtils.utils.runconfiguration import RunConfiguration
from IlluminaUtils.utils.runconfiguration import RunConfigError
from IlluminaUtils.utils.helperfunctions import ConfigError
from IlluminaUtils.utils.helperfunctions import visualize_qual_stats_dict
from IlluminaUtils.utils.helperfunctions import colorize
from IlluminaUtils.utils.helperfunctions import conv_dict
from IlluminaUtils.utils.helperfunctions import reverse_complement
from IlluminaUtils.utils.helperfunctions import big_number_pretty_print


bases_upper_case = set(conv_dict.keys())


NumberOfConflicts = lambda s: sum([True for n in s if n in bases_upper_case])
ConflictPositions = lambda s: [i for i in range(0, len(s)) if s[i] in bases_upper_case]



class MergerStats:
    def __init__(self):
        self.actual_number_of_pairs = 0
        self.pairs_eliminated_due_to_Ns = 0
        self.pairs_eliminated_due_to_P = 0
        self.pairs_eliminated_due_to_max_num_mismatches = 0
        self.pairs_eliminated_due_to_Q30 = 0
        self.pairs_eliminated_due_to_Min_Overlap = 0
        self.prefix_failed_in_pair_1_total = 0
        self.prefix_failed_in_pair_2_total = 0
        self.prefix_failed_in_both_pairs_total = 0
        self.passed_prefix_total = 0
        self.failed_prefix_total = 0
        self.merge_failed_total = 0
        self.merge_passed_total = 0
        self.merging_done_on_complete_overlap = 0
        self.num_mismatches_breakdown = {}
        self.num_mismatches_breakdown['merge passed'] = {}
        self.num_mismatches_breakdown['merge passed'][0] = 0
        self.total_number_of_mismatches = 0
        self.num_mismatches_recovered_from_read_1 = 0
        self.num_mismatches_recovered_from_read_2 = 0
        self.num_mismatches_replaced_with_N = 0
        self.num_Q30_fails_in_read_1 = 0
        self.num_Q30_fails_in_read_2 = 0
        self.num_Q30_fails_in_both = 0

    def s_line(self, label, value, padding = 55):
        return '%s%s\t%s\n' % (label, ' ' + '.' * (padding - len(label)), value)

    def write_stats(self, output_file_path, merger_object):
        stats = open(output_file_path, 'w')

        stats.write(self.s_line('Number of pairs analyzed', '%d' % self.actual_number_of_pairs))
        stats.write(self.s_line('Prefix failed in read 1', '%d' % self.prefix_failed_in_pair_1_total))
        stats.write(self.s_line('Prefix failed in read 2', '%d' % self.prefix_failed_in_pair_2_total))
        stats.write(self.s_line('Prefix failed in both', '%d' % self.prefix_failed_in_both_pairs_total))
        stats.write(self.s_line('Passed prefix total', '%d' % self.passed_prefix_total))
        stats.write(self.s_line('Failed prefix total', '%d' % self.failed_prefix_total))
        stats.write(self.s_line('Merged total', '%d' % self.merge_passed_total))
        stats.write(self.s_line('Complete overlap forced total', '%d' % self.merging_done_on_complete_overlap))
        stats.write(self.s_line('Merge failed total', '%d' % self.merge_failed_total))
        stats.write(self.s_line('Merge discarded due to P', '%d' % self.pairs_eliminated_due_to_P))
        stats.write(self.s_line('Merge discarded due to max num mismatches', '%d' % self.pairs_eliminated_due_to_max_num_mismatches))
        stats.write(self.s_line('Merge discarded due to Ns', '%d' % self.pairs_eliminated_due_to_Ns))
        stats.write(self.s_line('Merge discarded due to Q30', '%d' % self.pairs_eliminated_due_to_Q30))
        stats.write(self.s_line('Pairs discarded due to min expected overlap', '%d' % self.pairs_eliminated_due_to_Min_Overlap))
        stats.write(self.s_line('Num mismatches found in merged reads', '%d' % self.total_number_of_mismatches))
        stats.write(self.s_line('Mismatches recovered from read 1', '%d' % self.num_mismatches_recovered_from_read_1))
        stats.write(self.s_line('Mismatches recovered from read 2', '%d' % self.num_mismatches_recovered_from_read_2))
        stats.write(self.s_line('Mismatches replaced with N', '%d' % self.num_mismatches_replaced_with_N))
        stats.write('\n\nMismatches breakdown in final merged reads:\n\n')

        for i in sorted(self.num_mismatches_breakdown['merge passed'].keys()):
            stats.write('%d\t%d\n' % (i, self.num_mismatches_breakdown['merge passed'][i]))

        stats.write('\n\n')
        stats.write(self.s_line('Command line', '%s' % ' '.join(sys.argv)))
        stats.write(self.s_line('Work directory', '%s' % os.getcwd()))
        stats.write(self.s_line('"p" value', '%f' % merger_object.p_value))
        stats.write(self.s_line('Max num mismatches at the overlapped region', '%s' % (str(merger_object.max_num_mismatches) if merger_object.max_num_mismatches >= 0 else 'None')))
        stats.write(self.s_line('Min overlap size', '%s' % merger_object.min_overlap_size))
        stats.write(self.s_line('Min Q-score for mismatches', '%s' % merger_object.min_qual_score))
        stats.write(self.s_line('Ns ignored?', '%s' % merger_object.ignore_Ns))
        stats.write(self.s_line('Q30 enforced?', '%s' % merger_object.enforce_Q30_check))
        stats.write(self.s_line('Run with stringent flag on?', '%s' % merger_object.marker_gene_stringent))
        stats.write(self.s_line('Requested only the overlapping part to be retained?', '%s' % merger_object.retain_only_overlap))

        stats.close()


    def write_rapidmerge_stats(self, output_file_path, merger_object):
        stats = open(output_file_path, 'w')

        stats.write(self.s_line('Number of pairs analyzed', '%d' % self.actual_number_of_pairs))
        stats.write(self.s_line('Prefix failed in read 1', '%d' % self.prefix_failed_in_pair_1_total))
        stats.write(self.s_line('Prefix failed in read 2', '%d' % self.prefix_failed_in_pair_2_total))
        stats.write(self.s_line('Prefix failed in both', '%d' % self.prefix_failed_in_both_pairs_total))
        stats.write(self.s_line('Passed prefix total', '%d' % self.passed_prefix_total))
        stats.write(self.s_line('Failed prefix total', '%d' % self.failed_prefix_total))
        stats.write(self.s_line('Merged total', '%d' % self.merge_passed_total))
        stats.write(self.s_line('Complete overlap total', '%d' % self.merging_done_on_complete_overlap))
        stats.write(self.s_line('Merge discarded due to Ns', '%d' % self.pairs_eliminated_due_to_Ns))

        stats.write('\n\n')
        stats.write(self.s_line('Command line', '%s' % ' '.join(sys.argv)))
        stats.write(self.s_line('Work directory', '%s' % os.getcwd()))
        stats.write(self.s_line('Max num mismatches at the overlapped region', '%s' % (str(merger_object.max_num_mismatches) if merger_object.max_num_mismatches >= 0 else 'None')))
        stats.write(self.s_line('Min overlap size', '%s' % merger_object.min_overlap_size))
        stats.write(self.s_line('Run with stringent flag on?', '%s' % merger_object.marker_gene_stringent))
        stats.write(self.s_line('Requested only the overlapping part to be retained?', '%s' % merger_object.retain_only_overlap))

        stats.close()


    def write_mismatches_breakdown_table(self, output_file_path):
        num_mismatches_breakdown_table = open(output_file_path, 'w')
        categories = list(self.num_mismatches_breakdown.keys())
        num_mismatches_breakdown_table.write('%s\t%s\n' % ('num_mismatch', '\t'.join(categories)))

        for i in range(0, max([max(self.num_mismatches_breakdown[x].keys()) for x in self.num_mismatches_breakdown])):
            mismatch_counts = []
            for category in categories:
                count = 0
                if i in self.num_mismatches_breakdown[category]:
                    count = self.num_mismatches_breakdown[category][i]

                mismatch_counts.append(count)

            num_mismatches_breakdown_table.write('%d\t%s\n' % (i, '\t'.join([str(x) for x in mismatch_counts])))

        num_mismatches_breakdown_table.close()


    def record_num_mismatches(self, number_of_mismatches, merging_result = 'merged'):
        if merging_result not in self.num_mismatches_breakdown:
            self.num_mismatches_breakdown[merging_result] = {}

        if number_of_mismatches not in self.num_mismatches_breakdown[merging_result]:
            self.num_mismatches_breakdown[merging_result][number_of_mismatches] = 1
        else:
            self.num_mismatches_breakdown[merging_result][number_of_mismatches] += 1


    def process_recovery_dict(self, recovery_dict):
        self.total_number_of_mismatches += sum(recovery_dict.values())
        self.num_mismatches_recovered_from_read_1 += recovery_dict['read_1']
        self.num_mismatches_recovered_from_read_2 += recovery_dict['read_2']
        self.num_mismatches_replaced_with_N += recovery_dict['none']


class Merger:
    def __init__(self, config):
        self.config = config
        self.output_file_prefix = None
        self.compute_qual_dicts = False
        self.p_value = 0.3
        self.max_num_mismatches = -1
        self.min_overlap_size = 15
        self.min_qual_score = 15
        self.debug = False
        self.ignore_Ns = False
        self.ignore_deflines = False
        self.enforce_Q30_check = False
        self.retain_only_overlap = False
        self.marker_gene_stringent = False
        self.skip_suffix_trimming = False
        self.report_r1_prefix = False
        self.report_r2_prefix = False
        self.num_threads = -1
        self.project_name = 'Unknown_project'

        if self.config.pair_1_prefix:
            self.pair_1_prefix_compiled = re.compile(self.config.pair_1_prefix)
        else:
            self.pair_1_prefix_compiled = None

        if self.config.pair_2_prefix:
            self.pair_2_prefix_compiled = re.compile(self.config.pair_2_prefix)
        else:
            self.pair_2_prefix_compiled = None

        # used for visualization:
        self.mean_quals_per_mismatch = {}

        self.stats = MergerStats()


    def run(self):
        #sys.stderr.write('POK: "Percent of all pairs that did present the prefix defined by the user\n"')
        #sys.stderr.write('ZM : "Percent of all pairs that merged with 0 mismatch\n"')

        self.init()

        for index in range(0, len(self.config.pair_1)):

            # Faster, multithreaded merging
            if self.num_threads != -1:
                # Back out certain variables for `rapidmerge.py`.
                r1_prefix_pattern = ('' if self.pair_1_prefix_compiled is None
                                     else self.pair_1_prefix_compiled.pattern)

                r2_prefix_pattern = ('' if self.pair_2_prefix_compiled is None
                                     else self.pair_2_prefix_compiled.pattern)

                if (r1_prefix_pattern and r1_prefix_pattern[0] != '^') or (r2_prefix_pattern and r2_prefix_pattern[0] != '^'):
                    print("Prefix patterns must start with ^ as they occur at the beginning of the read.")
                    sys.exit(-1)

                rapidmerger = rapidmerge.FASTQMerger(
                    self.config.pair_1[index],
                    self.config.pair_2[index],
                    input1_is_gzipped=self.config.pair_1[index].endswith('.gz'),
                    input2_is_gzipped=self.config.pair_2[index].endswith('.gz'),
                    ignore_deflines=self.ignore_deflines,
                    output_dir=self.config.output_directory,
                    output_file_name=self.output_file_prefix,
                    r1_prefix_pattern=r1_prefix_pattern,
                    r2_prefix_pattern=r2_prefix_pattern,
                    report_r1_prefix=self.report_r1_prefix,
                    report_r2_prefix=self.report_r2_prefix,
                    max_p=self.p_value,
                    max_num_mismatches=self.max_num_mismatches,
                    min_overlap_size=self.min_overlap_size + 1,
                    min_qual_score=self.min_qual_score,
                    partial_overlap_only=not self.marker_gene_stringent,
                    retain_overlap_only=self.retain_only_overlap,
                    skip_suffix_trimming=self.skip_suffix_trimming,
                    ignore_Ns=self.ignore_Ns,
                    enforce_Q30_check=self.enforce_Q30_check,
                    dataset_index=index + 1,
                    total_dataset_count=len(self.config.pair_1),
                    project_name=self.project_name,
                    num_cores=self.num_threads)
                if self.max_num_mismatches == 0:
                    count_stats = rapidmerger.run('exact')
                elif self.distance_metric == 'hamming':
                    count_stats = rapidmerger.run('hamming')
                elif self.distance_metric == 'levenshtein':
                    count_stats = rapidmerger.run('levenshtein')

                return


            try:
                self.input_1 = u.FastQSource(self.config.pair_1[index], lazy_init=True)
                self.input_2 = u.FastQSource(self.config.pair_2[index], lazy_init=True)

            except u.FastQLibError as e:
                print("FastQLib is not happy.\n\n\t", e, "\n")
                sys.exit()


            if self.input_1.forced_raw and not self.ignore_deflines:
                print("\nError: Your input FASTQ files do not seem to be generated by CASAVA 1.8. Please use --ignore-deflines parameter.")
                sys.exit()


            while self.input_1.next(raw = self.ignore_deflines) and self.input_2.next(raw = self.ignore_deflines):
                self.stats.actual_number_of_pairs += 1
                if self.input_1.p_available:
                    if self.pair_1_prefix_compiled or self.pair_2_prefix_compiled:
                        self.input_1.print_percentage('[Merging %d of %d] POK: %.1f%% :: ZM: %.1f%%' \
                                    % (index + 1,
                                       len(self.config.pair_1),
                                       self.stats.passed_prefix_total * 100.0 / self.stats.actual_number_of_pairs,
                                       self.stats.num_mismatches_breakdown['merge passed'][0] * 100.0 / (self.stats.passed_prefix_total or 1)
                                       ))
                    else:
                        self.input_1.print_percentage(prefix = '[Merging %d of %d]' % (index + 1, len(self.config.pair_1)),
                                                      postfix = '-- merged pairs: %s' % big_number_pretty_print(self.stats.merge_passed_total)) \
                                                                    if self.stats.merge_passed_total > 0 else ''


                # taking care of prefixes if there are any..
                if self.pair_1_prefix_compiled:
                    pattern_1 = self.pair_1_prefix_compiled.search(self.input_1.entry.sequence)
                    if pattern_1:
                        self.input_1.entry.trim(trim_from = pattern_1.end())

                if self.pair_2_prefix_compiled:
                    pattern_2 = self.pair_2_prefix_compiled.search(self.input_2.entry.sequence)
                    if pattern_2:
                        self.input_2.entry.trim(trim_from = pattern_2.end())

                failed_prefix = False

                if self.pair_1_prefix_compiled and (not pattern_1):
                    failed_prefix = True
                    self.stats.prefix_failed_in_pair_1_total += 1
                if self.pair_2_prefix_compiled and (not pattern_2):
                    failed_prefix = True
                    self.stats.prefix_failed_in_pair_2_total += 1
                if (self.pair_1_prefix_compiled and (not pattern_1)) and (self.pair_2_prefix_compiled and (not pattern_2)):
                    self.stats.prefix_failed_in_both_pairs_total += 1

                if failed_prefix:
                    self.stats.failed_prefix_total += 1
                    continue
                if not failed_prefix:
                    self.stats.passed_prefix_total += 1


                # make sure both reads are longer than the minimum overlap criterion
                if min(len(self.input_1.entry.sequence), len(self.input_2.entry.sequence)) <= self.min_overlap_size:
                    self.stats.pairs_eliminated_due_to_Min_Overlap += 1
                    continue


                # merging
                merging_done_on_complete_overlap = False
                number_of_mismatches, recovery_dict = None, {}
                (beginning, overlap, end), number_of_mismatches, recovery_dict = self.merge_two_sequences_fast()

                merged_sequence = beginning + overlap + end

                # find out about 'p'
                len_overlap = len(overlap)
                p = 1.0 * number_of_mismatches / len_overlap

                if self.marker_gene_stringent:
                    (alt_beginning, alt_overlap, alt_end), alt_number_of_mismatches, alt_recovery_dict =\
                                                             self.merge_two_sequences_fast(complete_overlap = True)

                    if self.retain_only_overlap or not self.skip_suffix_trimming:
                        alt_merged_sequence = alt_overlap
                    else:
                        alt_merged_sequence = alt_beginning + alt_overlap + alt_end

                    # find out about 'p'
                    alt_len_overlap = len(alt_overlap)
                    alt_p = 1.0 * alt_number_of_mismatches / alt_len_overlap

                    if alt_p < p:
                        beginning, overlap, end, number_of_mismatches, recovery_dict, merged_sequence, len_overlap, p =\
                                            alt_beginning, alt_overlap, alt_end, alt_number_of_mismatches,\
                                            alt_recovery_dict, alt_merged_sequence, alt_len_overlap, alt_p
                        merging_done_on_complete_overlap = True
                        self.stats.merging_done_on_complete_overlap += 1

                if self.enforce_Q30_check and not merging_done_on_complete_overlap:
                    if not self.input_1.entry.Q_list:
                        self.input_1.entry.process_Q_list()
                        self.input_2.entry.process_Q_list()

                    r1_passed_Q30, r1_Q30 = self.passes_minoche_Q30(self.input_1.entry.Q_list[0:-len_overlap])
                    r2_passed_Q30, r2_Q30 = self.passes_minoche_Q30(self.input_2.entry.Q_list[0:-len_overlap])

                    pair_passed_Q30 = r1_passed_Q30 and r2_passed_Q30

                    if not r1_passed_Q30 and r2_passed_Q30:
                        self.stats.num_Q30_fails_in_read_1 += 1
                    elif r1_passed_Q30 and not r2_passed_Q30:
                        self.stats.num_Q30_fails_in_read_2 += 1
                    elif not r1_passed_Q30 and not r2_passed_Q30:
                        self.stats.num_Q30_fails_in_both += 1

                # generate the header line
                header_line = '%s|%o:d|m/o:%f|MR:%s|Q30:%s|CO:%d|mismatches:%d' % \
                                    (self.input_1.entry.header_line,
                                     len_overlap,
                                     p,
                                     'n=%d;r1=%d;r2=%d' % (recovery_dict['none'],
                                                           recovery_dict['read_1'],
                                                           recovery_dict['read_2']),
                                     'n/a' if not (self.enforce_Q30_check and not merging_done_on_complete_overlap) \
                                           else '%s=%d;%s=%d' % ('p' if r1_passed_Q30 else 'f',
                                                                 r1_Q30,
                                                                 'p' if r2_passed_Q30 else 'f',
                                                                 r2_Q30),
                                     1 if merging_done_on_complete_overlap else 0,
                                     number_of_mismatches)

                # append the header line to the project name
                header_line = '%s_%d|%s' % (self.project_name, self.stats.actual_number_of_pairs, header_line.replace('_', '-'))

                # FAIL CASE ~ max_num_mismatches
                if self.max_num_mismatches >= 0 and number_of_mismatches > self.max_num_mismatches:
                    self.failed.write_id(header_line)
                    self.failed.write_seq(merged_sequence, split = False)
                    self.stats.record_num_mismatches(number_of_mismatches, 'merge failed due to max num mismatches')
                    self.stats.merge_failed_total += 1
                    self.stats.pairs_eliminated_due_to_max_num_mismatches += 1
                    continue

                # FAIL CASE ~ m/o
                if p > self.p_value or len_overlap < self.min_overlap_size:
                    self.failed.write_id(header_line)
                    self.failed.write_seq(merged_sequence, split = False)
                    self.stats.record_num_mismatches(number_of_mismatches, 'merge failed due to P value')
                    self.stats.merge_failed_total += 1
                    self.stats.pairs_eliminated_due_to_P += 1
                    continue

                # FAIL CASE ~ N
                if not self.ignore_Ns:
                    if 'N' in merged_sequence or 'n' in merged_sequence:
                        self.withNs.write_id(header_line)
                        self.withNs.write_seq(merged_sequence, split = False)
                        self.stats.record_num_mismatches(number_of_mismatches, 'merge failed due to N')
                        self.stats.merge_failed_total += 1
                        self.stats.pairs_eliminated_due_to_Ns += 1
                        continue

                # FAIL CASE ~ Q30
                if self.enforce_Q30_check and not merging_done_on_complete_overlap:
                    if not pair_passed_Q30:
                        self.Q30fails.write_id(header_line)
                        self.Q30fails.write_seq(merged_sequence, split = False)
                        self.stats.record_num_mismatches(number_of_mismatches, 'merge failed due to Q30')
                        self.stats.merge_failed_total += 1
                        self.stats.pairs_eliminated_due_to_Q30 += 1
                        continue

                # FINALLY
                self.output.write_id(header_line)

                if self.retain_only_overlap:
                    self.output.write_seq(overlap, split = False)
                else:
                    self.output.write_seq(merged_sequence, split = False)
                self.stats.record_num_mismatches(number_of_mismatches, 'merge passed')
                self.stats.merge_passed_total += 1

                # record the info for successfuly merged pair in recovery dict
                self.stats.process_recovery_dict(recovery_dict)

                if self.report_r1_prefix and self.pair_1_prefix_compiled:
                    self.read1prefixes.write_id(header_line)
                    self.read1prefixes.write_seq(pattern_1.group(0))
                if self.report_r2_prefix and self.pair_2_prefix_compiled:
                    self.read2prefixes.write_id(header_line)
                    self.read2prefixes.write_seq(pattern_2.group(0))

                # FIXME: this is ridiculous. visualization of the Q-score curves associated
                # part needs to be moved somewhere else.
                if self.compute_qual_dicts:
                    ################ quality dicts associated stuff ####################
                    if not self.input_1.entry.Q_list:
                        q1 = self.input_1.entry.process_Q_list()
                        q2 = self.input_2.entry.process_Q_list()
                    else:
                        q1 = self.input_1.entry.Q_list
                        q2 = self.input_2.entry.Q_list

                    tile_number = self.input_1.entry.tile_number

                    if number_of_mismatches >= 10:
                        ind = 10
                    else:
                        ind = number_of_mismatches

                    if ind not in self.mean_quals_per_mismatch:
                        self.mean_quals_per_mismatch[ind] = {}

                    tiles_dict = self.mean_quals_per_mismatch[ind]

                    if '1' not in tiles_dict:
                        tiles_dict['1'] = {}
                    if '2' not in tiles_dict:
                        tiles_dict['2'] = {}

                    if tile_number not in tiles_dict['1']:
                        tiles_dict['1'][tile_number] = {'mean': [0] * len(q1), 'std': [0] * len(q1), 'count': [0] * len(q1)}
                    if tile_number not in tiles_dict['2']:
                        tiles_dict['2'][tile_number] = {'mean': [0] * len(q2), 'std': [0] * len(q2), 'count': [0] * len(q2)}

                    # if there is a length varaition:
                    if len(q1) > len(tiles_dict['1'][tile_number]['mean']):
                        diff = len(q1) - len(tiles_dict['1'][tile_number]['mean'])
                        for attr in ['mean', 'std', 'count']:
                            for d in range(0, diff):
                                tiles_dict['1'][tile_number][attr].append(0)
                    if len(q2) > len(tiles_dict['2'][tile_number]['mean']):
                        diff = len(q2) - len(tiles_dict['2'][tile_number]['mean'])
                        for attr in ['mean', 'std', 'count']:
                            for d in range(0, diff):
                                tiles_dict['2'][tile_number][attr].append(0)


                    for i in range(0, len(q1)):
                        tiles_dict['1'][tile_number]['mean'][i] += q1[i]
                        tiles_dict['1'][tile_number]['count'][i] += 1
                    for i in range(0, len(q2)):
                        tiles_dict['2'][tile_number]['mean'][i] += q2[i]
                        tiles_dict['2'][tile_number]['count'][i] += 1

                    ################ / quality dicts associated stuff ####################




            print()
            self.input_1.close()
            self.input_2.close()


        self.close_output_files()
        self.stats.write_stats(os.path.join(self.config.output_directory, self.output_file_prefix + '_STATS'), self)
        self.stats.write_mismatches_breakdown_table(os.path.join(self.config.output_directory, self.output_file_prefix + '_MISMATCHES_BREAKDOWN'))


        if self.compute_qual_dicts:
            ################ quality dicts associated stuff ####################
            #finalizing mean_quals_per_mismatch dict:
            for ind in list(self.mean_quals_per_mismatch.keys()):
                for pair in ['1', '2']:
                    for tile in self.mean_quals_per_mismatch[ind][pair]:
                        for i in range(0, len(self.mean_quals_per_mismatch[ind][pair][tile]['mean'])):
                            self.mean_quals_per_mismatch[ind][pair][tile]['mean'][i] = self.mean_quals_per_mismatch[ind][pair][tile]['mean'][i] * 1.0 \
                                                                                  / self.mean_quals_per_mismatch[ind][pair][tile]['count'][i]

            # visualizing quals
            for ind in list(self.mean_quals_per_mismatch.keys()):
                sys.stderr.write('\rVisualizing Qual Scores: %d of %d' % (list(self.mean_quals_per_mismatch.keys()).index(ind) + 1,
                                                                          len(list(self.mean_quals_per_mismatch.keys()))))
                sys.stderr.flush()
                visualize_qual_stats_dict(self.mean_quals_per_mismatch[ind], os.path.join(self.config.output_directory, self.output_file_prefix + '_QUALS_%d_MISMATCH' % ind),\
                                            title = 'Machine Reported Mean PHRED Scores for Pairs Merged with %d Mismatches' % ind if ind != 1 else \
                                                    'Machine Reported Mean PHRED Scores for Pairs Merged with 1 Mismatch')
            print()
            ################ / quality dicts associated stuff ####################


    def init(self):
        self.project_name = self.config.project_name.strip().replace(' ', '_')

        if not self.output_file_prefix:
            self.output_file_prefix = self.project_name

        P = lambda p: f.FastaOutput(os.path.join(self.config.output_directory, self.output_file_prefix + p))

        self.output = P('_MERGED')
        self.failed = P('_FAILED')
        self.withNs = P('_FAILED_WITH_Ns')

        if self.enforce_Q30_check:
            self.Q30fails = P('_FAILED_Q30')
        else:
            self.Q30fails = None

        self.read1prefixes = P('_MERGED_R1_PREFIX') if self.report_r1_prefix else None
        self.read2prefixes = P('_MERGED_R2_PREFIX') if self.report_r2_prefix else None

        # Restrictions on --num-threads option.
        if self.num_threads != -1:
            if self.compute_qual_dicts:
                print("`--num-threads` cannot be used with `--compute-qual-dicts`.")
                sys.exit(0)


    def close_output_files(self):
        self.output.close()
        self.failed.close()
        self.withNs.close()
        if self.enforce_Q30_check:
            self.Q30fails.close()
        if self.report_r1_prefix:
            self.read1prefixes.close()
        if self.report_r2_prefix:
            self.read2prefixes.close()


    def get_read_id_with_better_base_qual(self, base_index_in_read_1, base_index_in_reversed_read_2):
        """when there is a disagreement regarding a base between pairs at the,
           overlapped region, such as this one:

           read 1: AAAAAAAAAAAAAAAAAAAAAAAAAAAA
           read 2:                AAAATAAAAAAAAAAAAAAAAAAAAAAAA
                                      ^

           there can be two outcome for the merged sequence: we would either use
           the base in read 1, in which case merged sequence would look like this:

                   AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
                                      ^

           or we would use the base in read 2 that would result with this one:

                   AAAAAAAAAAAAAAAAAAATAAAAAAAAAAAAAAAAAAAAAAAA
                                      ^


           this function takes both reads and returns the read id (1 or 2) that has
           a higher quality score assigned by the sequencer for the base that is
           in question. if the higher quality score is still lower than
           min_qual_score parameter, 0 is returned as read id.
           """

        obj_read_1 = self.input_1.entry
        obj_read_2 = self.input_2.entry

        if not obj_read_1.Q_list:
            obj_read_1.process_Q_list()
            obj_read_2.process_Q_list()

        base_qual_in_read_1 = obj_read_1.Q_list[base_index_in_read_1]
        base_qual_in_read_2 = obj_read_2.Q_list[::-1][base_index_in_reversed_read_2]

        # if neither of the bases in question satisfies min_qual_score expectation,
        # return 0. in this case dispute will resolve to an ambiguous base in the
        # merged sequence
        if max([base_qual_in_read_1, base_qual_in_read_2]) < self.min_qual_score:
            return 0

        if base_qual_in_read_1 >= base_qual_in_read_2:
            return 1
        else:
            return 2


    def merge_two_sequences_fast(self, complete_overlap = False):
        # when we expect a complete overlap, it is the first read that
        # needs to be reverse complemented and compared to the second
        # read.
        if complete_overlap:
            obj_read_1 = self.input_2.entry
            obj_read_2 = self.input_1.entry
            seq1 = reverse_complement(obj_read_1.sequence)
            seq2 = obj_read_2.sequence
        else:
            obj_read_1 = self.input_1.entry
            obj_read_2 = self.input_2.entry
            seq1 = obj_read_1.sequence
            seq2 = reverse_complement(obj_read_2.sequence)

        smallest, ind = sys.maxsize, 0
        for i in range(self.min_overlap_size + 1, min(len(seq1), len(seq2)) + 1):
            d = l.hamming(seq1[-i:], seq2[:i])
            if d <= smallest:
                smallest = d
                ind = i

        overlap_starts_in_seq_1 = len(seq1) - ind
        overlap_ends_in_seq_2 = ind

        beg = seq1[0:len(seq1) - ind].lower()
        overlap_1 = seq1[overlap_starts_in_seq_1:].lower()
        overlap_2 = seq2[:overlap_ends_in_seq_2].lower()
        end = seq2[ind:].lower()
        overlap = ''

        recovery_dict = {'none': 0, 'read_1': 0, 'read_2': 0}
        for i in range(0, len(overlap_1)):
            if overlap_1[i] != overlap_2[i]:
                # MISMATCH FOUND!

                if complete_overlap:
                    read_id_with_better_base_qual = self.get_read_id_with_better_base_qual(i, overlap_starts_in_seq_1 + i)

                    if read_id_with_better_base_qual == 0:
                        overlap += 'N'.upper()
                        recovery_dict['none'] += 1
                    elif read_id_with_better_base_qual == 1:
                        overlap += overlap_2[i].upper()
                        recovery_dict['read_1'] += 1
                    elif read_id_with_better_base_qual == 2:
                        overlap += overlap_1[i].upper()
                        recovery_dict['read_2'] += 1
                else:
                    read_id_with_better_base_qual = self.get_read_id_with_better_base_qual(overlap_starts_in_seq_1 + i, i)

                    if read_id_with_better_base_qual == 0:
                        overlap += 'N'.upper()
                        recovery_dict['none'] += 1
                    elif read_id_with_better_base_qual == 1:
                        overlap += overlap_1[i].upper()
                        recovery_dict['read_1'] += 1
                    elif read_id_with_better_base_qual == 2:
                        overlap += overlap_2[i].upper()
                        recovery_dict['read_2'] += 1

            else:
                overlap += overlap_1[i]

        if self.debug:
            print('--')
            print(self.input_1.entry.header_line)
            print(beg + colorize(overlap) + end)

        return ((beg, overlap, end), sum(recovery_dict.values()), recovery_dict)


    def merge_two_sequences(self):
        seq1 = self.input_1.entry.sequence
        seq2 = reverse_complement(self.input_2.entry.sequence)

        temp_seq_1 = tempfile.NamedTemporaryFile(delete=False)
        temp_seq_1_path = temp_seq_1.name
        temp_seq_2 = tempfile.NamedTemporaryFile(delete=False)
        temp_seq_2_path = temp_seq_2.name
        temp_merge_path = tempfile.NamedTemporaryFile(delete=False).name

        temp_seq_1.write('>1\n' + seq1 + '\n')
        temp_seq_2.write('>2\n' + seq2 + '\n')

        temp_seq_1.close()
        temp_seq_2.close()

        merger_process = ['merger', temp_seq_1_path, temp_seq_2_path, '-outseq', temp_merge_path, '-gapopen', '25', '-outfile', '/dev/null']
        if subprocess.call(merger_process, stderr=open('/dev/null', 'w')) == 0:
            merged = f.SequenceSource(temp_merge_path)
            next(merged)
            merged.close()

            os.remove(temp_seq_1_path)
            os.remove(temp_seq_2_path)
            os.remove(temp_merge_path)

            if self.debug:
                overlap_beg = len(seq1) - (len(seq1) + len(seq2) - len(merged.seq))
                overlap_end = overlap_beg + (len(seq1) + len(seq2) - len(merged.seq))
                print()
                print(''.join([merged.seq[:overlap_beg], colorize(merged.seq[overlap_beg:overlap_end]), merged.seq[overlap_end:]]))

            return ((merged.seq[:overlap_beg], merged.seq[overlap_beg:overlap_end], merged.seq[overlap_end:]), NumberOfConflicts(merged.seq), {'none': 0, 'read_1': 0, 'read_2': 0})
        else:
            print('Something went wrong while merging these: \n\n%s\n--\n%s\n\n' % (seq1, seq2))
            sys.exit(-2)

    def passes_minoche_Q30(self, base_qualities):
        # This algorithm is from Minoche et al. It calculates the length of the
        # read, returnes true only if two thirds of bases in the first half of
        # the read have Q-scores over Q30.

        half_length = len(base_qualities) / 2
        Q30 = len([True for _q in base_qualities[0:int(half_length)] if _q > 30])
        if Q30 < (0.66 * half_length):
            return (False, Q30)
        return (True, Q30)

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='Merge Overlapping Paired-End Illumina Reads')
    parser.add_argument('user_config', metavar = 'CONFIG_FILE',
                                        help = 'User configuration to run')
    parser.add_argument('-o', '--output_file_prefix', metavar = 'OUTPUT_FILE_PREFIX', default = None,
                                        help = 'Output file prefix (which will be used as a prefix\
                                                for files that appear in output directory)')
    parser.add_argument('--min-overlap-size', type = int, default = 15, metavar = 'INT',
                                        help = 'Overlap size must exceed this value. Default is %(default)d.')
    parser.add_argument('--max-num-mismatches', type = int, default = -1, metavar = 'INT',
                                        help = 'Maximum number of mismatches at the overlapped region to\
                                                retain the pair. The default behavior relies on `-P`\
                                                parameter and does not pay attention to the number of\
                                                mismatches at the overlapped region. \
                                                As of now, this parameter must be set to 0 \
                                                for forward and reverse reads of unequal length.')
    parser.add_argument('--min-qual-score', type = int, default = 15, metavar = 'INT',
                                        help = 'Minimum Q-score for a base to overwrite a mismatch\
                                                at the overlapped region. If there is a mismatch at\
                                                the overlapped region, the base with higher quality\
                                                is being used in the final sequence. Alternatively,\
                                                if the Q-score of the base with higher quality is\
                                                lower than the Q-score declared with this parameter,\
                                                that base is being marked as an ambiguous base, which\
                                                may result in the elimination of the merged sequence\
                                                depending on the --ignore-Ns paranmeter. The default\
                                                value is %(default)d.')
    parser.add_argument('-P', type = float, default = 0.3, metavar = 'FLOAT',
                                        help = 'Any merged sequence with P above the declared value\
                                                is discarded and stored in a separate file. P is computed\
                                                by dividing the number of mismatches at the overlapped region by\
                                                the length of the overlapped region. For instance, if the length\
                                                of the overlapped region is 30 nt long, and there are\
                                                3 mismatches in the overlapped region, P would be\
                                                3 / 30 = 0.1. The default value for P is %(default)f, however\
                                                it must be adjusted based on the expected overlap in a given\
                                                study with respect to desired stringency. Stringency can also be\
                                                adjusted using `--max-num-mismatches` parameter, or can be done\
                                                post-merging, using the program `filter-merged-reads`.')
    parser.add_argument('--enforce-Q30-check', action = 'store_true', default = False,
                                        help = 'By default, quality filtering is being done based only\
                                                on the mismatches found in the overlapped region, and\
                                                the beginning and the end of merged reads are not being\
                                                checked. However a final control can be enforced using\
                                                this flag. This flag turns on the Q30 check, as it was\
                                                explained by Minoche et al. in their 2012 paper. Briefly,\
                                                Q30-check eliminates pairs if the 66%% of bases in the\
                                                first half of each read do not have Q-scores over Q30.\
                                                Note that Q30 is applied only to the parts of reads that\
                                                did not overlap. If either of reads fail Q30 check,\
                                                merged sequence is discarded.')
    parser.add_argument('--compute-qual-dicts', action = 'store_true', default = False,
                                        help = 'When set, qual dicts will be computed. May take a\
                                                very long time for datasets with more than a\
                                                million pairs.')
    parser.add_argument('--retain-only-overlap', action = 'store_true', default = False,
                                        help = 'When set, merger will only return the parts of reads that do\
                                                overlap, and parts of reads that do not overlap will be\
                                                trimmed.')
    parser.add_argument('--debug', action = 'store_true', default = False,
                                        help = 'When set, debug messages will be printed')
    parser.add_argument('--ignore-deflines', action = 'store_true', default = False,
                                        help = 'If FASTQ files are not CASAVA outputs, parsing the header info\
                                                may go wrong. This flag tells the software to skip parsing\
                                                deflines.')
    parser.add_argument('--ignore-Ns', action = 'store_true', default = False,
                                        help = 'Merged sequences are being eliminated if they have\
                                                any ambiguous bases. If this parameter is set\
                                                merged pairs with Ns stay in the merged pairs bin.')
    parser.add_argument('--marker-gene-stringent', action = 'store_true', default = False,
                                        help = 'Finds the best merge going beyond expected "partial\
                                                overlap" case for amplicons. Please read the description\
                                                at url https://github.com/meren/illumina-utils/issues/1 for details.')
    parser.add_argument('--skip-suffix-trimming', action = 'store_true', default = False,
                                        help = 'Some datasets contain both partially and completely overlapping reads.\
                                                `--marker-gene-stringent` would be used in that case.\
                                                Completely overlapping reads can contain\
                                                unwanted adapter sequence at the ends of the reads\
                                                (read 1 adapter at the end of read 2\
                                                and read 2 adapter at the end of read 1).\
                                                By default, these adapter suffixes are trimmed from merged reads,\
                                                and non-overlapping parts of the insert\
                                                are left untouched in partially overlapping reads\
                                                (`--retain-only-overlap` trims partially overlapping merges).\
                                                Setting this flag leaves adapter suffixes untrimmed.')
    parser.add_argument('--report-r1-prefix', action = 'store_true', default = False,
                                        help = 'If there is a prefix to read 1 specified in the config file,\
                                                these sequences will be reported for merged reads\
                                                when this flag is set.\
                                                This can be useful if the prefix sequence has variable bases,\
                                                such as in a unique molecular identifier (UMI).')
    parser.add_argument('--report-r2-prefix', action = 'store_true', default = False,
                                        help = 'If there is a prefix to read 2 specified in the config file,\
                                                these sequences will be reported for merged reads\
                                                when this flag is set.\
                                                This can be useful if the prefix sequence has variable bases,\
                                                such as in a unique molecular identifier (UMI).')
    parser.add_argument('--num-threads', type = int, default = -1, metavar = 'INT',
                                        help = 'This triggers a faster algorithm for merging reads\
                                                with no mismatches in the overlapping region.\
                                                This can only be used in combination with `--max-num-mismatches 0`.\
                                                Specify the number of CPU cores that you wish to use.')
    parser.add_argument('--distance-metric', choices = ['hamming', 'levenshtein'], default='hamming')


    args = parser.parse_args()

    user_config = configparser.ConfigParser()
    user_config.read(args.user_config)

    try:
        config = RunConfiguration(user_config)
    except RunConfigError as e:
        print(e)
        sys.exit()

    merger = Merger(config)
    merger.output_file_prefix = args.output_file_prefix
    merger.p_value = args.P
    merger.max_num_mismatches = args.max_num_mismatches
    merger.min_overlap_size = args.min_overlap_size
    merger.min_qual_score = args.min_qual_score
    merger.enforce_Q30_check = args.enforce_Q30_check
    merger.compute_qual_dicts = args.compute_qual_dicts
    merger.debug = args.debug
    merger.ignore_Ns = args.ignore_Ns
    merger.ignore_deflines = args.ignore_deflines
    merger.retain_only_overlap = args.retain_only_overlap
    merger.marker_gene_stringent = args.marker_gene_stringent
    merger.skip_suffix_trimming = args.skip_suffix_trimming
    merger.report_r1_prefix = args.report_r1_prefix
    merger.report_r2_prefix = args.report_r2_prefix
    merger.num_threads = args.num_threads
    merger.distance_metric = args.distance_metric

    sys.exit(merger.run())
