#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Created on Sun May 22 14:37:37 2016

@author: wxt

"""
import argparse, sys, os, logging, logging.handlers, traceback, tadlib

currentVersion = tadlib.__version__

def datasets_convert(metadata):
    
    datasets = {}
    with open(metadata, 'r') as source:
        for line in source:
            if line.startswith('res'):
                parse = line.rstrip().split(':')
                res = int(parse[1])
                datasets[res] = {}
            else:
                if line.isspace():
                    continue
                parse = line.strip().split(':')
                uri = ':'.join(parse[1:])
                datasets[res][parse[0]] = os.path.abspath(os.path.expanduser(uri))
    
    return datasets

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(usage = '%(prog)s <-d datasets -O output> [options]',
                                     description = '''A highly sensitive and reproducible
                                     hierarchical domain caller.''',
                                     formatter_class = argparse.ArgumentDefaultsHelpFormatter)
    
    parser.add_argument('-v', '--version', action = 'version',
                        version = ' '.join(['%(prog)s', currentVersion]),
                        help = 'Print version number and exit')

    # Output
    parser.add_argument('-O', '--output', help = 'Output file name.')
    parser.add_argument('-d', '--datasets', type = datasets_convert,
                        help = 'Metadata file describing Hi-C datasets. Refer to our '
                        'online documentation for more details.')
    parser.add_argument('-W', '--weight-col', default='weight',
                        help = '''Name of the column in .cool to be used to construct the
                        normalized matrix. Specify "-W RAW" if you want to run with the raw matrix.''')
    parser.add_argument('--exclude', nargs = '*', default = ['chrY','chrM'],
                        help = '''List of chromosomes to exclude.''')
    parser.add_argument('--minimum-chrom-size', default=1000000, type=int,
                        help='''Minimum chromosome size. Only chromosomes with a size greater than
                        this value will be considered.''')
    parser.add_argument('--maxsize', default=4000000, type=int, help = 'Maximum domain size in '
                        'base-pair unit.')
    parser.add_argument('-D', '--DI-col', default='DIs',
                        help = '''Name of the column in .cool to output the DI track''')
    parser.add_argument('-p', '--cpu-core', type = int, default = 1,
                        help = 'Number of processes to launch.')
    parser.add_argument('--removeCache', action = 'store_true',
                        help = '''Remove cache data before exiting.''')
    
    parser.add_argument('--logFile', default = 'hitad.log', help = '''Logging file name.''')
    
    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands

def run():
    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        ## Root Logger Configuration
        logger = logging.getLogger()
        logger.setLevel(10)
        console = logging.StreamHandler()
        filehandler = logging.handlers.RotatingFileHandler(args.logFile,
                                                           maxBytes = 200000,
                                                           backupCount = 5)
        # Set level for Handlers
        console.setLevel('INFO')
        filehandler.setLevel('DEBUG')
        # Customizing Formatter
        formatter = logging.Formatter(fmt = '%(name)-25s %(levelname)-7s @ %(asctime)s: %(message)s',
                                      datefmt = '%m/%d/%y %H:%M:%S')
        ## Unified Formatter
        console.setFormatter(formatter)
        filehandler.setFormatter(formatter)
        # Add Handlers
        logger.addHandler(console)
        logger.addHandler(filehandler)
        
        ## Logging for argument setting
        arglist = ['# ARGUMENT LIST:',
                   '# Output file name = {0}'.format(args.output),
                   '# Hi-C datasets = {0}'.format(args.datasets),
                   '# Column for matrix balancing = {0}'.format(args.weight_col),
                   '# Excluded chromosomes = {0}'.format(args.exclude),
                   '# Minimum Chromosome Size = {0}'.format(args.minimum_chrom_size),
                   '# Maximum domain size = {0}'.format(args.maxsize),
                   '# Column for DI track = {0}'.format(args.DI_col),
                   '# Number of processes used = {0}'.format(args.cpu_core),
                   '# Remove cache data = {0}'.format(args.removeCache),
                   '# Log file name = {0}'.format(args.logFile)
                   ]
        
        argtxt = '\n'.join(arglist)
        logger.info('\n' + argtxt)
        
        from tadlib.hitad.genomeLev import Genome
        
        try:
            logger.info('Parsing Hi-C datasets ...')
            cachefolder = os.path.abspath(os.path.expanduser('.hitad'))
            G = Genome(args.datasets, balance_type=args.weight_col, maxsize=args.maxsize,
                       cache=cachefolder, exclude=args.exclude, DIcol=args.DI_col,
                       min_chrom_size=args.minimum_chrom_size)
            logger.info('Done!')
            logger.debug('Learning HMM parameters for each dataset ...')
            G.learning(cpu_core=args.cpu_core)
            logger.info('Identifying hierarchical domains ...')
            kwargs = {'cpu_core':args.cpu_core}
            G.callHierDomain(**kwargs)
            logger.info('Done!')
            logger.info('Output domains ...')
            G.outputDomain(args.output)
            logger.info('Done!\n')
            if args.removeCache:
                G.wipeDisk()
        except:
            traceback.print_exc(file = open(args.logFile, 'a'))
            sys.exit(1)

if __name__ == '__main__':
    run()