#! /bin/env python
'''
Created on Oct 30, 2013

@author: Shunping Huang
'''

from __future__ import print_function

import os
import sys
import gc
import pysam
import argparse as ap
import logging
import multiprocessing as mp
import time

from modtools.alias import Alias
from modtools.mod import Mod
from modtools.utils import *

# from lapels.matefixer import *
from lapels.annotator import Annotator
from lapels.matefixer import fixmate
from lapels.version import __version__


DESC = "A remapper and annotator of in silico (pseudo) genome alignments."
VERBOSITY = 1
VERSION = '0.2.0'
PKG_VERSION = __version__

nReadsInChroms = dict()
outHeader = None

logger = None
args = None
alias = None


def init_logger():
    global logger
    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    ch = logging.StreamHandler()
    #ch.setLevel(logging.INFO)
    formatter = logging.Formatter("[%(asctime)s] %(name)-10s: "
                                  "%(levelname)s: %(message)s",
                                  "%Y-%m-%d %H:%M:%S")
    ch.setFormatter(formatter)
    logger.addHandler(ch)


def parse_arguments():
    global args
    # Usage:
    # pylapels [-V][-q][-v][-t][-n][-p process][-c list]
    #          modin_fn bamin_fn [bamout_fn]

    p = ap.ArgumentParser(description=DESC,
                          formatter_class=ap.RawTextHelpFormatter)
    p.add_argument('-V', '--version', action='version', version='%(prog)s' +
                   ' %s in Lapels %s' % (VERSION, PKG_VERSION))
    group = p.add_mutually_exclusive_group()
    group.add_argument("-q", dest='quiet', action='store_true',
                       help='quiet mode')
    group.add_argument('-v', dest='verbosity', action="store_const", const=2,
                       default=1, help="verbose mode")
    p.add_argument("-f", dest='force', action='store_true',
                   help='overwrite existing output (mod)')
    p.add_argument('-a', metavar='alias.csv', dest='alias_fn',
                   default=None,
                   help='the csv file for alias classes of sequence name'
                   ' (default: None)')
    p.add_argument('-t', dest='keepTemp', action='store_true',
                   help="keep temporary files (default: no)")
    p.add_argument('--nofixmate', dest='noFixMate',action='store_true',
                   help="disable mate fixing (default: no)")
    p.add_argument('-n', dest='sortByName', action='store_true',
                   help='output bam file sorted by read names (default: no)')
    p.add_argument('-p', metavar='nProcesses', dest='nProcesses',
                   type=int, default=1,
                   help='number of processes to run (default: 1)')
    p.add_argument('-c', metavar='chromList', dest='chroms', default=None,
                   help='a comma-separated list of chromosomes in output'
                   ', e.g. 1,2,3 (default: all)')
    p.add_argument('-o', dest='bamout_fn', metavar='out.bam', default=None,
                   help='the output bam file (default: in.annotated.bam)')

    p.add_argument('mod_fn', metavar='in.mod',
                   help='the mod file of the in silico genome')
    p.add_argument('bamin_fn', metavar='in.bam',
                   help='the input bam file')

    args = p.parse_args()
    #print(args)

    if args.quiet:
        logger.setLevel(logging.CRITICAL)
    elif args.verbosity == 2:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)

    logger.debug(args)
    #annotator.VERBOSITY = args.verbosity

    is_file_readable(args.mod_fn)
    is_file_readable(args.bamin_fn)

    # Check if input bam index exists.
    if not os.path.isfile(args.bamin_fn + '.bai'):
        logger.info("creating index for input BAM file")
        pysam.index(args.bamin_fn)
        if os.path.isfile(args.bamin_fn + '.bai'):
            logger.info("index created")
        else:
            logger.exception("index failed")
            raise IOError("Failed to create index. " +
                          "Make sure the bam file is sorted by position.")

    if args.bamout_fn is None:
        if args.bamin_fn.endswith('.bam'):
            args.bamout_fn = args.bamin_fn[:-4] + '.annotated.bam'
        else:
            args.bamout_fn = args.bamin_fn + '.annotated.bam'
    elif not args.bamout_fn.endswith('.bam'):
        raise ValueError("Output bam file should have '.bam' suffix")
    is_file_writable(args.bamout_fn, args.force)

    mod = Mod(args.mod_fn)

    if args.chroms is None:
        args.chroms = sorted(mod.all_chroms)
    else:
        args.chroms = args.chroms.split(',')

    if args.alias_fn is not None:
        is_file_readable(args.alias_fn)


def annotate(bam_fn, mod_fn, mergePool, chrom, lock=None):
    gc.disable()

    bam = pysam.Samfile(bam_fn, 'rb')
    mod = Mod(mod_fn, checkMode=False)

    # Choose the appropiate chromosome name for the MOD
    modChrom = alias.getName(chrom)
    assert modChrom is not None
    if lock:
        lock.acquire()
    logger.info("alias '%s' used for '%s' in MOD", modChrom, chrom)
    if lock:
        lock.release()

    mod.load(modChrom)
    chromLength = mod.refmeta[modChrom].length

    # Choose the appropiate chromosome name for the BAM
    for bamChrom in alias.getAliases(modChrom):
        if bamChrom in nReadsInChroms.keys():
            if lock:
                lock.acquire()
            logger.info("alias '%s' used for '%s' in BAM", bamChrom, chrom)
            if lock:
                lock.release()
            break
    else:
        logger.warning("Unable to determine chromosome '%s' in BAM." % chrom)
        return

    bam_iter = bam.fetch(bamChrom.encode("utf-8"))
    nReads = nReadsInChroms[bamChrom]
    unsorted_fn = "%s.%s.unsorted.bam" % (args.bamout_fn, chrom)
    unsorted = pysam.Samfile(unsorted_fn, 'wb', header=outHeader)

    a = Annotator(mod, bam_iter, modChrom, chromLength, nReads,
                  unsorted, lock)
    ret = a.execute()
    unsorted.close()
    bam.close()

    if ret > 0:
        # Sort by read name, required by fixmate
        if lock:
            lock.acquire()
        logger.info("sorting reads in '%s' by names ...", chrom)
        if lock:
            lock.release()
        sorted_fn = "%s.%s.sorted.bam" % (args.bamout_fn, chrom)
        pysam.sort('-n', unsorted_fn.encode("utf-8"),
                   sorted_fn[:-4].encode("utf-8"))
        mergePool.append(sorted_fn)

    os.remove(unsorted_fn)
    gc.enable()


def worker(workerId, bam_fn, mod_fn, mergePool, queue, qhead, qlock, llock):
    qlen = len(queue)
    while True:
        # Obtain the index of the queue head.
        if qlock:
            qlock.acquire()
        idx = qhead.value
        qhead.value += 1
        if qlock:
            qlock.release()

        if idx >= qlen:
            return

        if llock:
            llock.acquire()
        logger.info("workder %d processes chromosome '%s'" %
                    (workerId, queue[idx]))
        if llock:
            llock.release()

        annotate(bam_fn, mod_fn, mergePool, queue[idx], llock)


def run():
    global alias
    global nReadsInChroms
    global outHeader
    global outPrefix

    logger.info("input MOD file: %s" % args.mod_fn)
    logger.info("input BAM file: %s" % args.bamin_fn)
    logger.info("output BAM file: %s" % args.bamout_fn)

    alias = Alias()
    try:
        alias.load(args.alias_fn)
    except:
        pass

    mod = Mod(args.mod_fn)

    # Get the number of reads in each chromosome
    nReadsInChroms = dict()
    for idxstat in pysam.idxstats(args.bamin_fn):
        tup = idxstat.rstrip('\n').split('\t')
        nReadsInChroms[tup[0]] = int(tup[2])

    inFile = pysam.Samfile(args.bamin_fn, 'rb')
    outHeader = dict(inFile.header.items())
    references = inFile.references
    inFile.close()

    # Append a PG tag in the header of output bam
    if 'PG' in outHeader.keys():
        outHeader['PG'] = [{'ID': 'Lapels', 'VN': PKG_VERSION,
                            'PP': outHeader['PG'][0]['ID'],
                            'CL': ' '.join(sys.argv)}] + outHeader['PG']
    else:
        outHeader['PG'] = [{'ID': 'Lapels', 'VN': PKG_VERSION,
                            'CL': ' '.join(sys.argv)}]

    # Reset lengths in the header to reference coordinate system.
    outHeader['SQ'] = []
    for sn in references:
        chrDict = {}
        chrDict['SN'] = sn
        chrom = alias.getName(sn)
        try:
            chrDict['LN'] = mod.refmeta[chrom].length  # Length in reference
        except KeyError:
            raise ValueError("Unable to determine length of chromosome '%s'"
                             " from MOD." % chrDict['SN'])
        outHeader['SQ'].append(chrDict)

    # Correct order of reads
    outHeader['HD'] = outHeader.get('HD', {})
    if args.sortByName:
        outHeader['HD']['SO'] = 'query_name'
    else:
        outHeader['HD']['SO'] = 'coordinate'
    outHeader['HD']['VN'] = '1.0'

    logger.debug(outHeader)
    # comment = generateComment()
    # outHeader['CO'] = [comment] + outHeader.get('CO',[])

    # Annotate using multiple processes or a signle process
    nProcesses = args.nProcesses
    if args.verbosity == 2:
        logger.warning("force to use a single process in verbose mode")
        nProcesses = 1

    if nProcesses > 1:
        logger.info("use multiple processes: %d", nProcesses)
        llock = mp.Lock()  # lock for logger
        qlock = mp.Lock()  # lock for queue

        manager = mp.Manager()
        mergePool = manager.list()
        qhead = mp.Value('i', 0, lock=False)
        try:
            for i in range(nProcesses):
                p = mp.Process(target=worker,
                               args=(i, args.bamin_fn, args.mod_fn, mergePool,
                                     args.chroms, qhead, qlock, llock))
                p.start()
        except:
            raise RuntimeError("Cannot use multiple processes.")

        while len(mp.active_children()) > 1:
            time.sleep(1)
    else:
        logger.info("use a single process")
        mergePool = []
        for chrom in args.chroms:
            logger.info("workder %d processes chromosome '%s'" % (0, chrom))
            annotate(args.bamin_fn, args.mod_fn, mergePool, chrom)

    nMerges = len(mergePool)
    assert nMerges > 0
    if nMerges > 1:
         # Merge
        logger.info("merging %d files ...", nMerges)
        tmp_fn = args.bamout_fn + '.merged.bam'
        tmp_fn = tmp_fn.encode("utf-8")
        mergeParams = ['-f', '-n', tmp_fn] +\
                      [fn.encode("utf-8") for fn in mergePool]
        pysam.merge(*mergeParams)
        if not args.keepTemp:
            for fn in mergePool:
                os.remove(fn)
    else:
        logger.info("merging %d file ...", nMerges)
        os.rename(mergePool[0], args.bamout_fn + '.merged.bam')

    # Fix mates
    logger.info("fixing mate ...")
    # pysam.fixmate(outPrefix+'.sorted.tmp.bam', outPrefix +
    #               '.matefixed.tmp.bam')
    if not args.noFixMate:
        fixmate(args.bamout_fn + '.merged.bam', args.bamout_fn + '.matefixed.bam')
        if not args.keepTemp:
            os.remove(args.bamout_fn + '.merged.bam')

    if not args.sortByName:
        # Sort by position
        logger.info("sorting reads by positions ...")
        if not args.noFixMate:
            in_name = args.bamout_fn + '.matefixed.bam'
        else:
            in_name = args.bamout_fn + '.merged.bam'

        in_name = in_name.encode("utf-8")
        out_name = args.bamout_fn[:-4]
        out_name = out_name.encode("utf-8")
        pysam.sort(in_name, out_name)
        if not args.keepTemp:
            if not args.noFixMate:
                os.remove(args.bamout_fn + '.matefixed.bam')
            else:
                os.remove(args.bamout_fn + '.merged.bam')

        # Build index for output
        logger.info("creating bam index for output")
        out_name = args.bamout_fn
        out_name = out_name.encode("utf-8")
        pysam.index(out_name)
        if os.path.isfile(args.bamout_fn + '.bai'):
            logger.info("index created")
        else:
            logger.warning("index failed")
    else:
        if not args.noFixMate:
            os.rename(args.bamout_fn + '.matefixed.bam', args.bamout_fn)
        else:
            os.rename(args.bamout_fn + '.merged.bam', args.bamout_fn)

    logger.info("All Done!")
    logging.shutdown()


if __name__ == '__main__':
    init_logger()
    parse_arguments()
    run()
