#!/usr/bin/env python
import argparse
import logging
import re

import fetch_db
import numpy as np
from gene_pred_ext import GenePredExt


class SpineD(object):
    def __init__(self, db):
        '''
        :db: sqlite database contains structure information generated by SpineD.
        '''
        self.db = db
        self.logger = logging.getLogger(__name__)

    def query_structure_info(self, transcript_id, pstart, pend, structure_info):
        query_info = {}
        query_info['transcript_id'] = structure_info['transcript_id']
        query_info['aa'] = structure_info['aa'][pstart:pend]
        query_info['beta_sheet'] = structure_info['beta_sheet'][pstart:pend]
        query_info['random_coil'] = structure_info['random_coil'][pstart:pend]
        query_info['alpha_helix'] = structure_info['alpha_helix'][pstart:pend]
        query_info['asa'] = structure_info['asa'][pstart:pend]
        query_info['disorder'] = structure_info['disorder'][pstart:pend]
        query_info['length'] = pend - pstart
        if not query_info['disorder']:
            query_info = {}
        return query_info
    
    def cal_spined(self, transcript_id, pstart, pend):
        structure_info = self.db.query_structure(transcript_id)
        disorder = ['NA'] * 12
        ss = ['NA'] * 12
        asa = ['NA'] * 3
        if pstart and pend and structure_info:
            query_result = self.query_structure_info(transcript_id, pstart, pend, structure_info)
            if query_result:
                disorder = self._cal_disorder(query_result)
                ss = self._cal_ss(query_result)
                asa = self._cal_asa(query_result)
            else:
                self.logger.debug('region may only contain stop codon? {0} {1}-{2}\n'.
                                  format(transcript_id, pstart, pend))
        return disorder, ss, asa
    
    def _cal_disorder(self, structure_info):
        disorder = np.array(structure_info['disorder'])
        min_disorder = min(disorder)
        max_disorder = max(disorder)
        mean_disorder = np.mean(disorder)
        mean_disorder_structured_region = np.mean(disorder[disorder < 0.5]) if sum(disorder < 0.5) > 0 else 0.0
        mean_disorder_disorder_region = np.mean(disorder[disorder > 0.5]) if sum(disorder > 0.5) > 0 else 0.0
    
        disorder_len = []
        structured_len = []
        current_disorder_len = 0.0
        current_structured_len = 0.0
        switch_num = 0.0
        for d in disorder:
            if d < 0.5:
                if current_disorder_len > 0:
                    disorder_len.append(current_disorder_len)
                    current_disorder_len = 0.0
                    switch_num += 1
                current_structured_len += 1
            elif d > 0.5:
                if current_structured_len > 0:
                    structured_len.append(current_structured_len)
                    current_structured_len = 0.0
                    switch_num += 1
                current_disorder_len += 1
        if current_disorder_len > 0:
            disorder_len.append(current_disorder_len)
        if current_structured_len > 0:
            structured_len.append(current_structured_len)
        assert switch_num == len(disorder_len) + len(structured_len) - 1
        min_disorder_len = min(disorder_len) if disorder_len else 0.0
        max_disorder_len = max(disorder_len) if disorder_len else 0.0
        mean_disorder_len = np.mean(disorder_len) if disorder_len else 0.0
        min_structured_len = min(structured_len) if structured_len else 0.0
        max_structured_len = max(structured_len) if structured_len else 0.0
        mean_structured_len = np.mean(structured_len) if structured_len else 0.0
        return (min_disorder,
                max_disorder,
                mean_disorder,
                mean_disorder_structured_region,
                mean_disorder_disorder_region,
                switch_num,
                min_disorder_len,
                max_disorder_len,
                mean_disorder_len,
                min_structured_len,
                max_structured_len,
                mean_structured_len)
    
    def _cal_ss(self, structure_info):
        alpha_helix = np.array(structure_info['alpha_helix'])
        beta_sheet = np.array(structure_info['beta_sheet'])
        random_coil = np.array(structure_info['random_coil'])
        ss_combine = np.vstack((alpha_helix, beta_sheet, random_coil))
        assert ss_combine.shape[0] == 3
        predicted_ss = np.max(ss_combine, axis=0)
        assert len(predicted_ss) == len(alpha_helix)
        return (min(predicted_ss),
                max(predicted_ss),
                np.mean(predicted_ss),
                min(alpha_helix),
                max(alpha_helix),
                np.mean(alpha_helix),
                min(beta_sheet),
                max(beta_sheet),
                np.mean(beta_sheet),
                min(random_coil),
                max(random_coil),
                np.mean(random_coil))
    
    def _cal_asa(self, structure_info):
        asa = np.array(structure_info['asa'])
        return min(asa), max(asa), np.mean(asa)


def main():
    parser = argparse.ArgumentParser(description='''
            Parse GenePred table (Extended) and extract features.''')
    parser.add_argument('ensembldb',
            help='ensembl sqlite db file, containing protein structure info.')
    parser.add_argument('gfname',
            help='GenePred table (Extended) file name, from UCSC table browser.')
    parser.add_argument('bfname',
            help='bedtools closest region distances output')
    args = parser.parse_args()
    
    gene_pred = GenePredExt(args.gfname)

    db = fetch_db.DB(args.ensembldb)
    spined = SpineD(db)
    with open(args.bfname) as f:
        for line in f:
            cols = line.rstrip().split('\t')
            estart = int(cols[8])
            eend = int(cols[9])
            transcript_id = re.search(r'(\w+)_exon', cols[10]).group(1)
            if transcript_id in gene_pred.transcripts:
                pstart, pend = gene_pred.get_protein_coord(transcript_id, estart, eend)
                disorder, ss, asa = spined.cal_spined(transcript_id, pstart, pend)
            print('\t'.join(map(str, cols + list(disorder) + list(ss) + list(asa))))

if __name__ == '__main__':
    main()
