#!/usr/bin/env python
import argparse
import logging
import os.path

import pandas as pd
from sklearn.externals import joblib


class Predictor(object):
    def __init__(self, on_ss_imp, off_ss_imp, on_ss_clf, off_ss_clf):
        self.on_ss_imp = joblib.load(on_ss_imp)
        self.off_ss_imp = joblib.load(off_ss_imp)
        self.on_ss_clf = joblib.load(on_ss_clf)
        self.off_ss_clf = joblib.load(off_ss_clf)
        self.logger = logging.getLogger(__name__)

    def predict(self, ifname, ofname):
        '''
        Predict diseasing-causing probability of all intronic SNVs
        :data: m x n pandas data frame. m is the number of SNVs, n is the number of features.
        '''
        self.logger.info('Loading all the features.')
        data = pd.read_csv(ifname, sep='\t', header=0)
        data_on_ss = data[(data['distance'] >= -13) & (data['distance'] <= 7) & (data['distance'] != 0)]
        data_off_ss = data[(data['distance'] < -13) | (data['distance'] > 7)]
        if not data_on_ss.empty:
            self.logger.info('Predicting on splicing site SNVs.')
            preds, scores = self._predict_on_ss(data_on_ss)
            data_on_ss.insert(4, 'splicing_site', ['on'] * len(preds))
            data_on_ss.insert(4, 'prob', scores[:, 1])
            data_on_ss.insert(4, 'disease', preds)
        else:
            data_on_ss.insert(4, 'splicing_site', '')
            data_on_ss.insert(4, 'prob', 0.0)
            data_on_ss.insert(4, 'disease', 0)
        if not data_off_ss.empty:
            self.logger.info('Predicting off splicing site SNVs')
            preds, scores = self._predict_off_ss(data_off_ss)
            data_off_ss.insert(4, 'splicing_site', ['off'] * len(preds))
            data_off_ss.insert(4, 'prob', scores[:, 1])
            data_off_ss.insert(4, 'disease', preds)
        else:
            data_off_ss.insert(4, 'splicing_site', '')
            data_off_ss.insert(4, 'prob', 0.0)
            data_off_ss.insert(4, 'disease', 0)
        self.logger.info('Generate final output.')
        pd.concat([data_on_ss, data_off_ss]).\
            sort_values(by=['#chrom', 'pos']).\
            to_csv(ofname, sep='\t', na_rep='NA', index=False)

    def _predict_on_ss(self, data):
        '''
        Predict diseasing-causing probability of on-splicing-site SNVs
        :data: m x n pandas data frame. m is the number of SNVs, n is the number of features.
        '''
        data = data.drop(['#chrom', 'pos', 'ref', 'alt', 'name', 'strand', 'distance'], axis=1)
        data = self.on_ss_imp.transform(data)
        preds = self.on_ss_clf.predict(data)
        scores = self.on_ss_clf.predict_proba(data)
        return preds, scores

    def _predict_off_ss(self, data):
        '''
        Predict diseasing-causing probability of off-splicing-site SNVs
        :data: m x n pandas data frame. m is the number of SNVs, n is the number of features.
        '''
        data = data.drop(['#chrom', 'pos', 'ref', 'alt', 'name', 'strand', 'distance', 'aic_change', 'dic_change'],
                         axis=1)
        data = self.off_ss_imp.transform(data)
        preds = self.off_ss_clf.predict(data)
        scores = self.off_ss_clf.predict_proba(data)
        return preds, scores


def main():
    parser = argparse.ArgumentParser(description='''
            Given model files and SNV features generated by feature calculator,
            predict the disease-causing probability.''')
    parser.add_argument('db_path',
            help='path to models')
    parser.add_argument('sfname',
            help='SNV file with features generated by feature calculator')
    parser.add_argument('ofname',
            help='output file')
    args = parser.parse_args()
    predictor = Predictor(os.path.join(args.db_path, 'on_ss_imp.pkl'),
                          os.path.join(args.db_path, 'off_ss_imp.pkl'),
                          os.path.join(args.db_path, 'on_ss_clf.pkl'),
                          os.path.join(args.db_path, 'off_ss_clf.pkl'))
    predictor.predict(args.sfname, args.ofname)

if __name__ == '__main__':
    main()
