#!/usr/bin/env python

import os
import sys
from tqdm import tqdm as progressbar
import chess.commands as commands

import logging
logger = logging.getLogger('')


class Chess(object):

    def __init__(self):
        parser = commands.chess_parser()

        flag_increments = {
            '-l': 2, '--log-file': 2,
        }

        option_ix = 1
        while (option_ix < len(sys.argv) and
               sys.argv[option_ix].startswith('-')):
            if sys.argv[option_ix] in flag_increments:
                option_ix += flag_increments[sys.argv[option_ix]]
            else:
                option_ix += 1

        # parse_args defaults to [1:] for args, but you need to
        # exclude the rest of the args too, or validation will fail
        args = parser.parse_args(sys.argv[1:option_ix+1])

        # configure logger
        if args.verbosity == 1:
            log_level = logging.WARN
        elif args.verbosity == 2:
            log_level = logging.INFO
        elif args.verbosity > 2:
            log_level = logging.DEBUG
        else:
            log_level = logging.INFO
        logger.setLevel(log_level)

        if args.log_file is None:
            sh = logging.StreamHandler()
            sh_formatter = logging.Formatter(
                "%(asctime)s %(levelname)s %(message)s")
            sh.setFormatter(sh_formatter)
            sh.setLevel(log_level)
            logger.addHandler(sh)
        else:
            log_file = os.path.expanduser(args.log_file)
            fh = logging.FileHandler(log_file, mode='a')
            formatter = logging.Formatter(
                "%(asctime)s %(levelname)s %(message)s")
            fh.setFormatter(formatter)
            fh.setLevel(log_level)
            logger.addHandler(fh)

        if args.command is None or not hasattr(self, args.command):
            print('Unrecognized command')
            parser.print_help()
            exit(1)

        # echo parameters back to user
        command = " ".join(sys.argv)
        logger.info("Running '{}'".format(command))

        import chess
        import fanc
        logger.info("CHESS version: {}".format(chess.__version__))
        logger.info("FAN-C version: {}".format(fanc.__version__))

        # use dispatch pattern to invoke method with same name
        getattr(self, args.command)([sys.argv[0]] + sys.argv[option_ix:])

        # echo parameters back to user
        logger.info("Finished '{}'".format(" ".join(sys.argv)))

    def sim(self, argv):
        parser = commands.sim_parser()

        args = parser.parse_args(argv[2:])

        from chess.helpers import load_pairs_for_chroms, load_regions, \
            GenomicRegion, chunks, load_contacts, load_oe_contacts
        from chess.sim import chunk_comparison_worker
        import multiprocessing as mp
        import numpy as np
        from collections import defaultdict

        reference_regions_file = args.reference_regions
        reference_matrix_file = args.reference_contacts
        query_regions_file = args.query_regions
        query_matrix_file = args.query_contacts
        pairs_file = args.pairs
        threads = args.threads
        is_oe = args.oe_input
        background_regions_file = args.background_regions
        background_query = args.background_query
        output_file = args.out
        limit_background = args.limit_background

        output_dir = os.path.dirname(output_file)
        if output_dir == '':
            output_file = os.path.join('.', output_file)
        elif not os.path.exists(output_dir):
            logger.error("Specified output directory does not exist!")
            raise RuntimeError("Specified output path not found.")

        logger.debug('Parameters:')
        logger.debug(args)

        logger.info('Loading reference contact data')
        loading_function = load_oe_contacts if is_oe else load_contacts
        try:
            reference_edges, reference_region_trees, reference_regions = loading_function(
                reference_matrix_file, reference_regions_file)
        except ValueError as error:
            logger.error(error)
            logger.error(
                ("Reference contact data could not be loaded. "
                 "Please specify a valid input file. "
                 "Files in sparse format can only be loaded if "
                 "--reference-regions is specified."))
            exit()
        logger.info('Loading query contact data')
        try:
            query_edges, query_region_trees, query_regions = loading_function(
                query_matrix_file, query_regions_file)
        except ValueError as error:
            logger.error(error)
            logger.error(
                ("Query contact data could not be loaded. "
                 "Please specify a valid input file. "
                 "Files in sparse format can only be loaded if "
                 "--query-regions is specified."))
            exit()

        reference_bin_size = max(
            r.end - r.start + 1 for r in reference_regions)
        query_bin_size = max(
            r.end - r.start + 1 for r in query_regions)

        logger.info('Loading region pairs')
        pairs, dropped = load_pairs_for_chroms(
            pairs_file,
            set().union(
                set(reference_region_trees.keys()),
                set(query_region_trees.keys())))
        if dropped > 0:
            logger.warning(("{} region pairs have been dropped, "
                            "because they involve chromosomes "
                            "that are not present in the provided "
                            "contact data.").format(dropped))
        if len(pairs) == 0:
            logger.error("No valid region pairs found; aborting.")
            exit()

        reference_max_region_size = max(
            [p[1].end - p[1].start + 1 for p in pairs])
        query_max_region_size = max(
            [p[2].end - p[2].start + 1 for p in pairs])
        max_reference_region_span = int(
            reference_max_region_size / reference_bin_size)
        max_query_region_span = int(
            query_max_region_size / query_bin_size)

        # Early exit if the regions are too small.
        if max_reference_region_span < 20:
            logger.error(
                "All regions need to span at least 20 bins. "
                "The provided reference regions span at most {} bins. "
                "Please try again with larger regions or a smaller bin size. "
                "The bin size of the input data has been detected to be {} bp."
                .format(max_reference_region_span, reference_bin_size)
                )
            exit()
        elif max_query_region_span < 20:
            logger.error(
                "All regions need to span at least 20 bins. "
                "The provided query regions span at most {} bins. "
                "Please try again with larger regions or a smaller bin size. "
                "The bin size of the input data has been detected to be {} bp."
                .format(max_query_region_span, query_bin_size)
                )
            exit()

        logger.info("Launching workers")
        m = mp.Manager()
        input_queue = m.Queue()
        output_queue = m.Queue()

        pool = None
        try:
            pool = mp.Pool(threads,
                           chunk_comparison_worker,
                           (input_queue,
                            output_queue,
                            reference_edges,
                            reference_region_trees,
                            query_edges,
                            query_region_trees,
                            20,
                            args.keep_unmappable_bins,
                            args.absolute_windowsize,
                            args.relative_windowsize,
                            args.mappability_cutoff))

            logger.info("Submitting pairs for comparison")
            submitted_counter = 0
            for chunk in chunks(pairs, threads):
                input_queue.put([chunk, True])
                submitted_counter += 1

            ssim_results = dict()
            all_ssim = []
            for i in range(submitted_counter):
                out = output_queue.get(block=True)
                if isinstance(out, Exception):
                    raise out
                for pair_ix, ssim, sn in out:
                    ssim_results[pair_ix] = [ssim, sn, np.nan]
                    if not np.isnan(ssim):
                        all_ssim.append(ssim)

            # calculate ssim z-score
            all_ssim_mean = np.nanmean(all_ssim)
            all_ssim_sd = np.nanstd(all_ssim)
            for pair_id in list(ssim_results.keys()):
                ssim, sn, z = ssim_results[pair_id]
                if np.isfinite(ssim):
                    ssim_results[pair_id][2] = (ssim - all_ssim_mean) / all_ssim_sd

            # no background calculations
            if background_regions_file is None and not background_query:
                nan_counter = 0
                with open(output_file, 'w') as o:
                    o.write("ID\tSN\tssim\tz_ssim\n")
                    for pair_id, _, _ in pairs:
                        if pair_id in ssim_results:
                            ssim, sn, z_ssim = ssim_results[pair_id]
                        else:
                            ssim, sn, z_ssim = np.nan, np.nan, np.nan

                        if np.isnan(ssim):
                            nan_counter += 1

                        o.write("{}\t{}\t{}\t{}\n".format(pair_id, sn, ssim, z_ssim))

                if nan_counter > 0:
                    logger.info(
                        "Could not compute similarity for {} region pairs."
                        "This can be due to faulty coordinates, too small"
                        "region sizes or too many unmappable bins".format(
                            nan_counter))

            # background calculations
            else:
                logger.info("Running background calculations")
                if background_regions_file is not None:
                    background_regions_file = os.path.expanduser(background_regions_file)
                    background_regions, _, _ = load_regions(background_regions_file, ignore_ix=True)
                else:
                    background_regions = None
                    logger.info("Generating background regions from query")

                # process each reference region separately
                ssim_zscores = dict()
                ssim_pvalues = dict()
                for pair_id, reference_region, query_region in progressbar(pairs):
                    if background_query:
                        chromosome_ends = defaultdict(int)
                        for r in query_regions:
                            chromosome_ends[r.chromosome] = max(chromosome_ends[r.chromosome], r.end)

                        background_regions = []
                        query_size = query_region.end - query_region.start
                        # generate background regions from query genome
                        for region in query_regions:
                            if limit_background and region.chromosome != query_region.chromosome:
                                continue

                            # forward
                            br1 = GenomicRegion(chromosome=region.chromosome,
                                                start=region.start,
                                                end=region.start + query_size,
                                                strand='+')
                            if br1.end <= chromosome_ends[br1.chromosome]:
                                background_regions.append(br1)

                            # reverse
                            br2 = GenomicRegion(chromosome=region.chromosome,
                                                start=region.start,
                                                end=region.start + query_size,
                                                strand='-')
                            if br2.end <= chromosome_ends[br2.chromosome]:
                                background_regions.append(br2)

                    background_pairs = []
                    for bg_id, background_region in enumerate(background_regions):
                        background_pairs.append((bg_id, reference_region, background_region))

                    submitted_counter = 0
                    for chunk in chunks(background_pairs, threads):
                        input_queue.put([chunk, False])
                        submitted_counter += 1

                    background_ssim_results = []
                    for i in range(submitted_counter):
                        out = output_queue.get(block=True)
                        if isinstance(out, Exception):
                            raise out
                        for pair_ix, ssim, sn in out:
                            background_ssim_results.append(ssim)

                    try:
                        original_ssim, _, _ = ssim_results[pair_id]
                        if np.isnan(original_ssim):
                            ssim_zscores[pair_id] = np.nan
                            ssim_pvalues[pair_id] = np.nan
                        else:
                            ssim_mean = np.nanmean(background_ssim_results)
                            ssim_sd = np.nanstd(background_ssim_results)

                            z = (original_ssim - ssim_mean) / ssim_sd
                            p = (1 + np.sum(original_ssim <= background_ssim_results)) / len(background_ssim_results)
                            ssim_zscores[pair_id] = z
                            ssim_pvalues[pair_id] = p
                    except KeyError:
                        ssim_zscores[pair_id] = np.nan
                        ssim_pvalues[pair_id] = np.nan

                nan_counter = 0
                with open(output_file, 'w') as o:
                    o.write("ID\tSN\tssim\tz_ssim\tz_bg\tp_bg\n")
                    for pair_id, _, _ in pairs:
                        if pair_id in ssim_results:
                            ssim, sn, z_ssim = ssim_results[pair_id]
                            z = ssim_zscores[pair_id]
                            p = ssim_pvalues[pair_id]
                        else:
                            ssim, sn, z_ssim, z, p = np.nan, np.nan, np.nan, np.nan, np.nan

                        if np.isnan(ssim):
                            nan_counter += 1

                        o.write("{}\t{}\t{}\t{}\t{}\t{}\n".format(pair_id, sn, ssim, z_ssim, z, p))

                if nan_counter > 0:
                    logger.info("Could not compute similarity for {}/{} region pairs. "
                                "This typically happens if the region size is too small "
                                "(< 20 bins of the matrix)".format(nan_counter, len(pairs)))
        finally:
            for i in range(threads):
                input_queue.put(None)

            if pool is not None:
                pool.terminate()

    def oe(self, argv):
        parser = commands.oe_parser()

        args = parser.parse_args(argv[2:])

        import gzip
        from chess.oe import observed_expected
        from chess.helpers import is_gzipped

        output_regions, output_edges = observed_expected(
            args.regions, args.input_matrix)
        if is_gzipped(args.output_matrix):
            with gzip.open(args.output_matrix, 'w') as o:
                for source, sink, weight in output_edges:
                    line = "{}\t{}\t{:.6e}\n".format(source, sink, weight)
                    o.write(line.encode('utf-8'))
        else:
            with open(args.output_matrix, 'w') as o:
                for source, sink, weight in output_edges:
                    o.write("{}\t{}\t{:.6e}\n".format(source, sink, weight))

    def pairs(self, argv):
        parser = commands.pairs_parser()
        args = parser.parse_args(argv[2:])
        output_file = args.output

        output_dir = os.path.dirname(output_file)
        if output_dir == '':
            output_file = os.path.join('.', output_file)
        elif not os.path.exists(output_dir):
            logger.error("Specified output directory does not exist!")
            raise RuntimeError("Specified output path not found.")

        import pybedtools as pbt

        def load_chromosome_sizes(path):
            """Load chromosome sizes from file.

            Args:
                path (str): Path to chromosome sizes file

            Returns:
                dict: The file in dictionary representation.
            """
            chromosome_sizes = {}
            with open(path, 'r') as f:
                for line in f:
                    if line == '\n':
                        continue
                    fields = line.strip().split()
                    chromosome, size = fields[:2]
                    chromosome_sizes[chromosome] = int(size)

            return chromosome_sizes

        if args.file_input:
            chromosome_sizes = load_chromosome_sizes(args.genome)
        else:
            try:
                chromosome_sizes = pbt.chromsizes(args.genome)
                if len(chromosome_sizes) == 0:
                    raise ValueError(
                        "Genome not recognised in pybedtools.chromsizes: {}".format(
                            args.genome))
                # pbt has (start, end) as values,
                # where start is usually 0
                chromosome_sizes = {
                    k: v[1] for k, v in chromosome_sizes.items()}
            except (OSError, ValueError):
                logger.info((
                    'No entry found with pybedtools.'
                    ' Trying to read from file.'))
                chromosome_sizes = load_chromosome_sizes(args.genome)

        chromosomes = chromosome_sizes.keys()
        if args.chromosome:
            try:
                chromosome_sizes[args.chromosome]
                chromosomes = [args.chromosome]
            except KeyError:
                raise ValueError((
                    'Specified chromosome has no chromosome size entry.'))

        with open(args.output, 'w') as f:
            pair_id = 0
            for chromosome in chromosomes:
                chromosome_size = chromosome_sizes[chromosome]
                start = 1
                while start <= (chromosome_size - args.window):
                    end = start + args.window
                    f.write(
                        '\t'.join(
                            str(e) for e in [chromosome, start, end,
                                             chromosome, start, end,
                                             pair_id, '.', '+', '+'])
                        + '\n')
                    start += args.step
                    pair_id += 1

    def background(self, argv):
        parser = commands.background_parser()
        args = parser.parse_args(argv[2:])

        import pybedtools as pbt
        from chess.helpers import read_chromosome_sizes, GenomicRegion

        genome_or_region = args.genome_or_region
        window = args.window
        step = args.step
        output_file = args.output
        strand = args.strand

        output_dir = os.path.dirname(output_file)
        if output_dir == '':
            output_file = os.path.join('.', output_file)
        elif not os.path.exists(output_dir):
            logger.error("Specified output directory does not exist!")
            raise RuntimeError("Specified output path not found.")

        if strand not in ['-', '+', None]:
            raise ValueError("--strand must be either - or +")

        try:
            chromosome_sizes = pbt.chromsizes(genome_or_region)
            if len(chromosome_sizes) == 0:
                raise ValueError("Genome not recognised in pybedtools.chromsizes: {}".format(genome_or_region))
            # pbt has (start, end) as values,
            # where start is usually 0
            chromosome_sizes = {
                k: v[1] for k, v in chromosome_sizes.items()}

            input_regions = [(chromosome, 1, end) for chromosome, end in chromosome_sizes.items()]
        except ValueError:
            logger.info((
                'No entry found with pybedtools.'
                ' Trying to read from file.'))
            if os.path.exists(genome_or_region):
                chromosome_sizes = read_chromosome_sizes(genome_or_region)
                input_regions = [(chromosome, 1, end) for chromosome, end in chromosome_sizes.items()]
            else:
                region = GenomicRegion.from_string(genome_or_region)
                input_regions = [(region.chromosome, region.start, region.end)]

        with open(output_file, 'w') as o:
            for chromosome, start, end in input_regions:
                region_ix = 0
                for s in range(start, end - window, step):
                    if strand is None or strand == '+':
                        o.write("{}\t{}\t{}\t{}\t.\t{}\n".format(
                            chromosome, s, s + window, region_ix, '+'
                        ))
                        region_ix += 1

                    if strand is None or strand == '-':
                        o.write("{}\t{}\t{}\t{}\t.\t{}\n".format(
                            chromosome, s, s + window, region_ix, '-'
                        ))
                        region_ix += 1

    def extract(self, argv):
        parser = commands.extract_parser()
        args = parser.parse_args(argv[2:])
        from chess.helpers import load_pairs_iter, load_oe_contacts

        from chess.get_structures import extract_structures

        output_dir = args.out
        if not os.path.exists(output_dir):
            logger.error("Specified output directory does not exist!")
            raise RuntimeError("Specified output path not found.")

        reference_regions_file = args.reference_regions
        reference_matrix_file = args.reference_contacts
        query_matrix_file = args.query_contacts
        query_regions_file = args.query_regions
        pairs_file = args.pairs

        logger.debug('Parameters:')
        logger.debug(args)

        logger.info('Loading reference contact data')
        try:
            reference_edges, reference_region_trees, reference_regions = load_oe_contacts(
                reference_matrix_file, reference_regions_file)
        except ValueError:
            logger.error('Failed to load the reference data.')
        try:
            query_edges, query_region_trees, query_regions = load_oe_contacts(
                query_matrix_file, query_regions_file)
        except ValueError:
            logger.error('Failed to load the query data.')

        logger.info('Loading region pairs')

        pairs = list(load_pairs_iter(pairs_file))
        logger.info(
            'Applying image filtering to identify specific structures')
        extract_structures(
            reference_edges,
            reference_region_trees,
            query_edges,
            query_region_trees,
            pairs, output_dir,
            args.windowsize,
            args.sigma_spatial,
            args.size_medianfilter,
            args.closing_square)
        logger.info('Results collected')

    def crosscorrelate(self, argv):
        parser = commands.crosscorrelate_parser()
        args = parser.parse_args(argv[2:])

        from chess.helpers import load_pairs_iter
        from chess.cross_correlation import correlate2d

        input_file = args.extracted_file
        pairs_file = args.pairs
        output_dir = args.outdir
        pairs = list(load_pairs_iter(pairs_file))
        correlate2d(input_file, output_dir, pairs)


if __name__ == '__main__':
    Chess()
