392 lines
12 KiB
Python
392 lines
12 KiB
Python
import inspect
|
|
import traceback
|
|
import weakref
|
|
import logging
|
|
import webob
|
|
import sys
|
|
|
|
from wsme import exc
|
|
from wsme.types import register_type
|
|
|
|
__all__ = ['expose', 'validate']
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
registered_protocols = {}
|
|
|
|
APIPATH_MAXLEN = 20
|
|
|
|
|
|
html_body = """
|
|
<html>
|
|
<head>
|
|
<style type='text/css'>
|
|
%(css)s
|
|
</style>
|
|
</head>
|
|
<body>
|
|
%(content)s
|
|
</body>
|
|
</html>
|
|
"""
|
|
|
|
|
|
def scan_api(controller, path=[]):
|
|
"""
|
|
Recursively iterate a controller api entries, while setting
|
|
their :attr:`FunctionDefinition.path`.
|
|
"""
|
|
for name in dir(controller):
|
|
if name.startswith('_'):
|
|
continue
|
|
a = getattr(controller, name)
|
|
if inspect.ismethod(a):
|
|
if hasattr(a, '_wsme_definition'):
|
|
a._wsme_definition.path = path
|
|
yield a._wsme_definition
|
|
elif inspect.isclass(a):
|
|
continue
|
|
else:
|
|
if len(path) > APIPATH_MAXLEN:
|
|
raise ValueError("Path is too long: " + str(path))
|
|
for i in scan_api(a, path + [name]):
|
|
yield i
|
|
|
|
|
|
class FunctionArgument(object):
|
|
"""
|
|
An argument definition of an api entry
|
|
"""
|
|
def __init__(self, name, datatype, mandatory, default):
|
|
#: argument name
|
|
self.name = name
|
|
|
|
#: Data type
|
|
self.datatype = datatype
|
|
|
|
#: True if the argument is mandatory
|
|
self.mandatory = mandatory
|
|
|
|
#: Default value if argument is omitted
|
|
self.default = default
|
|
|
|
|
|
class FunctionDefinition(object):
|
|
"""
|
|
An api entry definition
|
|
"""
|
|
def __init__(self, func):
|
|
#: Function name
|
|
self.name = func.__name__
|
|
|
|
#: Function documentation
|
|
self.doc = func.__doc__
|
|
|
|
#: Return type
|
|
self.return_type = None
|
|
|
|
#: The function arguments (list of :class:`FunctionArgument`)
|
|
self.arguments = []
|
|
|
|
#: True if this function is exposed by a protocol and not in
|
|
#: the api tree, which means it is not part of the api.
|
|
self.protocol_specific = False
|
|
|
|
#: Override the contenttype of the returned value.
|
|
#: Make sense only with :attr:`protocol_specific` functions.
|
|
self.contenttype = None
|
|
|
|
#: Path of the function in the api tree.
|
|
self.path = None
|
|
|
|
@classmethod
|
|
def get(cls, func):
|
|
"""
|
|
Returns the :class:`FunctionDefinition` of a method.
|
|
"""
|
|
fd = getattr(func, '_wsme_definition', None)
|
|
if fd is None:
|
|
fd = FunctionDefinition(func)
|
|
func._wsme_definition = fd
|
|
return fd
|
|
|
|
def get_arg(self, name):
|
|
"""
|
|
Returns a :class:`FunctionArgument` from its name
|
|
"""
|
|
for arg in self.arguments:
|
|
if arg.name == name:
|
|
return arg
|
|
return None
|
|
|
|
|
|
def register_protocol(protocol):
|
|
global registered_protocols
|
|
registered_protocols[protocol.name] = protocol
|
|
|
|
|
|
class expose(object):
|
|
"""
|
|
Decorator that expose a function.
|
|
|
|
:param return_type: Return type of the function
|
|
|
|
Example::
|
|
|
|
class MyController(object):
|
|
@expose(int)
|
|
def getint(self):
|
|
return 1
|
|
"""
|
|
def __init__(self, return_type=None):
|
|
self.return_type = return_type
|
|
register_type(return_type)
|
|
|
|
def __call__(self, func):
|
|
fd = FunctionDefinition.get(func)
|
|
fd.return_type = self.return_type
|
|
return func
|
|
|
|
|
|
class pexpose(object):
|
|
def __init__(self, return_type=None, contenttype=None):
|
|
self.return_type = return_type
|
|
self.contenttype = contenttype
|
|
register_type(return_type)
|
|
|
|
def __call__(self, func):
|
|
fd = FunctionDefinition.get(func)
|
|
fd.return_type = self.return_type
|
|
fd.protocol_specific = True
|
|
fd.contenttype = self.contenttype
|
|
return func
|
|
|
|
|
|
class validate(object):
|
|
"""
|
|
Decorator that define the arguments types of a function.
|
|
|
|
|
|
Example::
|
|
|
|
class MyController(object):
|
|
@expose(str)
|
|
@validate(datetime.date, datetime.time)
|
|
def format(self, d, t):
|
|
return d.isoformat() + ' ' + t.isoformat()
|
|
"""
|
|
def __init__(self, *param_types):
|
|
self.param_types = param_types
|
|
|
|
def __call__(self, func):
|
|
fd = FunctionDefinition.get(func)
|
|
args, varargs, keywords, defaults = inspect.getargspec(func)
|
|
if args[0] == 'self':
|
|
args = args[1:]
|
|
for i, argname in enumerate(args):
|
|
datatype = self.param_types[i]
|
|
mandatory = defaults is None or i <= len(defaults)
|
|
default = None
|
|
if not mandatory:
|
|
default = defaults[i - (len(args) - len(defaults))]
|
|
fd.arguments.append(FunctionArgument(argname, datatype,
|
|
mandatory, default))
|
|
return func
|
|
|
|
|
|
class WSRoot(object):
|
|
"""
|
|
Root controller for webservices.
|
|
|
|
:param protocols: A list of protocols to enable (see :meth:`addprotocol`)
|
|
:param webpath: The web path where the webservice is published.
|
|
"""
|
|
def __init__(self, protocols=[], webpath=''):
|
|
self._debug = True
|
|
self._webpath = webpath
|
|
self.protocols = {}
|
|
for protocol in protocols:
|
|
self.addprotocol(protocol)
|
|
|
|
self._api = None
|
|
|
|
def addprotocol(self, protocol):
|
|
"""
|
|
Enable a new protocol on the controller.
|
|
|
|
:param protocol: A registered protocol name or an instance
|
|
of a protocol.
|
|
"""
|
|
if isinstance(protocol, str):
|
|
protocol = registered_protocols[protocol]()
|
|
self.protocols[protocol.name] = protocol
|
|
protocol.root = weakref.proxy(self)
|
|
|
|
def getapi(self):
|
|
"""
|
|
Returns the api description.
|
|
|
|
:rtype: list of :class:`FunctionDefinition`
|
|
"""
|
|
if self._api is None:
|
|
self._api = [i for i in scan_api(self)]
|
|
return self._api
|
|
|
|
def _select_protocol(self, request):
|
|
log.debug("Selecting a protocol for the following request :\n"
|
|
"headers: %s\nbody: %s", request.headers,
|
|
len(request.body) > 512
|
|
and request.body[:512]
|
|
or request.body)
|
|
protocol = None
|
|
if 'wsmeproto' in request.params:
|
|
protocol = self.protocols[request.params['wsmeproto']]
|
|
else:
|
|
|
|
for p in self.protocols.values():
|
|
if p.accept(request):
|
|
protocol = p
|
|
break
|
|
return protocol
|
|
|
|
def _handle_request(self, request):
|
|
res = webob.Response()
|
|
res_content_type = None
|
|
try:
|
|
protocol = self._select_protocol(request)
|
|
if protocol is None:
|
|
msg = ("None of the following protocols can handle this "
|
|
"request : %s" % ','.join(self.protocols.keys()))
|
|
res.status = 500
|
|
res.content_type = 'text/plain'
|
|
res.body = msg
|
|
log.error(msg)
|
|
return res
|
|
path = protocol.extract_path(request)
|
|
if path is None:
|
|
raise exc.ClientSideError(
|
|
u'The %s protocol was unable to extract a function '
|
|
u'path from the request' % protocol.name)
|
|
func, funcdef = self._lookup_function(path)
|
|
kw = protocol.read_arguments(funcdef, request)
|
|
|
|
for arg in funcdef.arguments:
|
|
if arg.mandatory and arg.name not in kw:
|
|
raise exc.MissingArgument(arg.name)
|
|
|
|
result = func(**kw)
|
|
|
|
res.status = 200
|
|
|
|
if funcdef.protocol_specific and funcdef.return_type is None:
|
|
res.body = result
|
|
else:
|
|
# TODO make sure result type == a._wsme_definition.return_type
|
|
res.body = protocol.encode_result(funcdef, result)
|
|
res_content_type = funcdef.contenttype
|
|
except Exception, e:
|
|
infos = self._format_exception(sys.exc_info())
|
|
if isinstance(e, exc.ClientSideError):
|
|
res.status = 400
|
|
else:
|
|
res.status = 500
|
|
res.body = protocol.encode_error(infos)
|
|
|
|
if res_content_type is None:
|
|
# Attempt to correctly guess what content-type we should return.
|
|
last_q = 0
|
|
if hasattr(request.accept, '_parsed'):
|
|
for mimetype, q in request.accept._parsed:
|
|
if mimetype in protocol.content_types and last_q < q:
|
|
res_content_type = mimetype
|
|
else:
|
|
res_content_type = request.accept.best_match([
|
|
ct for ct in protocol.content_types if ct])
|
|
|
|
# If not we will attempt to convert the body to an accepted
|
|
# output format.
|
|
if res_content_type is None:
|
|
if "text/html" in request.accept:
|
|
res.body = self._html_format(res.body, protocol.content_types)
|
|
res_content_type = "text/html"
|
|
|
|
# TODO should we consider the encoding asked by
|
|
# the web browser ?
|
|
res.headers['Content-Type'] = "%s; charset=UTF-8" % res_content_type
|
|
|
|
return res
|
|
|
|
def _lookup_function(self, path):
|
|
a = self
|
|
|
|
isprotocol_specific = path[0] == '_protocol'
|
|
|
|
if isprotocol_specific:
|
|
a = self.protocols[path[1]]
|
|
path = path[2:]
|
|
|
|
for name in path:
|
|
a = getattr(a, name, None)
|
|
if a is None:
|
|
break
|
|
|
|
if not hasattr(a, '_wsme_definition'):
|
|
raise exc.UnknownFunction('/'.join(path))
|
|
|
|
definition = a._wsme_definition
|
|
|
|
return a, definition
|
|
|
|
def _format_exception(self, excinfo):
|
|
"""Extract informations that can be sent to the client."""
|
|
error = excinfo[1]
|
|
if isinstance(error, exc.ClientSideError):
|
|
r = dict(faultcode="Client",
|
|
faultstring=error.faultstring)
|
|
log.warning("Client-side error: %s" % r['faultstring'])
|
|
r['debuginfo'] = None
|
|
return r
|
|
else:
|
|
faultstring = str(error)
|
|
debuginfo = "\n".join(traceback.format_exception(*excinfo))
|
|
|
|
log.error('Server-side error: "%s". Detail: \n%s' % (
|
|
faultstring, debuginfo))
|
|
|
|
r = dict(faultcode="Server", faultstring=faultstring)
|
|
if self._debug:
|
|
r['debuginfo'] = debuginfo
|
|
else:
|
|
r['debuginfo'] = None
|
|
return r
|
|
|
|
def _html_format(self, content, content_types):
|
|
try:
|
|
from pygments import highlight
|
|
from pygments.lexers import get_lexer_for_mimetype
|
|
from pygments.formatters import HtmlFormatter
|
|
|
|
lexer = None
|
|
for ct in content_types:
|
|
try:
|
|
print ct
|
|
lexer = get_lexer_for_mimetype(ct)
|
|
break
|
|
except:
|
|
pass
|
|
|
|
if lexer is None:
|
|
raise ValueError("No lexer found")
|
|
formatter = HtmlFormatter()
|
|
return html_body % dict(
|
|
css=formatter.get_style_defs(),
|
|
content=highlight(content, lexer, formatter).encode('utf8'))
|
|
except Exception, e:
|
|
log.warning(
|
|
"Could not pygment the content because of the following "
|
|
"error :\n%s" % e)
|
|
return html_body % dict(
|
|
css='',
|
|
content='<pre>%s</pre>' %
|
|
content.replace('>', '>').replace('<', '<'))
|