# This program is free software; you can redistribute it and/or modify
# it under the terms of the (LGPL) GNU Lesser General Public License as
# published by the Free Software Foundation; either version 3 of the 
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library Lesser General Public License for more details at
# ( http://www.gnu.org/licenses/lgpl.html ).
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
# written by: Jeff Ortel ( jortel@redhat.com )

"""
The I{wsdl} module provides an objectification of the WSDL.
The primary class is I{Definitions} as it represends the root element
found in the document.
"""

from logging import getLogger
from suds import *
from suds.sax import splitPrefix
from suds.sax.parser import Parser
from suds.sax.element import Element
from suds.bindings.document import Document
from suds.bindings.rpc import RPC, Encoded
from suds.xsd import qualify, Namespace
from suds.xsd.schema import Schema, SchemaCollection
from suds.xsd.query import ElementQuery
from suds.sudsobject import Object
from suds.sudsobject import Factory as SFactory
from urlparse import urljoin

log = getLogger(__name__)

wsdlns = (None, "http://schemas.xmlsoap.org/wsdl/")
soapns = (None, 'http://schemas.xmlsoap.org/wsdl/soap/')
soap12ns = (None, 'http://schemas.xmlsoap.org/wsdl/soap12/')


class Factory:
    """
    Simple WSDL object factory.
    @cvar tags: Dictionary of tag->constructor mappings.
    @type tags: dict
    """

    tags =\
    {
        'import' : lambda x,y: Import(x,y), 
        'types' : lambda x,y: Types(x,y), 
        'message' : lambda x,y: Message(x,y), 
        'portType' : lambda x,y: PortType(x,y),
        'binding' : lambda x,y: Binding(x,y),
        'service' : lambda x,y: Service(x,y),
    }
    
    @classmethod
    def create(cls, root, definitions):
        """
        Create an object based on the root tag name.
        @param root: An XML root element.
        @type root: L{Element}
        @param definitions: A definitions object.
        @type definitions: L{Definitions}
        @return: The created object.
        @rtype: L{WObject} 
        """
        fn = cls.tags.get(root.name)
        if fn is not None:
            return fn(root, definitions)
        else:
            return None


class WObject(Object):
    """
    Base object for wsdl types.
    @ivar root: The XML I{root} element.
    @type root: L{Element}
    """
    
    def __init__(self, root, definitions=None):
        """
        @param root: An XML root element.
        @type root: L{Element}
        @param definitions: A definitions object.
        @type definitions: L{Definitions}
        """
        Object.__init__(self)
        self.root = root
        pmd = SFactory.metadata()
        pmd.excludes = ['root']
        pmd.wrappers = dict(qname=lambda x: repr(x))
        self.__metadata__.__print__ = pmd
        
    def resolve(self, definitions):
        """
        Resolve named references to other WSDL objects.
        @param definitions: A definitions object.
        @type definitions: L{Definitions}
        """
        pass

        
class NamedObject(WObject):
    """
    A B{named} WSDL object.
    @ivar name: The name of the object.
    @type name: str
    @ivar qname: The I{qualified} name of the object.
    @type qname: (name, I{namespace-uri}).
    """

    def __init__(self, root, definitions):
        """
        @param root: An XML root element.
        @type root: L{Element}
        @param definitions: A definitions object.
        @type definitions: L{Definitions}
        """
        WObject.__init__(self, root, definitions)
        self.name = root.get('name')
        self.qname = (self.name, definitions.tns[1])
        pmd = self.__metadata__.__print__
        pmd.wrappers['qname'] = lambda x: repr(x)


class Definitions(WObject):
    """
    Represents the I{root} container of the WSDL objects as defined
    by <wsdl:definitions/>
    @ivar id: The object id.
    @type id: str
    @ivar options: An options dictionary.
    @type options: L{options.Options}
    @ivar url: The URL used to load the object.
    @type url: str
    @ivar tns: The target namespace for the WSDL.
    @type tns: str
    @ivar schema: The collective WSDL schema object.
    @type schema: L{SchemaCollection}
    @ivar children: The raw list of child objects.
    @type children: [L{WObject},...]
    @ivar imports: The list of L{Import} children.
    @type imports: [L{Import},...]
    @ivar messages: The dictionary of L{Message} children key'd by I{qname}
    @type messages: [L{Message},...]
    @ivar port_types: The dictionary of L{PortType} children key'd by I{qname}
    @type port_types: [L{PortType},...]
    @ivar bindings: The dictionary of L{Binding} children key'd by I{qname}
    @type bindings: [L{Binding},...]
    @ivar service: The service object.
    @type service: L{Service}
    """
    
    Tag = 'definitions'

    def __init__(self, url, options):
        """
        @param url: A URL to the WSDL.
        @type url: str
        @param options: An options dictionary.
        @type options: L{options.Options}
        """
        log.debug('reading wsdl at: %s ...', url)
        p = Parser(options.transport)
        root = p.parse(url=url).root()
        WObject.__init__(self, root)
        self.id = objid(self)
        self.options = options
        self.url = url
        self.tns = self.mktns(root)
        self.types = []
        self.schema = None
        self.children = []
        self.imports = []
        self.messages = {}
        self.port_types = {}
        self.bindings = {}
        self.service = None
        self.add_children(self.root)
        self.children.sort()
        pmd = self.__metadata__.__print__
        pmd.excludes.append('children')
        pmd.excludes.append('wsdl')
        pmd.wrappers['schema'] = lambda x: repr(x)
        self.open_imports()
        self.resolve()
        self.build_schema()
        self.set_wrapped()
        if self.service is not None:
            self.add_methods()
        log.debug("wsdl at '%s' loaded:\n%s", url, self)
        
    def mktns(self, root):
        """ Get/create the target namespace """
        tns = root.get('targetNamespace')
        prefix = root.findPrefix(tns)
        if prefix is None:
            log.debug('warning: tns (%s), not mapped to prefix', tns)
            prefix = 'tns'
        return (prefix, tns)
        
    def add_children(self, root):
        """ Add child objects using the factory """
        for c in root.getChildren(ns=wsdlns):
            child = Factory.create(c, self)
            if child is None: continue
            self.children.append(child)
            if isinstance(child, Import):
                self.imports.append(child)
                continue
            if isinstance(child, Types):
                self.types.append(child)
                continue
            if isinstance(child, Message):
                self.messages[child.qname] = child
                continue
            if isinstance(child, PortType):
                self.port_types[child.qname] = child
                continue
            if isinstance(child, Binding):
                self.bindings[child.qname] = child
                continue
            if isinstance(child, Service):
                self.service = child
                continue
                
    def open_imports(self):
        """ Import the I{imported} WSDLs. """
        for imp in self.imports:
            base = self.url
            imp.load(self)
                
    def resolve(self):
        """ Tell all children to resolve themselves """
        for c in self.children:
            c.resolve(self)
                
    def build_schema(self):
        """ Process L{Types} objects and create the schema collection """
        container = SchemaCollection(self)
        for t in [t for t in self.types if t.local()]:
            for r in t.contents():
                entry = (r, self)
                container.add(entry)
        if not len(container): # empty
            r = Element.buildPath(self.root, 'types/schema')
            entry = (r, self)
            container.add(entry)
        self.schema = container.load()
        for s in [t.schema() for t in self.types if t.imported()]:
            self.schema.merge(s)
        return self.schema
                
    def add_methods(self):
        """ Build method view for service """
        bindings = {
            'document/literal' : Document(self),
            'rpc/literal' : RPC(self),
            'rpc/encoded' : Encoded(self)
        }
        for p in self.service.ports:
            binding = p.binding
            ptype = p.binding.type
            operations = p.binding.type.operations.values()
            for name in [op.name for op in operations]:
                m = SFactory.object('Method')
                m.name = name
                m.location = p.location
                m.binding = SFactory.object('binding')
                op = binding.operation(name)
                m.soap = op.soap
                key = '/'.join((op.soap.style, op.soap.input.body.use))
                m.binding.input = bindings.get(key)
                key = '/'.join((op.soap.style, op.soap.output.body.use))
                m.binding.output = bindings.get(key)
                op = ptype.operation(name)
                m.message = SFactory.object('message')
                m.message.input = op.input
                m.message.output = op.output
                m.qname = ':'.join((p.name, name))
                self.service.methods[m.name] = m
                self.service.methods[m.qname] = m
                
    def set_wrapped(self):
        """ set (wrapped|bare) flag on messages """
        for m in self.messages.values():
            m.wrapped = False
            if len(m.parts) != 1:
                continue
            for p in m.parts:
                if p.element is None:
                    continue
                query = ElementQuery(p.element)
                pt = query.execute(self.schema)
                if pt is None:
                    raise TypeNotFound(query.ref)
                resolved = pt.resolve()
                if resolved.builtin():
                    continue
                m.wrapped = True
            


class Import(WObject):
    """
    Represents the <wsdl:import/>.
    @ivar location: The value of the I{location} attribute.
    @type location: str
    @ivar ns: The value of the I{namespace} attribute.
    @type ns: str
    @ivar imported: The imported object.
    @type imported: L{Definitions}
    """
    
    def __init__(self, root, definitions):
        """
        @param root: An XML root element.
        @type root: L{Element}
        @param definitions: A definitions object.
        @type definitions: L{Definitions}
        """
        WObject.__init__(self, root, definitions)
        self.location = root.get('location')
        self.ns = root.get('namespace')
        self.imported = None
        pmd = self.__metadata__.__print__
        pmd.wrappers['imported'] = ( lambda x: x.id )
        
    def load(self, definitions):
        """ Load the object by opening the URL """
        url = self.location
        log.debug('importing (%s)', url)
        if '://' not in url:
            url = urljoin(definitions.url, url)
        d = Definitions(url, definitions.options)
        if d.root.match(Definitions.Tag, wsdlns):
            self.import_definitions(definitions, d)
            return
        if d.root.match(Schema.Tag, Namespace.xsdns):
            self.import_schema(definitions, d)
            return
        raise Exception('document at "%s" is unknown' % url)
    
    def import_definitions(self, definitions, d):
        """ import/merge wsdl definitions """
        definitions.types += d.types
        definitions.messages.update(d.messages)
        definitions.port_types.update(d.port_types)
        definitions.bindings.update(d.bindings)
        self.imported = d
        log.debug('imported (WSDL):\n%s', d)
        
    def import_schema(self, definitions, d):
        """ import schema as <types/> content """
        if not len(definitions.types):
            types = Types.create(definitions)
            definitions.types.append(types)
        else:
            types = definitions.types[:-1]
        types.root.append(d.root)
        log.debug('imported (XSD):\n%s', d.root)
   
    def __gt__(self, other):
        return False
        

class Types(WObject):
    """
    Represents <types><schema/></types>.
    """
    
    @classmethod
    def create(cls, definitions):
        root = Element('types', ns=wsdlns)
        definitions.root.insert(root)
        return Types(root, definitions)

    def __init__(self, root, definitions):
        """
        @param root: An XML root element.
        @type root: L{Element}
        @param definitions: A definitions object.
        @type definitions: L{Definitions}
        """
        WObject.__init__(self, root, definitions)
        self.definitions = definitions
        
    def contents(self):
        return self.root.getChildren('schema', Namespace.xsdns)
    
    def schema(self):
        return self.definitions.schema
    
    def local(self):
        return ( self.definitions.schema is None )
    
    def imported(self):
        return ( not self.local() )
        
    def __gt__(self, other):
        return isinstance(other, Import)
    

class Part(NamedObject):
    """
    Represents <message><part/></message>.
    @ivar element: The value of the {element} attribute.
        Stored as a I{qref} as converted by L{suds.xsd.qualify}.
    @type element: str
    @ivar type: The value of the {type} attribute.
        Stored as a I{qref} as converted by L{suds.xsd.qualify}.
    @type type: str
    """

    def __init__(self, root, definitions):
        """
        @param root: An XML root element.
        @type root: L{Element}
        @param definitions: A definitions object.
        @type definitions: L{Definitions}
        """
        NamedObject.__init__(self, root, definitions)
        pmd = SFactory.metadata()
        pmd.wrappers = \
            dict(element=lambda x: repr(x), type=lambda x: repr(x))
        self.__metadata__.__print__ = pmd
        tns = definitions.tns
        self.element = self.__getref('element', tns)
        self.type = self.__getref('type', tns)
        
    def __getref(self, a, tns):
        """ Get the qualified value of attribute named 'a'."""
        s = self.root.get(a)
        if s is None:
            return s
        else:
            return qualify(s, self.root, tns)  


class Message(NamedObject):
    """
    Represents <message/>.
    @ivar parts: A list of message parts.
    @type parts: [I{Part},...]
    """

    def __init__(self, root, definitions):
        """
        @param root: An XML root element.
        @type root: L{Element}
        @param definitions: A definitions object.
        @type definitions: L{Definitions}
        """
        NamedObject.__init__(self, root, definitions)
        self.parts = []
        for p in root.getChildren('part'):
            part = Part(p, definitions)
            self.parts.append(part)
            
    def __gt__(self, other):
        return isinstance(other, (Import, Types))
    
    
class PortType(NamedObject):
    """
    Represents <portType/>.
    @ivar operations: A list of contained operations.
    @type operations: list
    """

    def __init__(self, root, definitions):
        """
        @param root: An XML root element.
        @type root: L{Element}
        @param definitions: A definitions object.
        @type definitions: L{Definitions}
        """
        NamedObject.__init__(self, root, definitions)
        self.operations = {}
        for c in root.getChildren('operation'):
            op = SFactory.object('Operation')
            op.name = c.get('name')
            op.tns = definitions.tns
            input = c.getChild('input')
            op.input = input.get('message')
            output = c.getChild('output', default=input)
            if output is None:
                op.output = None
            else:
                op.output = output.get('message')
            self.operations[op.name] = op
            
    def resolve(self, definitions):
        """
        Resolve named references to other WSDL objects.
        @param definitions: A definitions object.
        @type definitions: L{Definitions}
        """
        for op in self.operations.values():
            qref = qualify(op.input, self.root, definitions.tns)
            msg = definitions.messages.get(qref)
            if msg is None:
                raise Exception("msg '%s', not-found" % op.input)
            else:
                op.input = msg
            qref = qualify(op.output, self.root, definitions.tns)
            msg = definitions.messages.get(qref)
            if msg is None:
                raise Exception("msg '%s', not-found" % op.output)
            else:
                op.output = msg
                
    def operation(self, name):
        """
        Shortcut used to get a contained operation by name.
        @param name: An operation name.
        @type name: str
        @return: The named operation.
        @rtype: Operation
        @raise L{MethodNotFound}: When not found.
        """
        try:
            return self.operations[name]
        except Exception, e:
            raise MethodNotFound(name)
                
    def __gt__(self, other):
        return isinstance(other, (Import, Types, Message))


class Binding(NamedObject):
    """
    Represents <binding/>
    @ivar operations: A list of contained operations.
    @type operations: list
    """

    def __init__(self, root, definitions):
        """
        @param root: An XML root element.
        @type root: L{Element}
        @param definitions: A definitions object.
        @type definitions: L{Definitions}
        """
        NamedObject.__init__(self, root, definitions)
        self.operations = {}
        self.type = root.get('type')
        sr = self.soaproot()
        if sr is None:
            self.soap = None
            log.debug('binding: "%s" not a soap binding', self.name)
            return
        soap = SFactory.object('soap')
        self.soap = soap
        self.soap.style = sr.get('style', default='document')
        self.add_operations(self.root, definitions)
        
    def soaproot(self):
        """ get the soap:binding """
        for ns in (soapns, soap12ns):
            sr =  self.root.getChild('binding', ns=ns)
            if sr is not None:
                return sr
        return None
        
    def add_operations(self, root, definitions):
        """ Add <operation/> children """
        dsop = Element('operation', ns=soapns)
        for c in root.getChildren('operation'):
            op = SFactory.object('Operation')
            op.name = c.get('name')
            sop = c.getChild('operation', default=dsop)
            soap = SFactory.object('soap')
            soap.action = '"%s"' % sop.get('soapAction', default='')
            soap.style = sop.get('style', default=self.soap.style)
            soap.input = SFactory.object('Input')
            soap.input.body = SFactory.object('Body')
            soap.input.headers = []
            soap.output = SFactory.object('Output')
            soap.output.body = SFactory.object('Body')
            soap.output.headers = []
            op.soap = soap
            input = c.getChild('input')
            if input is None:
                input = Element('input', ns=wsdlns)
            body = input.getChild('body')
            self.body(definitions, soap.input.body, body)
            for header in input.getChildren('header'):
                self.header(definitions, soap.input, header)
            output = c.getChild('output')
            if output is None:
                output = Element('output', ns=wsdlns)
            body = output.getChild('body')
            self.body(definitions, soap.output.body, output)
            for header in output.getChildren('header'):
                self.header(definitions, soap.output, header)
            self.operations[op.name] = op
            
    def body(self, definitions, body, root):
        """ add the input/output body properties """
        if root is None:
            body.use = 'literal'
            body.namespace = definitions.tns
            return
        body.use = root.get('use', default='literal')
        ns = root.get('namespace')
        if ns is None:
            body.namespace = definitions.tns
        else:
            prefix = root.findPrefix(ns, 'b0')
            body.namespace = (prefix, ns)
            
    def header(self, definitions, parent, root):
        """ add the input/output header properties """
        if root is None:
            return
        header = SFactory.object('Header')
        parent.headers.append(header)
        header.use = root.get('use', default='literal')
        ns = root.get('namespace')
        if ns is None:
            header.namespace = definitions.tns
        else:
            prefix = root.findPrefix(ns, 'h0')
            header.namespace = (prefix, ns)
        msg = root.get('message')
        if msg is not None:
            header.message = msg
        part = root.get('part')
        if part is not None:
            header.part = part
            
    def resolve(self, definitions):
        """
        Resolve named references to other WSDL objects.
        @param definitions: A definitions object.
        @type definitions: L{Definitions}
        """
        self.resolveport(definitions)
        self.resolveheaders(definitions)
        
    def resolveport(self, definitions):
        """
        Resolve port_type reference.
        @param definitions: A definitions object.
        @type definitions: L{Definitions}
        """
        ref = qualify(self.type, self.root, definitions.tns)
        port_type = definitions.port_types.get(ref)
        if port_type is None:
            raise Exception("portType '%s', not-found" % self.type)
        else:
            self.type = port_type
            
    def resolveheaders(self, definitions):
        """
        Resolve soap header I{message} references.
        @param definitions: A definitions object.
        @type definitions: L{Definitions}
        """
        for op in self.operations.values():
            soap = op.soap
            headers = soap.input.headers + soap.output.headers
            for header in headers:
                mn = header.message
                ref = qualify(mn, self.root, definitions.tns)
                message = definitions.messages.get(ref)
                if message is None:
                    raise Exception("message'%s', not-found" % mn)
                header.message = SFactory.object('Message')
                header.message.name = message.name
                header.message.qname = message.qname
                header.message.parts = []
                for p in message.parts:
                    if p.name == header.part:
                        header.message.parts.append(p)
                        break
            
    def operation(self, name):
        """
        Shortcut used to get a contained operation by name.
        @param name: An operation name.
        @type name: str
        @return: The named operation.
        @rtype: Operation
        @raise L{MethodNotFound}: When not found.
        """
        try:
            return self.operations[name]
        except:
            raise MethodNotFound(name)
            
    def __gt__(self, other):
        return ( not isinstance(other, Service) )


class Port(NamedObject):
    """
    Represents a service port.
    @ivar service: A service.
    @type service: L{Service}
    @ivar binding: A binding name.
    @type binding: str
    @ivar location: The service location (url).
    @type location: str
    """
    
    def __init__(self, root, definitions, service):
        """
        @param root: An XML root element.
        @type root: L{Element}
        @param definitions: A definitions object.
        @type definitions: L{Definitions}
        @param service: A service object.
        @type service: L{Service}
        """
        NamedObject.__init__(self, root, definitions)
        self.__service = service
        self.binding = root.get('binding')
        address = root.getChild('address')
        self.location = address.get('location').encode('utf-8')
        
    def method(self, name):
        """
        Get a method defined in this portType by name.
        @param name: A method name.
        @type name: str
        @return: The requested method object.
        @rtype: I{Method}
        """
        qname = ':'.join((self.name, name))
        return self.__service.method(qname)
        

class Service(NamedObject):
    """
    Represents <service/>.
    @ivar port: The contained ports.
    @type port: [Port,..]
    @ivar methods: The contained methods for all ports.
    @type methods: [Method,..]
    """
    
    def __init__(self, root, definitions):
        """
        @param root: An XML root element.
        @type root: L{Element}
        @param definitions: A definitions object.
        @type definitions: L{Definitions}
        """
        NamedObject.__init__(self, root, definitions)
        self.ports = []
        self.methods = {}
        for p in root.getChildren('port'):
            port = Port(p, definitions, self)
            self.ports.append(port)
            
    def port(self, name):
        """
        Locate a port by name.
        @param name: A port name.
        @type name: str
        @return: The port object.
        @rtype: L{Port} 
        """
        for p in self.ports:
            if p.name == name:
                return p
        return None
    
    def method(self, name):
        """
        Get a method defined in one of the portTypes by name.
        @param name: A method name.
        @type name: str
        @return: The requested method object.
        @rtype: I{Method}
        """
        return self.methods.get(name)
    
    def setlocation(self, url, names=None):
        """
        Override the invocation location (url) for service method.
        @param url: A url location.
        @type url: A url.
        @param names:  A list of method names.  None=ALL
        @type names: [str,..]
        """
        for m in self.methods.values():
            if names is None or m.name in names:
                m.location = url
        
    def resolve(self, definitions):
        """
        Resolve named references to other WSDL objects.
        Ports without soap bindings are discarded.
        @param definitions: A definitions object.
        @type definitions: L{Definitions}
        """
        filtered = []
        for p in self.ports:
            ref = qualify(p.binding, self.root, definitions.tns)
            binding = definitions.bindings.get(ref)
            if binding is None:
                raise Exception("binding '%s', not-found" % p.binding)
            if binding.soap is None:
                log.debug('binding "%s" - not a soap, discarded', binding.name)
                continue
            p.binding = binding
            filtered.append(p)
        self.ports = filtered
        
    def __gt__(self, other):
        return True
