#!/usr/bin/env python

import os
import sys
import argparse
import textwrap
from tqdm import tqdm as progressbar


import logging
logger = logging.getLogger('')


class Chess(object):

    def __init__(self):
        parser = chess_parser()

        flag_increments = {
            '-l': 2, '--log-file': 2,
        }

        option_ix = 1
        while (option_ix < len(sys.argv) and
               sys.argv[option_ix].startswith('-')):
            if sys.argv[option_ix] in flag_increments:
                option_ix += flag_increments[sys.argv[option_ix]]
            else:
                option_ix += 1

        # parse_args defaults to [1:] for args, but you need to
        # exclude the rest of the args too, or validation will fail
        args = parser.parse_args(sys.argv[1:option_ix+1])

        # configure logger
        if args.verbosity == 1:
            log_level = logging.WARN
        elif args.verbosity == 2:
            log_level = logging.INFO
        elif args.verbosity > 2:
            log_level = logging.DEBUG
        else:
            log_level = logging.INFO
        logger.setLevel(log_level)

        if args.log_file is None:
            sh = logging.StreamHandler()
            sh_formatter = logging.Formatter(
                "%(asctime)s %(levelname)s %(message)s")
            sh.setFormatter(sh_formatter)
            sh.setLevel(log_level)
            logger.addHandler(sh)
        else:
            log_file = os.path.expanduser(args.log_file)
            fh = logging.FileHandler(log_file, mode='a')
            formatter = logging.Formatter(
                "%(asctime)s %(levelname)s %(message)s")
            fh.setFormatter(formatter)
            fh.setLevel(log_level)
            logger.addHandler(fh)

        # get version info
        if args.print_version:
            import chess
            print(chess.__version__)
            exit()

        if args.command is None or not hasattr(self, args.command):
            print('Unrecognized command')
            parser.print_help()
            exit(1)

        # echo parameters back to user
        command = " ".join(sys.argv)
        logger.info("Running '{}'".format(command))

        # use dispatch pattern to invoke method with same name
        getattr(self, args.command)([sys.argv[0]] + sys.argv[option_ix:])

        # echo parameters back to user
        logger.info("Finished '{}'".format(" ".join(sys.argv)))

    def sim(self, argv):
        parser = MyParser(
            description='''
            Compare structures between pairs of Hi-C matrices using the
            structural similarity index. Compute p-value and z-value after
            obtaining a background distribution of similarity values of the
            reference in the pair to the rest of the queries source genome.''',
            formatter_class=argparse.ArgumentDefaultsHelpFormatter)

        parser.add_argument(
            'reference_sparse',
            type=str,
            help='''Balanced or observed / expected Hi-C matrix in sparse format
            for reference sample.
            (each line:
                <row region index> <column region index> <matrix value>)''')

        parser.add_argument(
            'reference_regions',
            type=str,
            help='''BED file (no header) with regions corresponding to
            the number of rows in the provided reference matrix.''')

        parser.add_argument(
            'query_sparse',
            type=str,
            help='''Balanced or observed / expected Hi-C matrix in sparse format
            for query sample.
            (each line:
                <row region index> <column region index> <matrix value>)''')

        parser.add_argument(
            'query_regions',
            type=str,
            help='''BED file (no header) with regions corresponding to
            the number of rows in the provided query matrix.''')

        parser.add_argument(
            'pairs',
            type=str,
            help='''Region pairs to compare.
            Must be in bedpe format with chrom1, start1, ...
            corresponding to reference and chrom2, start2, ... to query.''')

        parser.add_argument(
            'out',
            type=str,
            help='''Path to outfile.''')

        parser.add_argument(
            '--background-regions', dest='background_regions',
            type=str,
            help='''BED file with regions to be used for background calculations.
                            If provided, CHESS will generate Z-scores and P-values for 
                            similarities.''')

        parser.add_argument(
            '--background-query', dest='background_query',
            default=False,
            action='store_true',
            help='Use every region of the same size as the '
                 'reference from the query genome as background. '
                 'Useful, for example, as background for inter-species '
                 'comparisons.')

        parser.add_argument(
            '-p', dest='threads',
            type=int,
            default=1,
            help='''Number of cores to use.''')

        parser.add_argument(
            '--keep-unmappable-bins',
            action='store_true',
            default=False,
            help='''Disable deletion of unmappable bins.''')

        parser.add_argument(
            '--mappability-cutoff',
            type=float,
            default=0.1,
            help='''Low pass threshold for fraction of unmappable bins.
                    Matrices with a higher content of unmappable bins
                    will not be considered.
                    Unmappable bins will be deleted from matrices
                    passing the filter.''')

        parser.add_argument(
            '-r', '--relative-windowsize',
            type=float,
            default=1,
            help='''Relative window size value
            for the win_size param in the ssim function.
            Fraction of the matrix size.''')

        parser.add_argument(
            '-a', '--absolute-windowsize',
            type=int,
            default=None,
            help='''Absolute window size value in bins
            for the win_size param in the ssim function.
            Overwrites -r.''')

        parser.add_argument(
            '--converted-input',
            action='store_true',
            default=False,
            help='''Use if input sparse matrices are already observed / expected
            matrices. Will skip transformation.''')

        parser.add_argument(
            '--limit-background',
            action='store_true',
            default=False,
            help='Restrict background computation to the syntenic / paired '
                 'chromosome as indicated in the pairs file.')

        args = parser.parse_args(argv[2:])

        from chess.helpers import load_pairs_iter, load_regions, \
            region_interval_trees, GenomicRegion, \
            edges_dict, edges_from_sparse_matrix, chunks
        from chess.sim import chunk_comparison_worker
        from chess.oe import observed_expected
        import multiprocessing as mp
        import numpy as np
        from collections import defaultdict

        reference_regions_file = args.reference_regions
        reference_matrix_file = args.reference_sparse
        query_regions_file = args.query_regions
        query_matrix_file = args.query_sparse
        pairs_file = args.pairs
        threads = args.threads
        is_converted = args.converted_input
        background_regions_file = args.background_regions
        background_query = args.background_query
        output_file = args.out
        limit_background = args.limit_background

        logger.debug('[MAIN]: Parameters:')
        logger.debug(args)

        logger.info('[MAIN]: Loading and indexing Hi-C regions')
        reference_regions, reference_ix_converter, _ = load_regions(reference_regions_file)
        reference_region_trees = region_interval_trees(reference_regions)
        query_regions, query_ix_converter, _ = load_regions(query_regions_file)
        query_region_trees = region_interval_trees(query_regions)

        logger.info('[MAIN]: Loading Hi-C contacts')
        if not is_converted:
            logger.info('[MAIN]: Converting to observed / expected')
            _, reference_oe = observed_expected(reference_regions_file, reference_matrix_file)
            reference_edges = edges_dict(reference_oe)

            _, query_oe = observed_expected(query_regions_file, query_matrix_file)
            query_edges = edges_dict(query_oe)
        else:
            reference_edges = edges_dict(edges_from_sparse_matrix(reference_matrix_file,
                                                                  reference_ix_converter))
            query_edges = edges_dict(edges_from_sparse_matrix(query_matrix_file,
                                                              query_ix_converter))

        logger.info('[MAIN]: Loading region pairs')
        pairs = list(load_pairs_iter(pairs_file))

        logger.info("Launching workers")
        m = mp.Manager()
        input_queue = m.Queue()
        output_queue = m.Queue()

        pool = None
        try:
            pool = mp.Pool(threads, chunk_comparison_worker,
                           (input_queue, output_queue,
                            reference_edges, reference_region_trees,
                            query_edges, query_region_trees,
                            20, args.keep_unmappable_bins, args.absolute_windowsize,
                            args.relative_windowsize, args.mappability_cutoff))

            logger.info("Submitting pairs for comparison")
            submitted_counter = 0
            for chunk in chunks(pairs, threads):
                input_queue.put([chunk, True])
                submitted_counter += 1

            ssim_results = dict()
            all_ssim = []
            for i in range(submitted_counter):
                out = output_queue.get(block=True)
                if isinstance(out, Exception):
                    raise out
                for pair_ix, ssim, sn in out:
                    ssim_results[pair_ix] = [ssim, sn, np.nan]
                    if not np.isnan(ssim):
                        all_ssim.append(ssim)

            # calculate ssim z-score
            all_ssim_mean = np.nanmean(all_ssim)
            all_ssim_sd = np.nanstd(all_ssim)
            for pair_id in list(ssim_results.keys()):
                ssim, sn, z = ssim_results[pair_id]
                if np.isfinite(ssim):
                    ssim_results[pair_id][2] = (ssim - all_ssim_mean) / all_ssim_sd

            # no background calculations
            if background_regions_file is None and not background_query:
                nan_counter = 0
                with open(output_file, 'w') as o:
                    o.write("ID\tSN\tssim\tz_ssim\n")
                    for pair_id, _, _ in pairs:
                        if pair_id in ssim_results:
                            ssim, sn, z_ssim = ssim_results[pair_id]
                        else:
                            ssim, sn, z_ssim = np.nan, np.nan, np.nan

                        if np.isnan(ssim):
                            nan_counter += 1

                        o.write("{}\t{}\t{}\t{}\n".format(pair_id, sn, ssim, z_ssim))

                if nan_counter > 0:
                    logger.info("Could not compute similarity for {} region pairs. "
                                "This typically happens if the region size is too small "
                                "(< 20 bins of the matrix)".format(nan_counter))

            # background calculations
            else:
                logger.info("[MAIN]: Running background calculations")
                if background_regions_file is not None:
                    background_regions_file = os.path.expanduser(background_regions_file)
                    background_regions, _, _ = load_regions(background_regions_file, ignore_ix=True)
                else:
                    background_regions = None
                    logger.info("Generating background regions from query")

                # process each reference region separately
                ssim_zscores = dict()
                ssim_pvalues = dict()
                for pair_id, reference_region, query_region in progressbar(pairs):
                    if background_query:
                        chromosome_ends = defaultdict(int)
                        for r in query_regions:
                            chromosome_ends[r.chromosome] = max(chromosome_ends[r.chromosome], r.end)

                        background_regions = []
                        query_size = query_region.end - query_region.start
                        # generate background regions from query genome
                        for region in query_regions:
                            if limit_background and region.chromosome != query_region.chromosome:
                                continue

                            # forward
                            br1 = GenomicRegion(chromosome=region.chromosome,
                                                start=region.start,
                                                end=region.start + query_size,
                                                strand='+')
                            if br1.end <= chromosome_ends[br1.chromosome]:
                                background_regions.append(br1)

                            # reverse
                            br2 = GenomicRegion(chromosome=region.chromosome,
                                                start=region.start,
                                                end=region.start + query_size,
                                                strand='-')
                            if br2.end <= chromosome_ends[br2.chromosome]:
                                background_regions.append(br2)

                    background_pairs = []
                    for bg_id, background_region in enumerate(background_regions):
                        background_pairs.append((bg_id, reference_region, background_region))

                    submitted_counter = 0
                    for chunk in chunks(background_pairs, threads):
                        input_queue.put([chunk, False])
                        submitted_counter += 1

                    background_ssim_results = []
                    for i in range(submitted_counter):
                        out = output_queue.get(block=True)
                        if isinstance(out, Exception):
                            raise out
                        for pair_ix, ssim, sn in out:
                            background_ssim_results.append(ssim)

                    try:
                        original_ssim, _, _ = ssim_results[pair_id]
                        if np.isnan(original_ssim):
                            ssim_zscores[pair_id] = np.nan
                            ssim_pvalues[pair_id] = np.nan
                        else:
                            ssim_mean = np.nanmean(background_ssim_results)
                            ssim_sd = np.nanstd(background_ssim_results)

                            z = (original_ssim - ssim_mean) / ssim_sd
                            p = (1 + np.sum(original_ssim <= background_ssim_results)) / len(background_ssim_results)
                            ssim_zscores[pair_id] = z
                            ssim_pvalues[pair_id] = p
                    except KeyError:
                        ssim_zscores[pair_id] = np.nan
                        ssim_pvalues[pair_id] = np.nan

                nan_counter = 0
                with open(output_file, 'w') as o:
                    o.write("ID\tSN\tssim\tz_ssim\tz_bg\tp_bg\n")
                    for pair_id, _, _ in pairs:
                        if pair_id in ssim_results:
                            ssim, sn, z_ssim = ssim_results[pair_id]
                            z = ssim_zscores[pair_id]
                            p = ssim_pvalues[pair_id]
                        else:
                            ssim, sn, z_ssim, z, p = np.nan, np.nan, np.nan, np.nan, np.nan

                        if np.isnan(ssim):
                            nan_counter += 1

                        o.write("{}\t{}\t{}\t{}\t{}\t{}\n".format(pair_id, sn, ssim, z_ssim, z, p))

                if nan_counter > 0:
                    logger.info("Could not compute similarity for {}/{} region pairs. "
                                "This typically happens if the region size is too small "
                                "(< 20 bins of the matrix)".format(nan_counter, len(pairs)))
        finally:
            for i in range(threads):
                input_queue.put(None)

            if pool is not None:
                pool.terminate()

    def oe(self, argv):
        parser = MyParser(
            description='''
            Convert a sparse Hi-C matrix to observed/expected format.''',
            formatter_class=argparse.ArgumentDefaultsHelpFormatter)

        parser.add_argument(
            'input_matrix',
            type=str,
            help='''Balanced Hi-C matrix in sparse format.
            (each line:
            <row region index> <column region index> <matrix value>)''')

        parser.add_argument(
            'regions',
            type=str,
            help='''BED file (no header) with regions corresponding to
                    the number of rows in the provided reference matrix.''')

        parser.add_argument(
            'output_matrix',
            type=str,
            help='''
            Obs/exp transformed matrix (same as input matrix format)''')

        args = parser.parse_args(argv[2:])

        import gzip
        from chess.oe import observed_expected
        from chess.helpers import is_gzipped

        output_regions, output_edges = observed_expected(
            args.regions, args.input_matrix)
        if is_gzipped(args.output_matrix):
            with gzip.open(args.output_matrix, 'w') as o:
                for source, sink, weight in output_edges:
                    line = "{}\t{}\t{:.6e}\n".format(source, sink, weight)
                    o.write(line.encode('utf-8'))
        else:
            with open(args.output_matrix, 'w') as o:
                for source, sink, weight in output_edges:
                    o.write("{}\t{}\t{:.6e}\n".format(source, sink, weight))

    def pairs(self, argv):
        parser = MyParser(
            description='''Make window pairs for CHESS genome scan.

            Write all positions of a sliding window of specified
            size with specified step in the specified genome to the outfile
            which can be directly used to run CHESS sim with the --genome-scan
            option.
            ''',
            formatter_class=argparse.ArgumentDefaultsHelpFormatter)

        parser.add_argument(
            'genome',
            type=str,
            help='''UCSC genome identifier (as recognized by pybedtools),
            or path to tab-separated chrom sizes file with columns
            <chromosome name> <chromosome size>.
            Will use the path only if no USCS entry with that name is found,
            or --file-input is specified
            ''')

        parser.add_argument(
            'window',
            type=int,
            help='''Window size in base pairs''')

        parser.add_argument(
            'step',
            type=int,
            help='''Step size in base pairs''')

        parser.add_argument(
            'output',
            type=str,
            help='''Path to output file''')

        parser.add_argument(
            '--file-input',
            action='store_true',
            default=False,
            help='''Will not check for USCS entry of genome input
            with pybedtools if set''')

        parser.add_argument(
            '--chromosome',
            type=str,
            help='''Produce window pairs only for the specified chromosome''')

        args = parser.parse_args(argv[2:])

        import pybedtools as pbt

        def load_chromosome_sizes(path):
            """Load chromosome sizes from file.
            :param path: Path to chromosome sizes file
            :returns: The file in dictionary representation.
            """
            chromosome_sizes = {}
            with open(path, 'r') as f:
                for line in f:
                    fields = line.strip().split()
                    chromosome, size = fields[:2]
                    chromosome_sizes[chromosome] = int(size)

            return chromosome_sizes

        if args.file_input:
            chromosome_sizes = load_chromosome_sizes(args.genome)
        else:
            try:
                chromosome_sizes = pbt.chromsizes(args.genome)
                if len(chromosome_sizes) == 0:
                    raise ValueError("Genome not recognised in pybedtools.chromsizes: {}".format(args.genome))
                # pbt has (start, end) as values,
                # where start is usually 0
                chromosome_sizes = {
                    k: v[1] for k, v in chromosome_sizes.items()}
            except ValueError:
                logger.info((
                    'No entry found with pybedtools.'
                    ' Trying to read from file.'))
                chromosome_sizes = load_chromosome_sizes(args.genome)

        chromosomes = chromosome_sizes.keys()
        if args.chromosome:
            try:
                chromosome_sizes[args.chromosome]
                chromosomes = [args.chromosome]
            except KeyError:
                raise ValueError((
                    'Specified chromosome has no chromosome size entry.'))

        with open(args.output, 'w') as f:
            pair_id = 0
            for chromosome in chromosomes:
                chromosome_size = chromosome_sizes[chromosome]
                start = 1
                while start <= (chromosome_size - args.window):
                    end = start + args.window
                    f.write(
                        '\t'.join(
                            str(e) for e in [chromosome, start, end,
                                             chromosome, start, end,
                                             pair_id, '.', '+', '+'])
                        + '\n')
                    start += args.step
                    pair_id += 1

    def background(self, argv):
        parser = MyParser(
            description='''
            Generate BED file with regions to be used in CHESS background calculations.
            ''',
            formatter_class=argparse.ArgumentDefaultsHelpFormatter)

        parser.add_argument(
            'genome_or_region',
            type=str,
            help='''UCSC genome identifier (as recognized by pybedtools),
                    OR path to tab-separated chrom sizes file with columns
                    <chromosome name> <chromosome size> OR region identifier in
                    the format <chromosome>:<start>-<end>.
                    Will try options int he order listed.
                 ''')

        parser.add_argument(
            'window',
            type=int,
            help='''Window size in base pairs''')

        parser.add_argument(
            'step',
            type=int,
            help='''Step size in base pairs''')

        parser.add_argument(
            'output',
            type=str,
            help='''Path to output file''')

        parser.add_argument(
            '--strand',
            type=str,
            help='''[+/-] .Generate regions on this strand only. Default: both strands''')

        args = parser.parse_args(argv[2:])

        import pybedtools as pbt
        from chess.helpers import read_chromosome_sizes, GenomicRegion

        genome_or_region = args.genome_or_region
        window = args.window
        step = args.step
        output_file = args.output
        strand = args.strand

        if strand not in ['-', '+', None]:
            raise ValueError("--strand must be either - or +")

        try:
            chromosome_sizes = pbt.chromsizes(genome_or_region)
            if len(chromosome_sizes) == 0:
                raise ValueError("Genome not recognised in pybedtools.chromsizes: {}".format(genome_or_region))
            # pbt has (start, end) as values,
            # where start is usually 0
            chromosome_sizes = {
                k: v[1] for k, v in chromosome_sizes.items()}

            input_regions = [(chromosome, 1, end) for chromosome, end in chromosome_sizes.items()]
        except ValueError:
            logger.info((
                'No entry found with pybedtools.'
                ' Trying to read from file.'))
            if os.path.exists(genome_or_region):
                chromosome_sizes = read_chromosome_sizes(genome_or_region)
                input_regions = [(chromosome, 1, end) for chromosome, end in chromosome_sizes.items()]
            else:
                region = GenomicRegion.from_string(genome_or_region)
                input_regions = [(region.chromosome, region.start, region.end)]

        with open(output_file, 'w') as o:
            for chromosome, start, end in input_regions:
                region_ix = 0
                for s in range(start, end - window, step):
                    if strand is None or strand == '+':
                        o.write("{}\t{}\t{}\t{}\t.\t{}\n".format(
                            chromosome, s, s + window, region_ix, '+'
                        ))
                        region_ix += 1

                    if strand is None or strand == '-':
                        o.write("{}\t{}\t{}\t{}\t.\t{}\n".format(
                            chromosome, s, s + window, region_ix, '-'
                        ))
                        region_ix += 1

    def filter(self, argv):
        parser = MyParser(
            description='''
            Filter results of `chess sim` by p_bg, z_bg, z_ssim, ssim, SN value.

            You can filter by any combination of the `-p``,
            `-z` (z-score), `-s` (ssim) and `-n` (SN) flags.
            Each flag takes two arguments, where the first has to be either
            `geq` (greater-equal), `leq` (less-equal), `l` (less)
            or `g` (greater) and the second the threshold value of the filter.

            Example:

                if you want to extract regions that yielded
                a z-value greater or equal 2 in the comparison, you should run
                `chess filter <chess_results> <chess_pairs> <output_file> -z geq 2`

            If you don't specify any filters, this will simply convert
            the output to BED format.

            This produces two output BED files, one for the REF and one
            for the QRY, unless --genome-scan is specified.
            The <name> column of the BED file (4th) is used for the pair ID.
            ''',
            formatter_class=argparse.RawTextHelpFormatter)
        parser.add_argument(
            'chess_sim_results',
            type=str,
            help='''Tab seperated file (no header) with an <ID> column any of
            <p>, <z-score>, <ssim>, <SN> columns.''')
        parser.add_argument(
            'pairs',
            type=str,
            help='''Region pairs file used as an input to chess sim
            in the run that generated the chess_sim_results.''')
        parser.add_argument(
            'output',
            type=str,
            help='''Path to output file.''')
        parser.add_argument(
            '-p',
            nargs=2,
            default=[None, None],
            metavar=('mode', 'value'),
            help='''Filter by values in <p_bg> column according to mode (geq, leq, l, or g)
            and value.''')
        parser.add_argument(
            '-z',
            nargs=2,
            default=[None, None],
            metavar=('mode', 'value'),
            help='''Filter by values in <z_bg> column according to mode (geq, leq, l, or g)
            and value.''')
        parser.add_argument(
            '-s',
            nargs=2,
            default=[None, None],
            metavar=('mode', 'value'),
            help='''Filter by values in <ssim> column according to mode (geq, leq, l, or g)
            and value.''')
        parser.add_argument(
            '-n',
            nargs=2,
            default=[None, None],
            metavar=('mode', 'value'),
            help='''Filter by values in <SN> column according to mode (geq, leq, l, or g)
            and value.''')
        parser.add_argument(
            '-zs',
            nargs=2,
            default=[None, None],
            metavar=('mode', 'value'),
            help='''Filter by values in <z_ssim> column according to mode (geq, leq, l, or g)
            and value.''')
        parser.add_argument(
            '--score',
            default='ssim',
            help='''Input file column to use for the <score> column
            in the output BED.''')
        parser.add_argument(
            '--genome-scan',
            action='store_true',
            default=False,
            help='''Write only one output BED file.''')
        args = parser.parse_args(argv[2:])

        import pandas as pd
        from chess.helpers import load_pairs
        from operator import itemgetter

        results = pd.read_csv(args.chess_sim_results, sep='\t')

        if args.score not in results.columns:
            raise ValueError(('{} column chosen for the ouput'
                              'score not found in input results file.').format(
                              args.score))

        filters = {
            'p_bg': args.p,
            'z_bg': args.z,
            'z_ssim': args.zs, 
            'ssim': args.s,
            'SN': args.n
        }

        for column, (mode, value) in filters.items():
            if mode is None or value is None:
                continue
            if column not in results.columns:
                raise ValueError(('{} column not in results file.'.format(
                    column)))
            value = float(value)
            if mode == 'geq':
                results = results[results[column] >= value]
            elif mode == 'leq':
                results = results[results[column] <= value]
            elif mode == 'g':
                results = results[results[column] > value]
            elif mode == 'l':
                results = results[results[column] < value]
            else:
                raise ValueError(
                    ('Specified mode {} for {} filter is not valid.'
                     ' mode must be `geq`, `leq`, `l` or `g`.').format(
                                  mode, column))

        pairs = load_pairs(args.pairs)

        beds = {
            'REF': [],
            'QRY': []}

        for index, row in results.iterrows():
            ID, curr_score = str(int(row.ID)), row[args.score]
            refreg, qryreg = pairs[ID]
            if args.genome_scan:
                p = [[refreg, 'REF']]
            else:
                p = [[refreg, 'REF'], [qryreg, 'QRY']]
            for reg, k in p:
                curr_line = [reg.chromosome, reg.start, reg.end,
                             ID, curr_score, reg.strand, '\n']
                beds[k].append(curr_line)

        if args.genome_scan:
            beds = {'genome_scan': beds['REF']}

        for k, v in beds.items():
            v.sort(key=itemgetter(0, 1))
            with open(args.output + '_' + k, 'w') as f:
                for line in v:
                    f.write('\t'.join([str(e) for e in line]))

    def extract(self, argv):
        parser = MyParser(
            description='''
            Extract the specific regions that are different 
            between the regions identified by CHESS.''',
            formatter_class=argparse.ArgumentDefaultsHelpFormatter)

        parser.add_argument(
            'pairs',
            type=str,
            help='''Region pairs that have been identified to 
            be significantly different. Must be in bedpe format with chrom1, start1, ...
            corresponding to reference and chrom2, start2, ... to query.''')

        parser.add_argument(
            'reference_sparse',
            type=str,
            help='''Balanced Hi-C matrix in sparse format
                    for reference sample.
                    (each line:
                        <row region index> <column region index> <matrix value>)''')
        parser.add_argument(
            'reference_regions',
            type=str,
            help='''BED file (no header) with regions corresponding to
                    the number of rows in the provided reference matrix.''')

        parser.add_argument(
            'query_sparse',
            type=str,
            help='''Balanced Hi-C matrix in sparse format
                    for query sample.
                    (each line:
                        <row region index> <column region index> <matrix value>)''')

        parser.add_argument(
            'query_regions',
            type=str,
            help='''BED file (no header) with regions corresponding to
                    the number of rows in the provided query matrix.''')

        parser.add_argument(
            'out',
            type=str,
            help='''Path to outfile.''')

        args = parser.parse_args(argv[2:])
        from chess.helpers import load_pairs_iter, load_regions, \
            region_interval_trees, edges_dict
        from chess.get_structures import extract_structures, get_sparse_matrix

        reference_regions_file = args.reference_regions
        reference_matrix_file = args.reference_sparse
        query_matrix_file = args.query_sparse
        query_regions_file = args.query_regions
        pairs_file = args.pairs
        output_file = args.out

        logger.debug('[MAIN]: Parameters:')
        logger.debug(args)
        logger.info('[MAIN]: Loading and indexing Hi-C regions')
        reference_regions, reference_ix_converter, _ = load_regions(reference_regions_file)
        reference_region_trees = region_interval_trees(reference_regions)
        query_regions, query_ix_converter, _ = load_regions(query_regions_file)
        query_region_trees = region_interval_trees(query_regions)

        logger.info('[MAIN]: Loading Hi-C contacts')
        _, reference_sparse = get_sparse_matrix(reference_regions_file, reference_matrix_file)
        reference_edges = edges_dict(reference_sparse)
        _, query_sparse = get_sparse_matrix(query_regions_file, query_matrix_file)
        query_edges = edges_dict(query_sparse)

        logger.info('[MAIN]: Loading region pairs')

        pairs = list(load_pairs_iter(pairs_file))
        logger.info(
            '[MAIN]: Applying image filtering to identify specific structures')
        extract_structures(
            reference_edges,
            reference_region_trees,
            query_edges,
            query_region_trees,
            pairs, output_file)
        logger.info('[MAIN]: Results collected')

    def crosscorrelate(self, argv):
        parser = MyParser(
            description='''
            2D crosscorrelation of the specific substructures extracted from the
            CHESS significant different regions''',
            formatter_class=argparse.ArgumentDefaultsHelpFormatter)

        parser.add_argument(
            'extracted_file',
            type=str,
            help='''Output from extract sub-command.''')

        parser.add_argument(
            'pairs',
            type=str,
            help='''Region pairs that have been identified to 
            be significantly different. Must be in bedpe format with chrom1, start1, ...
            corresponding to reference and chrom2, start2, ... to query.''')

        parser.add_argument(
            'out',
            type=str,
            help='''Path to outfile.''')

        args = parser.parse_args(argv[2:])

        from chess.helpers import load_pairs_iter
        from chess.cross_correlation import correlate2d

        input_file = args.extracted_file
        pairs_file = args.pairs
        output_file = args.out
        pairs = list(load_pairs_iter(pairs_file))
        correlate2d(input_file, output_file, pairs)


def chess_parser():
    usage = '''\
        chess <command> [options]

        Commands:
            sim             Calculate structural similarity
            oe              Transform a Hi-C matrix to an observed/expected matrix
            pairs           Make window pairs for chess genome scan
            background   Generate background region BED files
            filter          Filter output of chess sim and save as BED file
            extract         Extract specific regions that are significantly different
            crosscorrelate  Get structural clusters from the extracted submatrices

        Run chess <command> -h for help on a specific command.
        '''
    parser = argparse.ArgumentParser(
        description="""
        CHESS: Compare Hi-C Experiments using Structural Similarity""",
        usage=textwrap.dedent(usage)
    )

    parser.add_argument(
        '--version', dest='print_version',
        action='store_true',
        help='''Print version information'''
    )
    parser.set_defaults(print_version=False)

    parser.add_argument(
        '--verbose', '-v', dest='verbosity',
        action='count',
        default=0,
        help='''Set verbosity level: Can be chained like
        '-vvv' to increase verbosity. Default is to show
        errors, warnings, and info messages (same as '-vv').
        '-v' shows only errors and warnings, '-vvv' shows errors, warnings,
        info, and debug messages in addition.'''
    )

    parser.add_argument(
        '-l', '--log-file', dest='log_file',
        help='''Path to file in which to save log.'''
    )

    parser.add_argument('command', nargs='?', help='Subcommand to run')

    return parser


class MyParser(argparse.ArgumentParser):
    def error(self, message):
        sys.stderr.write('error: %s\n' % message)
        self.print_help()
        sys.exit(2)


if __name__ == '__main__':
    Chess()
