wsme/wsme/root.py

368 lines
12 KiB
Python

import logging
import sys
import weakref
import webob
from wsme.exc import ClientSideError, UnknownFunction
from wsme.protocol import getprotocol
from wsme.rest import scan_api
import wsme.api
import wsme.types
log = logging.getLogger(__name__)
html_body = """
<html>
<head>
<style type='text/css'>
%(css)s
</style>
</head>
<body>
%(content)s
</body>
</html>
"""
def default_prepare_response_body(request, results):
r = None
sep = None
for value in results:
if sep is None:
if isinstance(value, str):
sep = '\n'
r = ''
else:
sep = b'\n'
r = b''
else:
r += sep
r += value
return r
class DummyTransaction:
def commit(self):
pass
def abort(self):
pass
class WSRoot(object):
"""
Root controller for webservices.
:param protocols: A list of protocols to enable (see :meth:`addprotocol`)
:param webpath: The web path where the webservice is published.
:type transaction: A `transaction
<http://pypi.python.org/pypi/transaction>`_-like
object or ``True``.
:param transaction: If specified, a transaction will be created and
handled on a per-call base.
This option *can* be enabled along with `repoze.tm2
<http://pypi.python.org/pypi/repoze.tm2>`_
(it will only make it void).
If ``True``, the default :mod:`transaction`
module will be imported and used.
"""
__registry__ = wsme.types.registry
def __init__(self, protocols=[], webpath='', transaction=None,
scan_api=scan_api):
self._debug = True
self._webpath = webpath
self.protocols = []
self._scan_api = scan_api
self._transaction = transaction
if self._transaction is True:
import transaction
self._transaction = transaction
for protocol in protocols:
self.addprotocol(protocol)
self._api = None
def wsgiapp(self):
"""Returns a wsgi application"""
from webob.dec import wsgify
return wsgify(self._handle_request)
def begin(self):
if self._transaction:
return self._transaction.begin()
else:
return DummyTransaction()
def addprotocol(self, protocol, **options):
"""
Enable a new protocol on the controller.
:param protocol: A registered protocol name or an instance
of a protocol.
"""
if isinstance(protocol, str):
protocol = getprotocol(protocol, **options)
self.protocols.append(protocol)
protocol.root = weakref.proxy(self)
def getapi(self):
"""
Returns the api description.
:rtype: list of (path, :class:`FunctionDefinition`)
"""
if self._api is None:
self._api = [
(path, f, f._wsme_definition, args)
for path, f, args in self._scan_api(self)
]
for path, f, fdef, args in self._api:
fdef.resolve_types(self.__registry__)
return [
(path, fdef)
for path, f, fdef, args in self._api
]
def _get_protocol(self, name):
for protocol in self.protocols:
if protocol.name == name:
return protocol
def _select_protocol(self, request):
log.debug("Selecting a protocol for the following request :\n"
"headers: %s\nbody: %s", request.headers.items(),
request.content_length and (
request.content_length > 512 and
request.body[:512] or
request.body) or '')
protocol = None
error = ClientSideError(status_code=406)
path = str(request.path)
assert path.startswith(self._webpath)
path = path[len(self._webpath) + 1:]
if 'wsmeproto' in request.params:
return self._get_protocol(request.params['wsmeproto'])
else:
for p in self.protocols:
try:
if p.accept(request):
protocol = p
break
except ClientSideError as e:
error = e
# If we could not select a protocol, we raise the last exception
# that we got, or the default one.
if not protocol:
raise error
return protocol
def _do_call(self, protocol, context):
request = context.request
request.calls.append(context)
try:
if context.path is None:
context.path = protocol.extract_path(context)
if context.path is None:
raise ClientSideError(
'The %s protocol was unable to extract a function '
'path from the request' % protocol.name)
context.func, context.funcdef, args = \
self._lookup_function(context.path)
kw = protocol.read_arguments(context)
args = list(args)
txn = self.begin()
try:
result = context.func(*args, **kw)
txn.commit()
except Exception:
txn.abort()
raise
else:
# TODO make sure result type == a._wsme_definition.return_type
return protocol.encode_result(context, result)
except Exception as e:
infos = wsme.api.format_exception(sys.exc_info(), self._debug)
if isinstance(e, ClientSideError):
request.client_errorcount += 1
request.client_last_status_code = e.code
else:
request.server_errorcount += 1
return protocol.encode_error(context, infos)
def find_route(self, path):
for p in self.protocols:
for routepath, func in p.iter_routes():
if path.startswith(routepath):
return routepath, func
return None, None
def _handle_request(self, request):
res = webob.Response()
res_content_type = None
path = request.path
if path.startswith(self._webpath):
path = path[len(self._webpath):]
routepath, func = self.find_route(path)
if routepath:
content = func()
if isinstance(content, str):
res.text = content
elif isinstance(content, bytes):
res.body = content
res.content_type = func._cfg['content-type']
return res
try:
msg = None
error_status = 500
protocol = self._select_protocol(request)
except ClientSideError as e:
error_status = e.code
msg = e.faultstring
protocol = None
except Exception as e:
msg = ("Unexpected error while selecting protocol: %s" % str(e))
log.exception(msg)
protocol = None
error_status = 500
if protocol is None:
if not msg:
msg = ("None of the following protocols can handle this "
"request : %s" % ','.join((
p.name for p in self.protocols)))
res.status = error_status
res.content_type = 'text/plain'
try:
res.text = str(msg)
except TypeError:
res.text = msg
log.error(msg)
return res
request.calls = []
request.client_errorcount = 0
request.client_last_status_code = None
request.server_errorcount = 0
try:
context = None
if hasattr(protocol, 'prepare_response_body'):
prepare_response_body = protocol.prepare_response_body
else:
prepare_response_body = default_prepare_response_body
body = prepare_response_body(request, (
self._do_call(protocol, context)
for context in protocol.iter_calls(request)))
if isinstance(body, str):
res.text = body
else:
res.body = body
if len(request.calls) == 1:
if hasattr(protocol, 'get_response_status'):
res.status = protocol.get_response_status(request)
else:
if request.client_errorcount == 1:
res.status = request.client_last_status_code
elif request.client_errorcount:
res.status = 400
elif request.server_errorcount:
res.status = 500
else:
res.status = 200
else:
res.status = protocol.get_response_status(request)
res_content_type = protocol.get_response_contenttype(request)
except ClientSideError as e:
request.server_errorcount += 1
res.status = e.code
res.text = e.faultstring
except Exception:
infos = wsme.api.format_exception(sys.exc_info(), self._debug)
request.server_errorcount += 1
res.text = protocol.encode_error(context, infos)
res.status = 500
if res_content_type is None:
# Attempt to correctly guess what content-type we should return.
ctypes = [ct for ct in protocol.content_types if ct]
if ctypes:
try:
offers = request.accept.acceptable_offers(ctypes)
res_content_type = offers[0][0]
except IndexError:
res_content_type = None
# If not we will attempt to convert the body to an accepted
# output format.
if res_content_type is None:
if "text/html" in request.accept:
res.text = self._html_format(res.body, protocol.content_types)
res_content_type = "text/html"
# TODO should we consider the encoding asked by
# the web browser ?
res.headers['Content-Type'] = "%s; charset=UTF-8" % res_content_type
return res
def _lookup_function(self, path):
if not self._api:
self.getapi()
for fpath, f, fdef, args in self._api:
if path == fpath:
return f, fdef, args
raise UnknownFunction('/'.join(path))
def _html_format(self, content, content_types):
try:
from pygments import highlight
from pygments.lexers import get_lexer_for_mimetype
from pygments.formatters import HtmlFormatter
lexer = None
for ct in content_types:
try:
lexer = get_lexer_for_mimetype(ct)
break
except Exception:
pass
if lexer is None:
raise ValueError("No lexer found")
formatter = HtmlFormatter()
return html_body % dict(
css=formatter.get_style_defs(),
content=highlight(content, lexer, formatter).encode('utf8'))
except Exception as e:
log.warning(
"Could not pygment the content because of the following "
"error :\n%s" % e)
return html_body % dict(
css='',
content='<pre>%s</pre>' %
content.replace(b'>', b'&gt;')
.replace(b'<', b'&lt;'))