#!/usr/bin/python
import numpy.oldnumeric as Numeric
import sys, os, os.path, struct, math, string
import copy
import gzip
import types
from math import *
#from PDB7 import *
PRIM_SPHERE_RAD = 1110


import MolKit
from MolKit.molecule import Atom, AtomSet, BondSet, Molecule , MoleculeSet
from MolKit.protein import Protein, ProteinSet, Residue, Chain, ResidueSet
from MolKit.stringSelector import CompoundStringSelector
from MolKit.tree import TreeNode, TreeNodeSet
from MolKit.molecule import Molecule, Atom
from MolKit.protein import Residue

from types import StringType, ListType

from c4d import gui

import c4d
import c4d.documents
import c4d.plugins
from c4d import plugins
from c4d import tools
from c4d.gui import *
from c4d.plugins import *


DGatomIds=['ASPOD1','ASPOD2','GLUOE1','GLUOE2', 'SERHG',
                        'THRHG1','TYROH','TYRHH',
                        'LYSNZ','LYSHZ1','LYSHZ2','LYSHZ3','ARGNE','ARGNH1','ARGNH2',
                        'ARGHH11','ARGHH12','ARGHH21','ARGHH22','ARGHE','GLNHE21',
                        'GLNHE22','GLNHE2',
                        'ASNHD2','ASNHD21', 'ASNHD22','HISHD1','HISHE2' ,
                        'CYSHG', 'HN']



def addObjectToScene(doc,obj,parent=None):
    #doc.start_undo()
    if parent != None : doc.insert_object(op=obj,parent=parent)
    else : doc.insert_object(op=obj)
    #add undo support
    #doc.add_undo(c4d.UNDO_NEW, obj)    
    #doc.end_undo()

def rotatePoint(pt,m,ax):
      x=pt[0]
      y=pt[1]
      z=pt[2]
      u=ax[0]
      v=ax[1]
      w=ax[2]
      ux=u*x
      uy=u*y
      uz=u*z
      vx=v*x
      vy=v*y
      vz=v*z
      wx=w*x
      wy=w*y
      wz=w*z
      sa=sin(ax[3])
      ca=cos(ax[3])
      pt[0]=(u*(ux+vy+wz)+(x*(v*v+w*w)-u*(vy+wz))*ca+(-wy+vz)*sa)+ m[0]
      pt[1]=(v*(ux+vy+wz)+(y*(u*u+w*w)-v*(ux+wz))*ca+(wx-uz)*sa)+ m[1]
      pt[2]=(w*(ux+vy+wz)+(z*(u*u+v*v)-w*(ux+vy))*ca+(-vx+uy)*sa)+ m[2]
      return pt


def dist(A,B):
  return Numeric.sqrt((A[0]-B[0])**2+(A[1]-B[1])**2+(A[2]-B[2])**2)

def norm(A):
        "Return vector norm"
        return sqrt(sum(A*A))

def normsq(A):
        "Return square of vector norm"
        return abs(sum(A*A))

def normalize(A):
        "Normalize the Vector"
        if (norm(A)==0.0) : return A
	else :return A/norm(A)




class pybObject():
	def __init__(self,obj,name,atms ):
	    self.b_obj=obj
	    self.name=name
	    self.Atoms=atms

class Surface(pybObject):
	def __init__(self,obj,name,atms,srf ):
	    pybObject.__init__(self, obj=obj,name=name,atms=atms)
	    self.msmsAtoms=atms
	    self.msmsSurf=srf

AtmRadi = {"N":"1.54","C":"1.7","CA":"1.7","O":"1.52","S":"1.85","H":"1.2"}

def computeRadius(protein,center=None):
		if center == None : center = protein.getCenter()
		rs = 0.
		for atom in protein.allAtoms:	
			r = dist(center,atom._coords[0])
			if r > rs:
				rs = r
		return rs

def spheresMesh(name,typMes,x,scn,armObj=None,scale=1.0,Res=32,R=None,join=0):
    if scale == 0.0 : scale =1.
    scale = scale *2.
    spher=[]
    if Res == 0 : Res = 10.
    else : Res = Res *5.
    k=0
    for j in range(len(x)):
    #at=res.atoms[j]
        at=x[j]
        atN=at.name
        fullname = at.full_name()
        print fullname
        atC=at._coords[0]
        if R !=None : rad=R
        elif AtmRadi.has_key(atN[0]) : rad=AtmRadi[atN[0]]
        else : rad=AtmRadi['H']
        me = c4d.BaseObject(c4d.Osphere)
        me[PRIM_SPHERE_RAD] = float(rad)*scale
        me.set_pos(c4d.Vector(float(atC[0]),float(atC[1]),float(atC[2])))
        spher.append(me)
    k=k+1
    return spher
	
def AtomPrim(name,typPrim,x,armObj,scale,res=32,R=None,join=0):
 Rsph=[]
 Robj=[]
 mod=[]
 spher=[]
 obj=[]
 k=0
 for i in range(len(x)):
  res=x[i]
  resN=res.name
  t=len(res.atoms)
  for j in range(t):
    at=res.atoms[j]
    atN=at.name
    atC=at._coords[0]
    at.colors[name] = (1.,1.,1.)
    at.opacities[name] = 1.0
    if R !=None : rad=R
    elif AtmRadi.has_key(atN[0]) : rad=AtmRadi[atN[0]]
    else : rad=AtmRadi['H']
    if typPrim == "Cube" :
        me = c4d.BaseObject(c4d.Ocube)
        #obj.set_pos(bd.sw(Vector(x, y, 500.0)))
        me[1100] = c4d.Vector(float(rad)*scale,float(rad)*scale,float(rad)*scale)
        #me[1100,1000] = float(rad)*scale
        print me[1100]        # me=Prim.Primitives.Cube(float(rad)*scale) #Cylinder(verts, diameter, length)
    elif typPrim == "Sphere" : 
        me = c4d.BaseObject(c4d.Osphere)
        #obj.set_pos(bd.sw(Vector(x, y, 500.0)))
        me[PRIM_SPHERE_RAD] = float(rad)*scale
    	#me=Prim.Primitives.UVsphere(64,32,float(rad)*scale)
    elif typPrim == "Mb":
       pass
	   #me=metab.elements.add()
       #me.radius=float(rad)*3    
 
    #if AtmRadi.has_key(atN[0]) : mat = Material.Get(atN[0])
    #else : mat = Material.Get('H')
    #me.materials=[mat]
    #if typPrim == "Mb": me.co = Blender.Mathutils.Vector(atC[0], atC[1], atC[2])	
    #else : 
    me.set_pos(c4d.Vector(float(atC[0]),float(atC[1]),float(atC[2])))
   # texture = me.make_tag(c4d.Ttexture)
    # refer the texture tag with the red colored material
   # texture[1010] = me.__red_mat
    spher.append(me)
	#OBJ=Object.New('Prim',resN+"_"+atN+str(j))
	#resG.objects.link(OBJ)
	#obj.append(OBJ)
	#obj[k].link(spher[k])
	
	#obj[k].setLocation(float(atC[0]),float(atC[1]),float(atC[2]))   
    #print obj[k]
    	#resGr[i].objects=obj
    #mods = obj[k].modifiers
    #mod=mods.append(Modifier.Types.ARMATURE)
    #mod[Modifier.Settings.OBJECT] = armObj
    # 	scn.link(obj[k])
	#if armObj != None : 
    # 		mods = obj[i].modifiers
#		mod=mods.append(Modifier.Types.ARMATURE)
#		mod[Modifier.Settings.OBJECT] = armObj
    k=k+1
    #obj[i].link(mat)
  #Rsph.append(spher) 
  #Robj.append(obj)
  #pr.objects.link(resG)
  #if typPrim == "Mb":
   #bball.link(metab)
   #ob_mb = scn.objects.new(metab)
   #if armObj != None :
   # modi=ob_mb.modifiers
   # mo=modi.append(Modifier.Types.ARMATURE)
   # mo[Modifier.Settings.OBJECT] = armObj
   # obj=ob_mb
    #scn.link(bball)
 #join the Prim..
# if typPrim != "Mb" :#and join==1 : 
	#obj[0].join(obj[1:])
#	for ind in range(1,len(obj)):
		#obj[0].join([obj[ind]])
#		scn.unlink(obj[ind])
#	obj[0].setName(name)
 vdwObj=pybObject(obj=obj,name=name,atms=x.findType(Atom))
 return  spher,vdwObj 

def createsNmesh(name,vertices,vnormals,faces,smooth=False):
      PDBgeometry = c4d.PolygonObject(len(vertices), len(faces))
      #set name
      #PDBgeometry.set_name(op.get_name())
      #print op.get_name()
      # set points
      k=0
      for v in vertices :
          PDBgeometry.set_point(k, c4d.Vector(float(v[0]), float(v[1]), float(v[2])))
          k=k+1

      # set polygons
	  print len(faces)
      for g in range(len(faces)):
          A = int(faces[g][0])
          B = int(faces[g][1])
          C = int(faces[g][2])
          D = C
          #print A
          PDBgeometry.set_polygon(id=g, polygon=[ A, B, C, D ])
      print PDBgeometry	  
      return PDBgeometry

def blenderColor(col):
        if max(col)<=1.0: col = map( lambda x: x*255, col)
        return col


def changeColor(mesh,colors):
	print len(colors)
	for c in colors :
		print c
"""	mesh.vertexColors = 1  # enable vertex colors
	unic=False
	ncolor=None
	if len(colors)==1 : 
		#print colors	
		unic=True
		ncolor = blenderColor(colors[0])
	for f in mesh.faces:
		for i, v in enumerate(f):
			col= f.col[i]
			if not unic : ncolor = blenderColor(colors[v.index])
			col.r= int(ncolor[0])
			col.g= int(ncolor[1])
			col.b= int(ncolor[2])
	mesh.materials[0].setMode("VColPaint")
	if unic :
		mesh.materials[0].R=int(ncolor[0])
		mesh.materials[0].G=int(ncolor[1])
		mesh.materials[0].B=int(ncolor[2])
"""

def atomPropToVertices(obj,name,srf,atoms, propName, propIndex=None):#propIndex:surfName
        """Function called to map atomic properties to the vertices of the
        geometry"""
        if len(atoms)==0: return None

        geomC = obj
        surfName = name
        surf = srf
        surfNum = 1
        # array of colors of all atoms for the msms.
        prop = []
        if propIndex is not None:
            for a in geomC.msmsAtoms.data:
                d = getattr(a, propName)
                prop.append( d[surfName] )
        else:
            for a in geomC.msmsAtoms.data:
                prop.append( getattr(a, propName) )
        # find indices of atoms with surface displayed
        atomIndices = []
        indName = '__surfIndex%d__'%surfNum
        for a in atoms.data:
            atomIndices.append(getattr(a, indName))
        # get the indices of closest atoms
        dum1, vi, dum2 = surf.getTriangles(atomIndices, keepOriginalIndices=1)
        # get lookup col using closest atom indicies
        mappedProp = Numeric.take(prop, vi[:, 1]-1).astype('f')
        if hasattr(obj,'apbs_colors'):
            colors = []
            for i in range(len(geom.apbs_dum1)):
                ch = geom.apbs_dum1[i] == dum1[0]
                if not 0 in ch:
                    tmp_prop = mappedProp[0]
                    mappedProp = mappedProp[1:]
                    dum1 = dum1[1:]
                    if    (tmp_prop[0] == [1.5]) \
                      and (tmp_prop[1] == [1.5]) \
                      and (tmp_prop[2] == [1.5]):
                        colors.append(geom.apbs_colors[i][:3])
                    else:
                        colors.append(tmp_prop)
                    if dum1 is None:
                        break
            mappedProp = colors            
        return mappedProp


def msms(nodes, surfName='MSMS-MOL', pRadius=1.5, density=1.0,
             perMol=True, display=True,  hdensity=6.0):
        """Required Arguments:\n        
        nodes   ---  current selection\n
        surfName --- name of the surfname which will be used as the key in
                    mol.geomContainer.msms dictionary.\n
        \nOptional Arguments:  \n      
        pRadius  --- probe radius (1.5)\n
        density  --- triangle density to represent the surface. (1.0)\n
        perMol   --- when this flag is True a surface is computed for each 
                    molecule having at least one node in the current selection
                    else the surface is computed for the current selection.
                    (True)\n
        display  --- flag when set to True the displayMSMS will be executed with
                    the surfName else not.\n
        hdset    --- Atom set for which high density triangualtion 
                     will be generated
        hdensity --- vertex density for high density
        """
        from mslib import MSMS
        if nodes is None or not nodes:
            return
        # Check the validity of the input
        if not type(density) in [types.IntType, types.FloatType] or \
           density < 0: return 'ERROR'
        if not type(pRadius) in [types.IntType, types.FloatType] or \
           pRadius <0: return 'ERROR'
              
        # get the set of molecules and the set of atoms per molecule in the
        # current selection
        if perMol:
            molecules = nodes.top.uniq()
            atmSets = map(lambda x: x.allAtoms, molecules)
	     
        #else:
        #    molecules, atmSets = self.vf.getNodesByMolecule(nodes, Atom)

        for mol, atms in map(None, molecules, atmSets):
            if not surfName:
                surfName = mol.name + '-MSMS'
            # update the existing geometry
	    print mol
            for a in mol.allAtoms:
                    a.colors[surfName] = (1.,1.,1.)
                    a.opacities[surfName] = 1.0
            i=0  # atom indices are 1-based in msms
            indName = '__surfIndex%d__'% 1
            hd = []
            surf = []
	    atmRadii=[]
            for a in atms:
                setattr(a, indName, i)
                i = i + 1
                surf.append(1)
                hd.append(0)
            	atmRadii.append(a.vdwRadius)
            # build an MSMS object and compute the surface
            srf = MSMS(coords=atms.coords, radii=atmRadii, surfflags=surf,
                       hdflags=hd )
            srf.compute(probe_radius=pRadius, density=density,
                        hdensity=hdensity)
	    vf, vi, f = srf.getTriangles()
	    vertices=vf[:,:3]
	    vnormals=vf[:,3:6]
	    faces=f[:,:3]
	    ob=createsNmesh(surfName,vertices,vnormals,faces)
	    surface=Surface(ob,surfName,mol.allAtoms,srf)
        return ob,surface

def expandNodes(nodes,mols):
        """Takes nodes as string or TreeNode or TreeNodeSet and returns
a TreeNodeSet
If nodes is a string it can contain a series of set descriptors with operators
separated by / characters.  There is always a first set, followed by pairs of
operators and sets.  All sets ahve to describe nodes of the same level.

example:
    '1crn:::CA*/+/1crn:::O*' describes the union of all CA ans all O in 1crn
    '1crn:::CA*/+/1crn:::O*/-/1crn::TYR29:' 
"""	
	from MolKit.stringSelector import CompoundStringSelector
	from MolKit.tree import TreeNode, TreeNodeSet

        if isinstance(nodes,TreeNode):
            result = nodes.setClass([nodes])
            result.setStringRepr(nodes.full_name())

        elif type(nodes)==StringType:
            stringRepr = nodes
            css = CompoundStringSelector()
            result = css.select(mols, stringRepr)[0]
##            setsStrings = stringRepr.split('/')
##            getSet = self.Mols.NodesFromName
##            result = getSet(setsStrings[0])
##            for i in range(1, len(setsStrings), 2):
##                op = setsStrings[i]
##                arg = setsStrings[i+1]
##                if op=='|': # or
##                    result += getSet(arg)
##                elif op=='-': # subtract
##                    result -= getSet(arg)
##                elif op=='&': # intersection
##                    result &= getSet(arg)
##                elif op=='^': # xor
##                    result ^= getSet(arg)
##                elif op=='s': # sub select (i.e. select from previous result)
##                    result = result.get(arg)
##                else:
##                    raise ValueError, '%s bad operation in selection string'%op
##            result.setStringRepr(stringRepr)

        elif isinstance(nodes,TreeNodeSet):
            result = nodes
        else:
            raise ValueError, 'Could not expand nodes %s\n'%str(nodes)
        
        return result

def colorByAtomType(nodes,obj):
	from Pmv.pmvPalettes import AtomElements
	#from Pmv.colorPalette import ColorPalette, ColorPaletteFunction
	
	c = 'Color palette for atom type'
	palette = ColorPalette('AtomElements', colorDict=AtomElements,
                                    readonly=0, info=c,
                                    lookupMember='element')
	
	#molecules, nodes = self.getNodes(nodes)
	#molecules, atms, nodes = self.getNodes(nodes, returnNodes=True)
	molecules = nodes.top.uniq()
	atms = map(lambda x: x.allAtoms, molecules)
	#nodes = expandNodes(nodes,molecules)

	colors = palette.lookup( atms[0] )
	#print colors
	
	for a, c in map(None, atms[0], colors):
		a.colors[obj.name] = tuple(c)
	"""
        if len(colors)==len(nodes) and not isinstance(nodes[0], Atom):
            #expand colors from nodes to atoms
            newcolors = []
            for n,c in map(None,nodes,colors):
                newcolors.extend( [c]*len(n.findType(Atom)) )
            colors = newcolors
            
        if len(colors)==1 or len(colors)!=len(atms):
                for a in atms:
                    if not a.colors.has_key(obj.name): continue
                    a.colors[obj.name] = tuple( colors[0] )
        else:
                for a, c in map(None, atms, colors):
                    if not a.colors.has_key(obj.name): continue
                    #a.colors[g] = tuple(c[:3])
                    a.colors[obj.name] = tuple(c)
	"""
	#vcolors=atomPropToVertices(obj,obj.name,obj.msmsSurf,obj.msmsAtoms,'colors',propIndex=obj.name)
	#changeColor(obj.mesh,vcolors)

def lookupDGFunc(atom):
        assert isinstance(atom, Atom)
        if atom.name in ['HN']:
            atom.atomId = atom.name
        else:
            atom.atomId=atom.parent.type+atom.name
        if atom.atomId not in DGatomIds: 
            atom.atomId=atom.element
        return atom.atomId


def colorByDG(nodes,obj):
	#from Pmv.colorPalette import ColorPalette, ColorPaletteFunction
	from Pmv.pmvPalettes import DavidGoodsell, DavidGoodsellSortedKeys
	c = 'Color palette for DG'
        palette = ColorPaletteFunction('DavidGoodsell',
                                            DavidGoodsell, readonly=0,
                                            info=c,
                                            sortedkeys=DavidGoodsellSortedKeys,
                                            lookupFunction=lookupDGFunc)

	#molecules, nodes = self.getNodes(nodes)
	#molecules, atms, nodes = self.getNodes(nodes, returnNodes=True)
	molecules = nodes.top.uniq()
	atms = map(lambda x: x.allAtoms, molecules)
	#nodes = expandNodes(nodes,molecules)

	colors = palette.lookup( atms[0] )
	#print colors
	
	for a, c in map(None, atms[0], colors):
		a.colors[obj.name] = tuple(c)
	#vcolors=atomPropToVertices(obj,obj.name,obj.msmsSurf,obj.msmsAtoms,'colors',propIndex=obj.name)
	#changeColor(obj.mesh,vcolors)

def colorByResidueType(nodes,obj):
        from Pmv.pmvPalettes import RasmolAmino, RasmolAminoSortedKeys
        c = 'Color palette for Rasmol like residues types'
        palette = ColorPalette('RasmolAmino', RasmolAmino, readonly=0,
                                    info=c,
                                    sortedkeys = RasmolAminoSortedKeys,
                                    lookupMember='type')


	#molecules, nodes = self.getNodes(nodes)
	#molecules, atms, nodes = self.getNodes(nodes, returnNodes=True)
	molecules = nodes.top.uniq()
	atms = map(lambda x: x.allAtoms, molecules)
	#nodes = expandNodes(nodes,molecules)
        print nodes.findType(Residue)
	colors = palette.lookup( nodes.findType(Residue) )
	#print colors
	
	for r, c in map(None, nodes.findType(Residue), colors):
            for a in r.atoms :
		a.colors[obj.name] = tuple(c)
	#vcolors=atomPropToVertices(obj,obj.name,obj.msmsSurf,obj.msmsAtoms,'colors',propIndex=obj.name)
	#changeColor(obj.mesh,vcolors)

def color(Type,nodes,obj):
	if Type == "AtomType" : colorByAtomType(nodes,obj)
	if Type == "DG" : colorByDG(nodes,obj)
	if Type == "ResidueType" : colorByResidueType(nodes,obj)
	#else : return
	if obj.name[0:8]=='MSMS-MOL' : 
		vcolors=atomPropToVertices(obj,obj.name,obj.msmsSurf,obj.msmsAtoms,'colors',propIndex=obj.name)
		changeColor(obj.b_obj,vcolors)
	else :
		atoms=nodes.findType(Atom)
		k=0
		#for o,me in map(None,obj.b_obj,obj.mesh):#blender object and mesh
		#	vcolors = [atoms[k].colors[obj.name],] #(0,0,0)
		#	changeColor(me,vcolors)
		#	k=k+1
