import pickle, os

from b2bTools.singleSeq.DynaMine.Predictor import DynaMine
from b2bTools.general.Io import B2bIo

class EFoldMine(B2bIo):

  version = '2.0'
  # scriptName = "b2bTools.singleSeq.EFoldMine.Predictor"
  scriptName = "python.b2bTools.singleSeq.EFoldMine.Predictor"
  name = "EFoldMine"

  # pickleProbabilisticModelFile = os.path.abspath(os.path.join(os.path.dirname(__file__),'models',"efModelRBF2016.proba.cPickle"))
  pickleProbabilisticModelFile = os.path.abspath(os.path.join(os.path.dirname(__file__),'models',"efoldmine_converted.cPickle"))

  discreteCutoff = 0.196 # or 0.192?

  def __init__(self, dynaMine=None):
    self.allPredictions = {}
    self.window = 2

    # print("Reading the early folding probabilistic model...")
    self.model = pickle.load(open(self.pickleProbabilisticModelFile,'rb'),encoding='latin1')
    self.dynaMine = dynaMine if dynaMine else DynaMine()

    # Additional info for writing files
    self.references = ['doi: 10.1038/ncomms3741 (2013)', 'doi: 10.1093/nar/gku270 (2014)', 'doi: 10.1038/s41598-017-08366-3 (2017)']
    self.infoTexts = ['Generated by EFoldMine','Based on the DynaMine project','See http://bio2byte.be']

    self.informationPerPredictor = {
      self.name: {
          'references': self.references,
          'info':    ";".join(self.infoTexts),
          'version': self.version,
          'origin':  self.scriptName
                }
          }

  def predictSeqs(self,seqs,dynaMinePreds=None, includeDiscreteClass=False):

    """
    :param seqs: List/tuple of (seqId,sequenceString) tuples
    :param dynaMinePreds: DynaMine predictions matching info in seqs
          in self.allPredictions[seqId][predictionType] = [(aminoAcidTypeString,predValue),...] format
    :keyword includeDiscreteClass: If set to True, prediction output will include a class prediction
             (1 for early folding, 0 for not early folding)
    :return: True if all sequences predicted, False if problems

    See self.allPredictions for the predictions themselves. Will include DynaMine values.
    """

    # print("Start predictions...")

    # These can be fed in if available, if not then will run these automatically
    if not dynaMinePreds:

      #
      # DynaMine predictions if necessary
      #

      self.dynaMine.predictSeqs(seqs)
      self.allPredictions = self.dynaMine.allPredictions

    else:
      self.allPredictions = dynaMinePreds

    # Now do early folding predictions
    seqDoneCount = 0
    for (seqId,sequence) in seqs:

      if len(sequence) < 7:
        print(("Sequence with ID {} too short, ignoring".format(seqId)))
        continue

      #
      # Separate variable where the value of the DynaMine backbone preds is 1 - necessary for the early folding prediction!
      #

      dynaMineBackboneNormalised = self.allPredictions[seqId]['backbone'][:] # Make a copy, will be shifted to maximum value of 1, might mess up other code

      maxValue = max([values[1] for values in dynaMineBackboneNormalised])
      correction = 1 - maxValue
      for seqIndex in range(len(dynaMineBackboneNormalised)):
        dynaMineBackboneNormalised[seqIndex] = (dynaMineBackboneNormalised[seqIndex][0],dynaMineBackboneNormalised[seqIndex][1] + correction)

      #
      # Run the early folding prediction
      #

      x = self.buildVectors((sequence,
                             dynaMineBackboneNormalised,
                             self.allPredictions[seqId]["helix"],
                             self.allPredictions[seqId]["sheet"],
                             self.allPredictions[seqId]["coil"],
                             self.allPredictions[seqId]["sidechain"]),
                             self.window)

      yp = self.model.predict_proba(x)[:,1]

      #
      # Add to prediction variable
      #

      assert len(yp) == len(sequence)

      # Adding to DynaMine predictions for consistency
      self.allPredictions[seqId]['earlyFolding'] = []

      for i in range(len(sequence)):
        if includeDiscreteClass:
          predTuple = (sequence[i], yp[i], self.discretePreds(yp[i]))
        else:
          predTuple = (sequence[i], yp[i])

        self.allPredictions[seqId]['earlyFolding'].append(predTuple)

      seqDoneCount += 1

    allDone = False
    if len(seqs) == seqDoneCount:
      allDone = True

    return allDone

  def discretePreds(self,v):

    if v > self.discreteCutoff:
      return 1
    return 0

  def buildVectors(self,feats, window):
    x = []
    y = []
     #          			 0		1		2		3	4		5		  6	  7		8		9			10
    tmp = feats#db[name] = (seq, backbone, helix, sheet, coil, sidechain, s2, rsa, ground, espritzNmr, espritzXray)
    i = 0
    while i < len(tmp[0]):
      #tx, ty = getSingleFeats(tmp, i)
      # 1 backbone, 2 helix,  3 sheet,  4 coil, 5 sidechain
      tx = self.getWindowFeats(tmp, 1, i, window)+ self.getWindowFeats(tmp, 2, i, window)+self.getWindowFeats(tmp, 3, i, window)+self.getWindowFeats(tmp, 4, i, window)+self.getWindowFeats(tmp, 5, i, window)
      assert len(tx) == 25
      x.append(tx)
      i += 1
    return x

  def getWindowFeats(self,data, featureNum, pos, w):	#feature indicates : 1 backbone, 2 helix,  3 sheet,  4 coil, 5 sidechain
    # 1 backbone, 2 helix,  3 sheet,  4 coil, 5 sidechain
    chosenFeat = data[featureNum][max(pos-w, 0):min(pos+w+1,len(data[1]))]
    chosenFeatW = []
    if pos-w < 0:
      chosenFeatW = [-1] * -(pos-w)
    for i in chosenFeat:
      chosenFeatW.append(round(i[1],3))
    if pos+w+1 > len(data[featureNum]):
      chosenFeatW += [-1] * (pos+w+1 - len(data[featureNum]))
    assert len(chosenFeatW) == (w*2 +1)
    #print chosenFeatW
    return chosenFeatW
