#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2015, A. Murat Eren
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free
# Software Foundation; either version 2 of the License, or (at your option)
# any later version.
#
# Please read the COPYING file.

import re
import os
import sys

import IlluminaUtils.lib.fastqlib as u
from IlluminaUtils.utils.helperfunctions import big_number_pretty_print


def get_file_read_objects(r1_path, r2_path):
    try:
        r1 = u.FastQSource(r1_path, lazy_init=True)
        r2 = u.FastQSource(r2_path, lazy_init=True)
    except u.FastQLibError as e:
        print("FastQLib is not happy.\n\n\t", e, "\n")
        sys.exit()

    r1.percent_step = 1000
    r2.percent_step = 1000

    return r1, r2


def compile_identifier(identifier_code, r1_object, tested):
    try:
        identifier_function = eval(identifier_code)
    except SyntaxError as e:
        print("Your identifier code did not compile :(")
        sys.exit()

    if not tested:
        show_test_results(r1_object, identifier_code, identifier_function)

    return identifier_function


def show_test_results(r1_object, identifier_code, identifier_function):
    r1_object.next(raw=True)
    print("Testing your stuff.\n\n")

    print("This is what your header looks like:\n\n")
    print("\t%s\n\n" % r1_object.entry.header_line)

    print("This is how your identifier ('''%s''') is parses it:\n\n" % identifier_code)
    print("\t%s\n\n" % identifier_function(r1_object.entry.header_line))
    sys.exit()


def match_sequential(r1_path, r2_path, identifier_code="header.split()[0]", tested=False):
    """this is useful if the identifier code resolves a sequential number per pair"""
    r1, r2 = get_file_read_objects(r1_path, r2_path)
    identifier_function = compile_identifier(identifier_code, r1, tested)

    r1_matched_output = u.FastQOutput(os.path.join(r1_path + '_MATCHED'))
    r2_matched_output = u.FastQOutput(os.path.join(r2_path + '_MATCHED'))

    r1.next(raw=True)
    r2.next(raw=True)
    r1_dirty, r2_dirty = True, True

    num_matched = 0
    while 1:
        if r1_dirty:
            r1_id = int(identifier_function(r1.entry.header_line))
            r1_dirty = False
        if r2_dirty:
            r2_id = int(identifier_function(r2.entry.header_line))
            r2_dirty = False

        if r1.p_available:
            r1.print_percentage()

        if r1_id == r2_id:
            r1_matched_output.store_entry(r1.entry)
            r2_matched_output.store_entry(r2.entry)
            num_matched += 1
            r1.next(raw=True)
            r2.next(raw=True)
        else:
            if r1_id < r2_id:
                r1.next(raw=True)
                r1_dirty = True
            else:
                r2.next(raw=True)
                r2_dirty = True

        if not r1.entry.is_valid or not r2.entry.is_valid:
            break

    sys.stderr.write('\r')

    r1_matched_output.close()
    r2_matched_output.close()

    print('')
    print('R1 file that matches to R2: %s' % r1_path + '_MATCHED')
    print('R2 file that matches to R1: %s' % r2_path + '_MATCHED')


def match_non_sequential(r1_path, r2_path, identifier_code="header.split()[0]", tested=False):
    r1, r2 = get_file_read_objects(r1_path, r2_path)
    identifier_function = compile_identifier(identifier_code, r1, tested)

    r1_matched_output = u.FastQOutput(os.path.join(r1_path + '_MATCHED'))
    r2_matched_output = u.FastQOutput(os.path.join(r2_path + '_MATCHED'))

    r1_ids = {}
    while r1.next(raw = True):
        if r1.p_available:
            r1.print_percentage()

        r1_ids[identifier_function(r1.entry.header_line)] = r1.entry.header_line
    sys.stderr.write('\r')
    print('%d ids found for R1                                ' % len(r1_ids))

    r2_ids = {}
    while r2.next(raw = True):
        if r2.p_available:
            r2.print_percentage()

        r2_ids[identifier_function(r2.entry.header_line)] = r2.entry.header_line
    sys.stderr.write('\r')
    print('%d ids found for R2                                ' % len(r2_ids))

    remove_from_r1 = [r1_ids[i] for i in set(r1_ids).difference(r2_ids)]
    remove_from_r2 = [r2_ids[i] for i in set(r2_ids).difference(r1_ids)]

    print('')
    print('%d ids in R1 are not matching to R2' % len(remove_from_r1))
    print('%d ids in R2 are not matching to R1' % len(remove_from_r2))


    r1.reset()
    r2.reset()

    while r1.next(raw = True):
        if r1.p_available:
            r1.print_percentage()

        if r1.entry.header_line in remove_from_r1:
            continue
        else:
            r1_matched_output.store_entry(r1.entry)
    sys.stderr.write('\r')

    while r2.next(raw = True):
        if r2.p_available:
            r2.print_percentage()

        if r2.entry.header_line in remove_from_r2:
            continue
        else:
            r2_matched_output.store_entry(r2.entry)
    sys.stderr.write('\r')

    r1_matched_output.close()
    r2_matched_output.close()

    print('')
    print('R1 file that matches to R2: %s' % r1_path + '_MATCHED')
    print('R2 file that matches to R1: %s' % r2_path + '_MATCHED')


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='Recover matching ids in two FASTQ files')
    parser.add_argument('--r1', metavar = 'FILE_PATH', required=True, help = "R1")
    parser.add_argument('--r2', metavar = 'FILE_PATH', required=True, help = "R2")
    parser.add_argument('--identifier-code', metavar = 'PYTHON CODE', default='lambda defline: defline.split()[0]', help = "Lambda function to parse the header. Default: '''%(default)s'''.")
    parser.add_argument('--identifier-tested', default = False, action="store_true", help = "Use this flag to indicate that you tested your identifier.")
    parser.add_argument('--sequential', default = False, action="store_true", help = "Your identifier code parses an integer value that can link pairs, and is incremental throughout the file.")

    args = parser.parse_args()

    if args.r2 is None and args.output2 is not None:
        parser.error("You must declare both R1 and R2 files.")

    if args.sequential:
        sys.exit(match_sequential(args.r1, args.r2, args.identifier_code, args.identifier_tested))
    else:
        sys.exit(match_non_sequential(args.r1, args.r2, args.identifier_code, args.identifier_tested))
