#!/usr/bin/env python
import argparse
import logging
import re

import fetch_db
from gene_pred_ext import GenePredExt


class Pfam(object):
    def __init__(self, db):
        '''
        :db: sqlite database contains Pfam information generated by SpineD.
        '''
        self.db = db
        self.logger = logging.getLogger(__name__)

    def query_pfam_info(self, transcript_id, pstart, pend, pfam_info):
        query_info = []
        for record in pfam_info:
            trans_id, start, end, family, name, clan = record
            if start <= pend and end >= pstart:  # overlap
                query_info.append(record)
        return query_info
    
    def cal_pfam(self, transcript_id, pstart, pend):
        pfam_info = self.db.query_pfam(transcript_id)
        pfam = 'NA'
        if pstart and pend:
            if pfam_info:
                query_result = self.query_pfam_info(transcript_id, pstart, pend, pfam_info)
                if query_result:
                    pfam = 0.0
                    query_result = self._merge_interval(query_result)
                    for start, end in query_result:
                        pfam += (min(pend, end) - max(pstart, start))
                    assert pfam <= (pend - pstart + 1), '{0}, {1}, {2}, {3}\n'.format(transcript_id, pstart, pend, query_result)
                    pfam /= (pend - pstart + 1)
                else:
                    pfam = 0.0
                    self.logger.debug('region does not contain Pfam. {0} {1}-{2}\n'.format(transcript_id, pstart, pend))
            else:
                pfam = 0.0
        return pfam

    def _merge_interval(self, intervals):
        intervals = [[x[1], x[2]] for x in intervals]
        result = []
        result.append(intervals[0])
        for i in range(1, len(intervals)):
            a = result[-1]
            b = intervals[i]
            if b[0] > a[1]:
                result.append(b)
            else:
                result[-1][1]=max(a[1], b[1])
        return result


def main():
    parser = argparse.ArgumentParser(description='''
            Parse GenePred table (Extended) and extract features.''')
    parser.add_argument('ensembldb',
            help='ensembl sqlite db file, containing protein Pfam info.')
    parser.add_argument('gfname',
            help='GenePred table (Extended) file name, from UCSC table browser.')
    parser.add_argument('bfname',
            help='bedtools closest region distances output')
    args = parser.parse_args()
    
    gene_pred = GenePredExt(args.gfname)

    db = fetch_db.DB(args.ensembldb)
    pfam = Pfam(db)
    with open(args.bfname) as f:
        for line in f:
            cols = line.rstrip().split('\t')
            estart = int(cols[8])
            eend = int(cols[9])
            transcript_id = re.search(r'(\w+)_exon', cols[10]).group(1)
            if transcript_id in gene_pred.transcripts:
                pstart, pend = gene_pred.get_protein_coord(transcript_id, estart, eend)
                pfam_value = pfam.cal_pfam(transcript_id, pstart, pend)
            print('\t'.join(map(str, cols + [pfam_value])))

if __name__ == '__main__':
    main()
