#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import sys
import argparse
import logging
from scipy.sparse import lil_matrix
import numpy as np

DATA_DIR = os.getenv('GBRS_DATA', '.')


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", '--quant-file', action='store', dest='quantfile', type=str, required=True)
    parser.add_argument('-s', '--strains', action='store', dest='strains', type=str, required=True)
    parser.add_argument('-t', '--tid-file', action='store', dest='tidfile', type=str, default=None)
    parser.add_argument('-g', '--group-file', action='store', dest='grpfile', type=str, default=None)
    parser.add_argument("-o", action='store', dest='outbase', type=str, default=None)
    parser.add_argument("-v", "--verbose", help="Toggle DEBUG verbosity", action="store_true")
    return parser.parse_args()


def main(args, loglevel):

    logging.basicConfig(format="[gbrs::%(levelname)s] %(message)s", level=loglevel)
    logging.debug("Conversion requested: %s" % args.quantfile)

    #
    # Todo: Add code here!
    #

    logging.info("Reading in transcript IDs...")
    tname = np.loadtxt(args.tidfile, usecols=(0,), dtype='string')
    num_transcripts = len(tname)
    tid = dict(zip(tname, np.arange(num_transcripts)))

    logging.info("Getting suffices for strains...")
    strains = args.strains.split(',')
    num_strains = len(strains)
    hid = dict(zip(strains, np.arange(num_strains)))

    logging.info("Processing group file...")
    if args.grpfile is None:
        grpfile = os.path.join(DATA_DIR, 'ref.gene2transcripts.tsv')
    else:
        grpfile = args.grpfile
    gname = list()
    groups = list()
    with open(grpfile) as fh:
        for curline in fh:
            item = curline.rstrip().split("\t")
            gname.append(item[0])
            tid_list = [tid[t] for t in item[1:]]
            groups.append(tid_list)
    num_genes = len(gname)
    gid = dict(zip(gname, np.arange(num_genes)))
    grp_conv_mat = lil_matrix((num_transcripts, num_genes))
    for i in xrange(num_genes):
        grp_conv_mat[groups[i], i] = 1.0
    grp_conv_mat = grp_conv_mat.tocsc()

    logging.info("Exporting to GBRS quant format...")
    if args.outbase is None:
        outbase = 'kallisto.abundance'
    else:
        outbase = args.outbase
    toutfile = outbase + ".isoforms"
    theader = "\t".join(["Transcript_ID"] + strains + ["Total"])
    np.savetxt(toutfile, np.column_stack((tname, np.char.mod('%f', tmat))), \
               header=theader, fmt='%s', delimiter='\t')
    gmat = grp_conv_mat.transpose() * tmat
    goutfile = outbase + '.genes'
    gheader = "\t".join(["Gene_ID"] + strains + ["Total"])
    np.savetxt(goutfile, np.column_stack((gname, np.char.mod('%f', gmat))), \
               header=gheader, fmt='%s', delimiter='\t')

    logging.info("DONE")

if __name__ == '__main__':
    arguments = parse_args()
    # Setup logging
    if arguments.verbose:
        log_level = logging.DEBUG
    else:
        log_level = logging.INFO
    sys.exit(main(arguments, log_level))
