#!/usr/bin/env python

# Created on Mon Nov 03 19:32:33 2014

# Author: XiaoTao Wang
# Organization: HuaZhong Agricultural University

## Required modules
from __future__ import division
import argparse, sys, logging, logging.handlers, tadlib

currentVersion = tadlib.__version__

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(usage = '%(prog)s [options]\n\n'
                                     'CALFEA -- CALculate FEAture for TADs',
                                    description = '''The calculation is implemented by identifying
                                    long-range significant interactions for each TAD and looking for
                                    the aggregation patterns thereof. For more details, please refer
                                    to online TADLib documentation.''',
                                    formatter_class = argparse.ArgumentDefaultsHelpFormatter)

    # Version
    parser.add_argument('-v', '--version', action = 'version',
                        version = ' '.join(['%(prog)s', currentVersion]),
                        help = 'Print version number and exit')

    # Output
    parser.add_argument('-O', '--output',
                        help = 'Output file name.')
    
    ## Argument Groups
    group_1 = parser.add_argument_group(title = 'Relate to the input:')
    group_1.add_argument('-p', '--path', default = '.',
                         help = 'Path to the cool URI')
    group_1.add_argument('-t', '--tad-file', help = 'Path to the TAD file.')
    
    ## About the algorithm
    group_3 = parser.add_argument_group(title = 'Feature calculation')
    group_3.add_argument('--pw', type = int, default = 2, help = '''Width of the interaction
                         region surrounding the peak. According to experience, we set it
                         to 1 at 20 kb, 2 at 10 kb, and 4 at 5 kb.''')
    group_3.add_argument('--ww', type = int, default = 5, help = '''Width of the donut region
                         Set it to 3 at 20 kb, 5 at 10 kb, and 7 at 5 kb.''')
    group_3.add_argument('--top', type = float, default = 0.7, help = 'Parameter for noisy '
                         'interaction filtering. By default, 30 percent noisy interactions'
                         ' will be eliminated.')
    group_3.add_argument('--ratio', type = float, default = 0.05, help = 'Specifies the sample'
                         ' ratio of significant interactions for TAD.')
    group_3.add_argument('--gap', type = float, default = 0.2, help = 'Maximum gap ratio.')
    
    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands

## Pipeline
def pipe():
    """The Main pipeline for CALFEA.
    
    """
    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        ## Root Logger Configuration
        logger = logging.getLogger()
        logger.setLevel(10)
        console = logging.StreamHandler()
        filehandler = logging.handlers.RotatingFileHandler('calfea.log',
                                                           maxBytes = 30000,
                                                           backupCount = 5)
        # Set level for Handlers
        console.setLevel('INFO')
        filehandler.setLevel('DEBUG')
        # Customizing Formatter
        formatter = logging.Formatter(fmt = '%(name)-14s %(levelname)-7s @ %(asctime)s: %(message)s',
                                      datefmt = '%m/%d/%y %H:%M:%S')
        ## Unified Formatter
        console.setFormatter(formatter)
        filehandler.setFormatter(formatter)
        # Add Handlers
        logger.addHandler(console)
        logger.addHandler(filehandler)
        ## Logging for argument setting
        arglist = ['# ARGUMENT LIST:',
                   '# output file name = {0}'.format(args.output),
                   '# Hi-C path = {0}'.format(args.path),
                   '# TAD source file = {0}'.format(args.tad_file),
                   '# Peak window width = {0}'.format(args.pw),
                   '# Donut width = {0}'.format(args.ww),
                   '# Noise filtering ratio = {0}'.format((1 - args.top)),
                   '# Significant interaction ratio = {0}'.format(args.ratio),
                   '# Maximum gap ratio = {0}'.format(args.gap)
                   ]
        
        argtxt = '\n'.join(arglist)
        logger.info('\n' + argtxt)
                 
        from tadlib.calfea import analyze
        import cooler
        import numpy as np
        
        logger.info('Read Hi-C data ...')
        workInters = cooler.Cooler(args.path)
        # Load External TAD File, Columns 0,1,2
        logger.info('Read external TAD data ...')
        workTAD = analyze.load_TAD(args.tad_file)
    
        # Header
        H = ['ChromID', 'Start', 'End', 'AP', 'Gap-Ratio']
        outstream = open(args.output, 'w')
        outstream.write('\t'.join(H) + '\n')
        logger.info('Calculate feature for each TAD ...')
        for c in workInters.chromnames:
            logger.info('Chromosome %s', c)
            d_data = workTAD[workTAD['chr'] == c]
            for d in d_data:
                # Interaction Matrix
                if 'weight' in workInters.bins().keys():
                    matrix = workInters.matrix(balance=True, sparse=False).fetch((c, d['start'], min(d['end'], workInters.chromsizes[c])))
                else:
                    matrix = workInters.matrix(balance=False, sparse=False).fetch((c, d['start'], min(d['end'], workInters.chromsizes[c])))
                
                matrix[np.isnan(matrix)] = 0
                
                newM, _ = analyze.manipulation(matrix)
                if len(matrix) > 0:
                    # Ratio of Gaps (Vacant Rows or Columns)
                    gaps = 1 - len(newM) / len(matrix)
                else:
                    gaps = 1.0
                
                if (gaps < args.gap) and (newM.shape[0] > (args.ww * 2 + 1)):
                    workCore = analyze.Core(matrix)
                    # Extract Long-Range Interactions
                    workCore.longrange(pw = args.pw, ww = args.ww, top = args.top, ratio = args.ratio)
                    # Clusters
                    workCore.DBSCAN()
                    # Feature
                    workCore.gdensity()
                    # Line by Line
                    wL = [c, str(d['start']), str(d['end']), str(workCore.gden), str(gaps)]
                else:
                    # Bad Domain!
                    wL = [c, str(d['start']), str(d['end']), '-1', str(gaps)]
                outstream.write('\t'.join(wL) + '\n')
        logger.info('Done!')
        logger.info('Write results to %s ...', args.output)
        outstream.flush()
        outstream.close()
        logger.info('Done!\n')

if __name__ == '__main__':
    pipe()
