#!/usr/bin/env python

import os
import sys
import argparse
import textwrap
from tqdm import tqdm as progressbar


import logging
logger = logging.getLogger('')


class Chess(object):

    def __init__(self):
        parser = chess_parser()

        flag_increments = {
            '-l': 2, '--log-file': 2,
        }

        option_ix = 1
        while (option_ix < len(sys.argv) and
               sys.argv[option_ix].startswith('-')):
            if sys.argv[option_ix] in flag_increments:
                option_ix += flag_increments[sys.argv[option_ix]]
            else:
                option_ix += 1

        # parse_args defaults to [1:] for args, but you need to
        # exclude the rest of the args too, or validation will fail
        args = parser.parse_args(sys.argv[1:option_ix+1])

        # configure logger
        if args.verbosity == 1:
            log_level = logging.WARN
        elif args.verbosity == 2:
            log_level = logging.INFO
        elif args.verbosity > 2:
            log_level = logging.DEBUG
        else:
            log_level = logging.INFO
        logger.setLevel(log_level)

        if args.log_file is None:
            sh = logging.StreamHandler()
            sh_formatter = logging.Formatter(
                "%(asctime)s %(levelname)s %(message)s")
            sh.setFormatter(sh_formatter)
            sh.setLevel(log_level)
            logger.addHandler(sh)
        else:
            log_file = os.path.expanduser(args.log_file)
            fh = logging.FileHandler(log_file, mode='a')
            formatter = logging.Formatter(
                "%(asctime)s %(levelname)s %(message)s")
            fh.setFormatter(formatter)
            fh.setLevel(log_level)
            logger.addHandler(fh)

        # get version info
        if args.print_version:
            import chess
            print(chess.__version__)
            exit()

        if args.command is None or not hasattr(self, args.command):
            print('Unrecognized command')
            parser.print_help()
            exit(1)

        # echo parameters back to user
        command = " ".join(sys.argv)
        logger.info("Running '{}'".format(command))

        # use dispatch pattern to invoke method with same name
        getattr(self, args.command)([sys.argv[0]] + sys.argv[option_ix:])

        # echo parameters back to user
        logger.info("Finished '{}'".format(" ".join(sys.argv)))

    def sim(self, argv):
        parser = MyParser(
            description='''
            Compare structures between pairs of chromatin contact
            matrices using the structural similarity index.
            If --background-query or --background-regions
            are specified, compute p-value and z-value
            after obtaining a background distribution
            of similarity values of the reference in the pair
            to specfied background regions.

            The input matrices are expected to be balanced,
            e.g. by Knight-Ruiz matrix balancing.
            The contacts are automatically converted internally
            to observed / expected values. In your input files
            are already observed / expected transformed,
            set the --oe-input flag.
            ''',
            formatter_class=argparse.ArgumentDefaultsHelpFormatter)

        parser.add_argument(
            'reference_contacts',
            type=str,
            help='''
            Balanced contact matrix for the reference sample
            in one of the following formats:
                fanc .hic,
                juicer .hic@<resolution>,
                cooler .cool@<resolution> or .mcool@<resolution>,
                sparse format (each line:
                    <row region index> <column region index> <matrix value>).

            If the file is in sparse format, the corresponding regions
            BED file needs to be passed via --reference-regions.
            ''')

        parser.add_argument(
            '--reference-regions', dest="reference_regions",
            type=str,
            help='''BED file (no header) with regions corresponding to
            the number of rows in the provided reference matrix.''')

        parser.add_argument(
            'query_contacts',
            type=str,
            help='''
            Balanced contact matrix for the query sample
            in one of the following formats:
                fanc .hic,
                juicer .hic@<resolution>,
                cooler .cool@<resolution> or .mcool@<resolution>,
                sparse format (each line:
                    <row region index> <column region index> <matrix value>).

            If the file is in sparse format, the corresponding regions
            BED file needs to be passed via --query-regions.
            ''')

        parser.add_argument(
            '--query-regions', dest="query_regions",
            type=str,
            help='''BED file (no header) with regions corresponding to
            the number of rows in the provided query matrix.''')

        parser.add_argument(
            'pairs',
            type=str,
            help='''Region pairs to compare.
            Expected to be in bedpe format with chrom1, start1, ...
            corresponding to reference and chrom2, start2, ... to query.''')

        parser.add_argument(
            'out',
            type=str,
            help='''Path to outfile.''')

        parser.add_argument(
            '--background-regions', dest='background_regions',
            type=str,
            help='''BED file with regions to be used for background calculations.
                    If provided, CHESS will generate Z-scores and P-values for
                    similarities.''')

        parser.add_argument(
            '--background-query', dest='background_query',
            default=False,
            action='store_true',
            help='Use every region of the same size as the '
                 'reference from the query genome as background. '
                 'Useful, for example, as background for inter-species '
                 'comparisons.')

        parser.add_argument(
            '-p', dest='threads',
            type=int,
            default=1,
            help='''Number of cores to use.''')

        parser.add_argument(
            '--keep-unmappable-bins',
            action='store_true',
            default=False,
            help='''Disable deletion of unmappable bins.''')

        parser.add_argument(
            '--mappability-cutoff',
            type=float,
            default=0.1,
            help='''Low pass threshold for fraction of unmappable bins.
                    Matrices with a higher content of unmappable bins
                    will not be considered.
                    Unmappable bins will be deleted from matrices
                    passing the filter.''')

        parser.add_argument(
            '-r', '--relative-windowsize',
            type=float,
            default=1,
            help='''Relative window size value
            for the win_size param in the ssim function.
            Fraction of the matrix size.''')

        parser.add_argument(
            '-a', '--absolute-windowsize',
            type=int,
            default=None,
            help='''Absolute window size value in bins
            for the win_size param in the ssim function.
            Overwrites -r.''')

        parser.add_argument(
            '--oe-input',
            action='store_true',
            default=False,
            help='''Use if input contacts are already observed / expected
            transformed.''')

        parser.add_argument(
            '--limit-background',
            action='store_true',
            default=False,
            help='Restrict background computation to the syntenic / paired '
                 'chromosome as indicated in the pairs file.')

        args = parser.parse_args(argv[2:])

        from chess.helpers import load_pairs_for_chroms, load_regions, \
            GenomicRegion, chunks, load_contacts, load_oe_contacts
        from chess.sim import chunk_comparison_worker
        import multiprocessing as mp
        import numpy as np
        from collections import defaultdict

        reference_regions_file = args.reference_regions
        reference_matrix_file = args.reference_contacts
        query_regions_file = args.query_regions
        query_matrix_file = args.query_contacts
        pairs_file = args.pairs
        threads = args.threads
        is_oe = args.oe_input
        background_regions_file = args.background_regions
        background_query = args.background_query
        output_file = args.out
        limit_background = args.limit_background

        logger.debug('[MAIN]: Parameters:')
        logger.debug(args)

        logger.info('[MAIN]: Loading reference contact data')
        loading_function = load_oe_contacts if is_oe else load_contacts
        try:
            reference_edges, reference_region_trees, reference_regions = loading_function(
                reference_matrix_file, reference_regions_file)
        except ValueError as error:
            logger.error(error)
            logger.error(
                ("Reference contact data could not be loaded. "
                 "Please specify a valid input file. "
                 "Files in sparse format can only be loaded if "
                 "--reference-regions is specified."))
            exit()
        try:
            query_edges, query_region_trees, query_regions = loading_function(
                query_matrix_file, query_regions_file)
        except ValueError as error:
            logger.error(error)
            logger.error(
                ("Query contact data could not be loaded. "
                 "Please specify a valid input file. "
                 "Files in sparse format can only be loaded if "
                 "--query-regions is specified."))
            exit()

        logger.info('[MAIN]: Loading region pairs')
        pairs, dropped = load_pairs_for_chroms(
            pairs_file,
            set().union(
                set(reference_region_trees.keys()),
                set(query_region_trees.keys())))
        if dropped > 0:
            logger.warning(("{} region pairs have been dropped, "
                            "because they involve chromosomes "
                            "that are not present in the provided "
                            "contact data.").format(dropped))

        logger.info("Launching workers")
        m = mp.Manager()
        input_queue = m.Queue()
        output_queue = m.Queue()

        pool = None
        try:
            pool = mp.Pool(threads,
                           chunk_comparison_worker,
                           (input_queue,
                            output_queue,
                            reference_edges,
                            reference_region_trees,
                            query_edges,
                            query_region_trees,
                            20,
                            args.keep_unmappable_bins,
                            args.absolute_windowsize,
                            args.relative_windowsize,
                            args.mappability_cutoff))

            logger.info("Submitting pairs for comparison")
            submitted_counter = 0
            for chunk in chunks(pairs, threads):
                input_queue.put([chunk, True])
                submitted_counter += 1

            ssim_results = dict()
            all_ssim = []
            for i in range(submitted_counter):
                out = output_queue.get(block=True)
                if isinstance(out, Exception):
                    raise out
                for pair_ix, ssim, sn in out:
                    ssim_results[pair_ix] = [ssim, sn, np.nan]
                    if not np.isnan(ssim):
                        all_ssim.append(ssim)

            # calculate ssim z-score
            all_ssim_mean = np.nanmean(all_ssim)
            all_ssim_sd = np.nanstd(all_ssim)
            for pair_id in list(ssim_results.keys()):
                ssim, sn, z = ssim_results[pair_id]
                if np.isfinite(ssim):
                    ssim_results[pair_id][2] = (ssim - all_ssim_mean) / all_ssim_sd

            # no background calculations
            if background_regions_file is None and not background_query:
                nan_counter = 0
                with open(output_file, 'w') as o:
                    o.write("ID\tSN\tssim\tz_ssim\n")
                    for pair_id, _, _ in pairs:
                        if pair_id in ssim_results:
                            ssim, sn, z_ssim = ssim_results[pair_id]
                        else:
                            ssim, sn, z_ssim = np.nan, np.nan, np.nan

                        if np.isnan(ssim):
                            nan_counter += 1

                        o.write("{}\t{}\t{}\t{}\n".format(pair_id, sn, ssim, z_ssim))

                if nan_counter > 0:
                    logger.info("Could not compute similarity for {} region pairs. "
                                "This typically happens if the region size is too small "
                                "(< 20 bins of the matrix)".format(nan_counter))

            # background calculations
            else:
                logger.info("[MAIN]: Running background calculations")
                if background_regions_file is not None:
                    background_regions_file = os.path.expanduser(background_regions_file)
                    background_regions, _, _ = load_regions(background_regions_file, ignore_ix=True)
                else:
                    background_regions = None
                    logger.info("Generating background regions from query")

                # process each reference region separately
                ssim_zscores = dict()
                ssim_pvalues = dict()
                for pair_id, reference_region, query_region in progressbar(pairs):
                    if background_query:
                        chromosome_ends = defaultdict(int)
                        for r in query_regions:
                            chromosome_ends[r.chromosome] = max(chromosome_ends[r.chromosome], r.end)

                        background_regions = []
                        query_size = query_region.end - query_region.start
                        # generate background regions from query genome
                        for region in query_regions:
                            if limit_background and region.chromosome != query_region.chromosome:
                                continue

                            # forward
                            br1 = GenomicRegion(chromosome=region.chromosome,
                                                start=region.start,
                                                end=region.start + query_size,
                                                strand='+')
                            if br1.end <= chromosome_ends[br1.chromosome]:
                                background_regions.append(br1)

                            # reverse
                            br2 = GenomicRegion(chromosome=region.chromosome,
                                                start=region.start,
                                                end=region.start + query_size,
                                                strand='-')
                            if br2.end <= chromosome_ends[br2.chromosome]:
                                background_regions.append(br2)

                    background_pairs = []
                    for bg_id, background_region in enumerate(background_regions):
                        background_pairs.append((bg_id, reference_region, background_region))

                    submitted_counter = 0
                    for chunk in chunks(background_pairs, threads):
                        input_queue.put([chunk, False])
                        submitted_counter += 1

                    background_ssim_results = []
                    for i in range(submitted_counter):
                        out = output_queue.get(block=True)
                        if isinstance(out, Exception):
                            raise out
                        for pair_ix, ssim, sn in out:
                            background_ssim_results.append(ssim)

                    try:
                        original_ssim, _, _ = ssim_results[pair_id]
                        if np.isnan(original_ssim):
                            ssim_zscores[pair_id] = np.nan
                            ssim_pvalues[pair_id] = np.nan
                        else:
                            ssim_mean = np.nanmean(background_ssim_results)
                            ssim_sd = np.nanstd(background_ssim_results)

                            z = (original_ssim - ssim_mean) / ssim_sd
                            p = (1 + np.sum(original_ssim <= background_ssim_results)) / len(background_ssim_results)
                            ssim_zscores[pair_id] = z
                            ssim_pvalues[pair_id] = p
                    except KeyError:
                        ssim_zscores[pair_id] = np.nan
                        ssim_pvalues[pair_id] = np.nan

                nan_counter = 0
                with open(output_file, 'w') as o:
                    o.write("ID\tSN\tssim\tz_ssim\tz_bg\tp_bg\n")
                    for pair_id, _, _ in pairs:
                        if pair_id in ssim_results:
                            ssim, sn, z_ssim = ssim_results[pair_id]
                            z = ssim_zscores[pair_id]
                            p = ssim_pvalues[pair_id]
                        else:
                            ssim, sn, z_ssim, z, p = np.nan, np.nan, np.nan, np.nan, np.nan

                        if np.isnan(ssim):
                            nan_counter += 1

                        o.write("{}\t{}\t{}\t{}\t{}\t{}\n".format(pair_id, sn, ssim, z_ssim, z, p))

                if nan_counter > 0:
                    logger.info("Could not compute similarity for {}/{} region pairs. "
                                "This typically happens if the region size is too small "
                                "(< 20 bins of the matrix)".format(nan_counter, len(pairs)))
        finally:
            for i in range(threads):
                input_queue.put(None)

            if pool is not None:
                pool.terminate()

    def oe(self, argv):
        parser = MyParser(
            description='''
            Convert a sparse Hi-C matrix to observed/expected format.''',
            formatter_class=argparse.ArgumentDefaultsHelpFormatter)

        parser.add_argument(
            'input_matrix',
            type=str,
            help='''Balanced Hi-C matrix in sparse format.
            (each line:
            <row region index> <column region index> <matrix value>)''')

        parser.add_argument(
            'regions',
            type=str,
            help='''BED file (no header) with regions corresponding to
                    the number of rows in the provided reference matrix.''')

        parser.add_argument(
            'output_matrix',
            type=str,
            help='''
            Obs/exp transformed matrix (same as input matrix format)''')

        args = parser.parse_args(argv[2:])

        import gzip
        from chess.oe import observed_expected
        from chess.helpers import is_gzipped

        output_regions, output_edges = observed_expected(
            args.regions, args.input_matrix)
        if is_gzipped(args.output_matrix):
            with gzip.open(args.output_matrix, 'w') as o:
                for source, sink, weight in output_edges:
                    line = "{}\t{}\t{:.6e}\n".format(source, sink, weight)
                    o.write(line.encode('utf-8'))
        else:
            with open(args.output_matrix, 'w') as o:
                for source, sink, weight in output_edges:
                    o.write("{}\t{}\t{:.6e}\n".format(source, sink, weight))

    def pairs(self, argv):
        parser = MyParser(
            description='''Make window pairs for CHESS genome scan.

            Write all positions of a sliding window of specified
            size with specified step in the specified genome to the outfile
            which can be directly used to run CHESS sim with the --genome-scan
            option.
            ''',
            formatter_class=argparse.ArgumentDefaultsHelpFormatter)

        parser.add_argument(
            'genome',
            type=str,
            help='''UCSC genome identifier (as recognized by pybedtools),
            or path to tab-separated chrom sizes file with columns
            <chromosome name> <chromosome size>.
            Will use the path only if no USCS entry with that name is found,
            or --file-input is specified
            ''')

        parser.add_argument(
            'window',
            type=int,
            help='''Window size in base pairs''')

        parser.add_argument(
            'step',
            type=int,
            help='''Step size in base pairs''')

        parser.add_argument(
            'output',
            type=str,
            help='''Path to output file''')

        parser.add_argument(
            '--file-input',
            action='store_true',
            default=False,
            help='''Will not check for USCS entry of genome input
            with pybedtools if set''')

        parser.add_argument(
            '--chromosome',
            type=str,
            help='''Produce window pairs only for the specified chromosome''')

        args = parser.parse_args(argv[2:])

        import pybedtools as pbt

        def load_chromosome_sizes(path):
            """Load chromosome sizes from file.

            Args:
                path (str): Path to chromosome sizes file

            Returns:
                dict: The file in dictionary representation.
            """
            chromosome_sizes = {}
            with open(path, 'r') as f:
                for line in f:
                    if line == '\n':
                        continue
                    fields = line.strip().split()
                    chromosome, size = fields[:2]
                    chromosome_sizes[chromosome] = int(size)

            return chromosome_sizes

        if args.file_input:
            chromosome_sizes = load_chromosome_sizes(args.genome)
        else:
            try:
                chromosome_sizes = pbt.chromsizes(args.genome)
                if len(chromosome_sizes) == 0:
                    raise ValueError(
                        "Genome not recognised in pybedtools.chromsizes: {}".format(
                            args.genome))
                # pbt has (start, end) as values,
                # where start is usually 0
                chromosome_sizes = {
                    k: v[1] for k, v in chromosome_sizes.items()}
            except (OSError, ValueError):
                logger.info((
                    'No entry found with pybedtools.'
                    ' Trying to read from file.'))
                chromosome_sizes = load_chromosome_sizes(args.genome)

        chromosomes = chromosome_sizes.keys()
        if args.chromosome:
            try:
                chromosome_sizes[args.chromosome]
                chromosomes = [args.chromosome]
            except KeyError:
                raise ValueError((
                    'Specified chromosome has no chromosome size entry.'))

        with open(args.output, 'w') as f:
            pair_id = 0
            for chromosome in chromosomes:
                chromosome_size = chromosome_sizes[chromosome]
                start = 1
                while start <= (chromosome_size - args.window):
                    end = start + args.window
                    f.write(
                        '\t'.join(
                            str(e) for e in [chromosome, start, end,
                                             chromosome, start, end,
                                             pair_id, '.', '+', '+'])
                        + '\n')
                    start += args.step
                    pair_id += 1

    def background(self, argv):
        parser = MyParser(
            description='''
            Generate BED file with regions to be used in CHESS background calculations.
            ''',
            formatter_class=argparse.ArgumentDefaultsHelpFormatter)

        parser.add_argument(
            'genome_or_region',
            type=str,
            help='''UCSC genome identifier (as recognized by pybedtools),
                    OR path to tab-separated chrom sizes file with columns
                    <chromosome name> <chromosome size> OR region identifier in
                    the format <chromosome>:<start>-<end>.
                    Will try options int he order listed.
                 ''')

        parser.add_argument(
            'window',
            type=int,
            help='''Window size in base pairs''')

        parser.add_argument(
            'step',
            type=int,
            help='''Step size in base pairs''')

        parser.add_argument(
            'output',
            type=str,
            help='''Path to output file''')

        parser.add_argument(
            '--strand',
            type=str,
            help='''[+/-] .Generate regions on this strand only. Default: both strands''')

        args = parser.parse_args(argv[2:])

        import pybedtools as pbt
        from chess.helpers import read_chromosome_sizes, GenomicRegion

        genome_or_region = args.genome_or_region
        window = args.window
        step = args.step
        output_file = args.output
        strand = args.strand

        if strand not in ['-', '+', None]:
            raise ValueError("--strand must be either - or +")

        try:
            chromosome_sizes = pbt.chromsizes(genome_or_region)
            if len(chromosome_sizes) == 0:
                raise ValueError("Genome not recognised in pybedtools.chromsizes: {}".format(genome_or_region))
            # pbt has (start, end) as values,
            # where start is usually 0
            chromosome_sizes = {
                k: v[1] for k, v in chromosome_sizes.items()}

            input_regions = [(chromosome, 1, end) for chromosome, end in chromosome_sizes.items()]
        except ValueError:
            logger.info((
                'No entry found with pybedtools.'
                ' Trying to read from file.'))
            if os.path.exists(genome_or_region):
                chromosome_sizes = read_chromosome_sizes(genome_or_region)
                input_regions = [(chromosome, 1, end) for chromosome, end in chromosome_sizes.items()]
            else:
                region = GenomicRegion.from_string(genome_or_region)
                input_regions = [(region.chromosome, region.start, region.end)]

        with open(output_file, 'w') as o:
            for chromosome, start, end in input_regions:
                region_ix = 0
                for s in range(start, end - window, step):
                    if strand is None or strand == '+':
                        o.write("{}\t{}\t{}\t{}\t.\t{}\n".format(
                            chromosome, s, s + window, region_ix, '+'
                        ))
                        region_ix += 1

                    if strand is None or strand == '-':
                        o.write("{}\t{}\t{}\t{}\t.\t{}\n".format(
                            chromosome, s, s + window, region_ix, '-'
                        ))
                        region_ix += 1

    def filter(self, argv):
        parser = MyParser(
            description='''
            Filter results of `chess sim` by p_bg, z_bg, z_ssim, ssim, SN value.

            You can filter by any combination of the `-p``,
            `-z` (z-score), `-s` (ssim) and `-n` (SN) flags.
            Each flag takes two arguments, where the first has to be either
            `geq` (greater-equal), `leq` (less-equal), `l` (less)
            or `g` (greater) and the second the threshold value of the filter.

            Example:

                if you want to extract regions that yielded
                a z-value greater or equal 2 in the comparison, you should run
                `chess filter <chess_results> <chess_pairs> <output_file> -z geq 2`

            If you don't specify any filters, this will simply convert
            the output to BED format.

            This produces two output BED files, one for the REF and one
            for the QRY, unless --genome-scan is specified.
            The <name> column of the BED file (4th) is used for the pair ID.
            ''',
            formatter_class=argparse.RawTextHelpFormatter)
        parser.add_argument(
            'chess_sim_results',
            type=str,
            help='''Tab seperated file (no header) with an <ID> column any of
            <p>, <z-score>, <ssim>, <SN> columns.''')
        parser.add_argument(
            'pairs',
            type=str,
            help='''Region pairs file used as an input to chess sim
            in the run that generated the chess_sim_results.''')
        parser.add_argument(
            'output',
            type=str,
            help='''Path to output file.''')
        parser.add_argument(
            '-p',
            nargs=2,
            default=[None, None],
            metavar=('mode', 'value'),
            help='''Filter by values in <p_bg> column according to mode (geq, leq, l, or g)
            and value.''')
        parser.add_argument(
            '-z',
            nargs=2,
            default=[None, None],
            metavar=('mode', 'value'),
            help='''Filter by values in <z_bg> column according to mode (geq, leq, l, or g)
            and value.''')
        parser.add_argument(
            '-s',
            nargs=2,
            default=[None, None],
            metavar=('mode', 'value'),
            help='''Filter by values in <ssim> column according to mode (geq, leq, l, or g)
            and value.''')
        parser.add_argument(
            '-n',
            nargs=2,
            default=[None, None],
            metavar=('mode', 'value'),
            help='''Filter by values in <SN> column according to mode (geq, leq, l, or g)
            and value.''')
        parser.add_argument(
            '-zs',
            nargs=2,
            default=[None, None],
            metavar=('mode', 'value'),
            help='''Filter by values in <z_ssim> column according to mode (geq, leq, l, or g)
            and value.''')
        parser.add_argument(
            '--score',
            default='ssim',
            help='''Input file column to use for the <score> column
            in the output BED.''')
        parser.add_argument(
            '--genome-scan',
            action='store_true',
            default=False,
            help='''Write only one output BED file.''')
        args = parser.parse_args(argv[2:])

        import pandas as pd
        from chess.helpers import load_pairs
        from operator import itemgetter

        results = pd.read_csv(args.chess_sim_results, sep='\t')

        if args.score not in results.columns:
            raise ValueError(('{} column chosen for the ouput'
                              'score not found in input results file.').format(
                              args.score))

        filters = {
            'p_bg': args.p,
            'z_bg': args.z,
            'z_ssim': args.zs, 
            'ssim': args.s,
            'SN': args.n
        }

        for column, (mode, value) in filters.items():
            if mode is None or value is None:
                continue
            if column not in results.columns:
                raise ValueError(('{} column not in results file.'.format(
                    column)))
            value = float(value)
            if mode == 'geq':
                results = results[results[column] >= value]
            elif mode == 'leq':
                results = results[results[column] <= value]
            elif mode == 'g':
                results = results[results[column] > value]
            elif mode == 'l':
                results = results[results[column] < value]
            else:
                raise ValueError(
                    ('Specified mode {} for {} filter is not valid.'
                     ' mode must be `geq`, `leq`, `l` or `g`.').format(
                                  mode, column))

        pairs = load_pairs(args.pairs)

        beds = {
            'REF': [],
            'QRY': []}

        for index, row in results.iterrows():
            ID, curr_score = str(int(row.ID)), row[args.score]
            refreg, qryreg = pairs[ID]
            if args.genome_scan:
                p = [[refreg, 'REF']]
            else:
                p = [[refreg, 'REF'], [qryreg, 'QRY']]
            for reg, k in p:
                curr_line = [reg.chromosome, reg.start, reg.end,
                             ID, curr_score, reg.strand, '\n']
                beds[k].append(curr_line)

        if args.genome_scan:
            beds = {'genome_scan': beds['REF']}

        for k, v in beds.items():
            v.sort(key=itemgetter(0, 1))
            with open(args.output + '_' + k, 'w') as f:
                for line in v:
                    f.write('\t'.join([str(e) for e in line]))

    def extract(self, argv):
        parser = MyParser(
            description='''
            Extract the specific regions that are different
            between the regions identified by CHESS.''',
            formatter_class=argparse.ArgumentDefaultsHelpFormatter)

        parser.add_argument(
            'pairs',
            type=str,
            help='''Region pairs that have been identified to
            be significantly different. Expected to be in
            bedpe format with chrom1, start1, ...
            corresponding to reference and chrom2, start2, ... to query.''')

        parser.add_argument(
            'reference_contacts',
            type=str,
            help='''
            Balanced contact matrix for the reference sample
            in one of the following formats:
                fanc .hic,
                juicer .hic@<resolution>,
                cooler .cool@<resolution> or .mcool@<resolution>,
                sparse format (each line:
                    <row region index> <column region index> <matrix value>).

            If the file is in sparse format, the corresponding regions
            BED file needs to be passed via --reference-regions.
            ''')

        parser.add_argument(
            '--reference-regions', dest="reference_regions",
            type=str,
            help='''BED file (no header) with regions corresponding to
            the number of rows in the provided reference matrix.''')

        parser.add_argument(
            'query_contacts',
            type=str,
            help='''
            Balanced contact matrix for the query sample
            in one of the following formats:
                fanc .hic,
                juicer .hic@<resolution>,
                cooler .cool@<resolution> or .mcool@<resolution>,
                sparse format (each line:
                    <row region index> <column region index> <matrix value>).

            If the file is in sparse format, the corresponding regions
            BED file needs to be passed via --query-regions.
            ''')

        parser.add_argument(
            '--query-regions', dest="query_regions",
            type=str,
            help='''BED file (no header) with regions corresponding to
            the number of rows in the provided query matrix.''')

        parser.add_argument(
            'out',
            type=str,
            help='''Path to output directory.''')

        parser.add_argument(
            '--windowsize',
            type=int,
            default=3,
            help='''Window size to average the bins according to their spatial closeness and their radiometric similarity, 
                     by default the windows size is the 3 x 3 bins.Higher values will average bins with larger differences.''')

        parser.add_argument(
            '--sigma-spatial', dest="sigma_spatial",
            type=int,
            default=3,
            help='''Gaussian function of the Euclidean distance between two bins and its standard deviation. 
                    By default is 3. Higher values will average bins with larger differences.''')

        parser.add_argument(
            '--size-medianfilter', dest="size_medianfilter",
            type=int,
            default=9,
            help='''Windows size used to scan and smooth the contained bins. By default it uses windows of 9. 
                    Note that higher values will smooth larger structures.''')

        parser.add_argument(
            '--closing-square', dest="closing_square",
            type=int,
            default=8,
            help='''Size of the square used to remove noise, and fill structures. By default it uses a square of 8 x 8.
                      Note that higher values will enclose larger structures and remove punctuate or more looping structures.''')

        args = parser.parse_args(argv[2:])
        from chess.helpers import load_pairs_iter, load_oe_contacts

        from chess.get_structures import extract_structures

        reference_regions_file = args.reference_regions
        reference_matrix_file = args.reference_contacts
        query_matrix_file = args.query_contacts
        query_regions_file = args.query_regions
        pairs_file = args.pairs
        output_file = args.out

        logger.debug('[MAIN]: Parameters:')
        logger.debug(args)

        logger.info('[MAIN]: Loading reference contact data')
        try:
            reference_edges, reference_region_trees, reference_regions = load_oe_contacts(
                reference_matrix_file, reference_regions_file)
        except ValueError:
            logger.error('Failed to load the reference data.')
        try:
            query_edges, query_region_trees, query_regions = load_oe_contacts(
                query_matrix_file, query_regions_file)
        except ValueError:
            logger.error('Failed to load the query data.')

        logger.info('[MAIN]: Loading region pairs')

        pairs = list(load_pairs_iter(pairs_file))
        logger.info(
            '[MAIN]: Applying image filtering to identify specific structures')
        extract_structures(
            reference_edges,
            reference_region_trees,
            query_edges,
            query_region_trees,
            pairs, output_file,
            args.windowsize,
            args.sigma_spatial,
            args.size_medianfilter,
            args.closing_square)
        logger.info('[MAIN]: Results collected')

    def crosscorrelate(self, argv):
        parser = MyParser(
            description='''
            2D crosscorrelation of the specific substructures extracted from the
            CHESS significant different regions''',
            formatter_class=argparse.ArgumentDefaultsHelpFormatter)

        parser.add_argument(
            'extracted_file',
            type=str,
            help='''Output from extract sub-command.''')

        parser.add_argument(
            'pairs',
            type=str,
            help='''Region pairs that have been identified to 
            be significantly different. Must be in bedpe format with chrom1, start1, ...
            corresponding to reference and chrom2, start2, ... to query.''')

        parser.add_argument(
            'outdir',
            type=str,
            help='''Path to output directory.''')

        args = parser.parse_args(argv[2:])

        from chess.helpers import load_pairs_iter
        from chess.cross_correlation import correlate2d

        input_file = args.extracted_file
        pairs_file = args.pairs
        output_dir = args.outdir
        pairs = list(load_pairs_iter(pairs_file))
        correlate2d(input_file, output_dir, pairs)


def chess_parser():
    usage = '''\
        chess <command> [options]

        Commands:
            sim             Calculate structural similarity
            oe              Transform a Hi-C matrix to an observed/expected matrix
            pairs           Make window pairs for chess genome scan
            background   Generate background region BED files
            filter          Filter output of chess sim and save as BED file
            extract         Extract specific regions that are significantly different
            crosscorrelate  Get structural clusters from the extracted submatrices

        Run chess <command> -h for help on a specific command.
        '''
    parser = argparse.ArgumentParser(
        description="""
        CHESS: Compare Hi-C Experiments using Structural Similarity""",
        usage=textwrap.dedent(usage)
    )

    parser.add_argument(
        '--version', dest='print_version',
        action='store_true',
        help='''Print version information'''
    )
    parser.set_defaults(print_version=False)

    parser.add_argument(
        '--verbose', '-v', dest='verbosity',
        action='count',
        default=0,
        help='''Set verbosity level: Can be chained like
        '-vvv' to increase verbosity. Default is to show
        errors, warnings, and info messages (same as '-vv').
        '-v' shows only errors and warnings, '-vvv' shows errors, warnings,
        info, and debug messages in addition.'''
    )

    parser.add_argument(
        '-l', '--log-file', dest='log_file',
        help='''Path to file in which to save log.'''
    )

    parser.add_argument('command', nargs='?', help='Subcommand to run')

    return parser


class MyParser(argparse.ArgumentParser):
    def error(self, message):
        sys.stderr.write('error: %s\n' % message)
        self.print_help()
        sys.exit(2)


if __name__ == '__main__':
    Chess()
