141 lines
4.0 KiB
Python
141 lines
4.0 KiB
Python
import inspect
|
|
import traceback
|
|
import weakref
|
|
import logging
|
|
|
|
from wsme import exc
|
|
from wsme.types import register_type
|
|
|
|
__all__ = ['expose', 'validate', 'WSRoot']
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
registered_protocols = {}
|
|
|
|
|
|
def scan_api(controller, path=[]):
|
|
for name in dir(controller):
|
|
if name.startswith('_'):
|
|
continue
|
|
a = getattr(controller, name)
|
|
if hasattr(a, '_wsme_definition'):
|
|
yield path, a._wsme_definition
|
|
else:
|
|
for i in scan_api(a, path + [name]):
|
|
yield i
|
|
|
|
|
|
class FunctionArgument(object):
|
|
def __init__(self, name, datatype, mandatory, default):
|
|
self.name = name
|
|
self.datatype = datatype
|
|
self.mandatory = mandatory
|
|
self.default = default
|
|
|
|
|
|
class FunctionDefinition(object):
|
|
def __init__(self, func):
|
|
self.name = func.__name__
|
|
self.return_type = None
|
|
self.arguments = []
|
|
|
|
@classmethod
|
|
def get(cls, func):
|
|
fd = getattr(func, '_wsme_definition', None)
|
|
if fd is None:
|
|
fd = FunctionDefinition(func)
|
|
func._wsme_definition = fd
|
|
return fd
|
|
|
|
|
|
def register_protocol(protocol):
|
|
global registered_protocols
|
|
registered_protocols[protocol.name] = protocol
|
|
|
|
|
|
class expose(object):
|
|
def __init__(self, return_type=None):
|
|
self.return_type = return_type
|
|
register_type(return_type)
|
|
|
|
def __call__(self, func):
|
|
fd = FunctionDefinition.get(func)
|
|
fd.return_type = self.return_type
|
|
return func
|
|
|
|
|
|
class validate(object):
|
|
def __init__(self, *args, **kw):
|
|
self.param_types = args
|
|
|
|
def __call__(self, func):
|
|
fd = FunctionDefinition.get(func)
|
|
args, varargs, keywords, defaults = inspect.getargspec(func)
|
|
if args[0] == 'self':
|
|
args = args[1:]
|
|
for i, argname in enumerate(args):
|
|
datatype = self.param_types[i]
|
|
mandatory = defaults is None or i <= len(defaults)
|
|
default = None
|
|
if not mandatory:
|
|
default = defaults[i - (len(args) - len(defaults))]
|
|
fd.arguments.append(FunctionArgument(argname, datatype,
|
|
mandatory, default))
|
|
return func
|
|
|
|
|
|
class WSRoot(object):
|
|
def __init__(self, protocols=None):
|
|
self._debug = True
|
|
if protocols is None:
|
|
protocols = registered_protocols.keys()
|
|
self.protocols = {}
|
|
for protocol in protocols:
|
|
if isinstance(protocol, str):
|
|
protocol = registered_protocols[protocol]()
|
|
self.protocols[protocol.name] = protocol
|
|
|
|
def _handle_request(self, request):
|
|
protocol = None
|
|
if 'wsmeproto' in request.params:
|
|
protocol = self.protocols[request.params['wsmeproto']]
|
|
else:
|
|
for p in self.protocols.values():
|
|
if p.accept(self, request):
|
|
protocol = p
|
|
break
|
|
|
|
return protocol.handle(self, request)
|
|
|
|
def _format_exception(self, excinfo):
|
|
"""Extract informations that can be sent to the client."""
|
|
if isinstance(excinfo[1], exc.ClientSideError):
|
|
r = dict(faultcode="Client",
|
|
faultstring=unicode(excinfo[1]))
|
|
log.warning("Client-side error: %s" % r['faultstring'])
|
|
return r
|
|
else:
|
|
faultstring = str(excinfo[1])
|
|
debuginfo = "\n".join(traceback.format_exception(*excinfo))
|
|
|
|
log.error('Server-side error: "%s". Detail: \n%s' % (
|
|
faultstring, debuginfo))
|
|
|
|
r = dict(faultcode="Server", faultstring=faultstring)
|
|
if self._debug:
|
|
r['debuginfo'] = debuginfo
|
|
return r
|
|
|
|
def _lookup_function(self, path):
|
|
a = self
|
|
|
|
for name in path:
|
|
a = getattr(a, name, None)
|
|
if a is None:
|
|
break
|
|
|
|
if not hasattr(a, '_wsme_definition'):
|
|
raise exc.UnknownFunction('/'.join(path))
|
|
|
|
return a, a._wsme_definition
|