Move as much as possible the request handling code out of the protocol
This commit is contained in:
parent
b3c3501255
commit
3a2280bd6a
5
setup.py
5
setup.py
@ -3,5 +3,8 @@ from setuptools import setup
|
||||
setup(
|
||||
name='wsme',
|
||||
packages=['wsme'],
|
||||
install_requires=['webob'],
|
||||
install_requires=[
|
||||
'simplegeneric',
|
||||
'webob',
|
||||
],
|
||||
)
|
||||
|
@ -2,6 +2,8 @@ import inspect
|
||||
import traceback
|
||||
import weakref
|
||||
import logging
|
||||
import webob
|
||||
import sys
|
||||
|
||||
from wsme import exc
|
||||
from wsme.types import register_type
|
||||
@ -13,6 +15,20 @@ log = logging.getLogger(__name__)
|
||||
registered_protocols = {}
|
||||
|
||||
|
||||
html_body = """
|
||||
<html>
|
||||
<head>
|
||||
<style type='text/css'>
|
||||
%(css)s
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
%(content)s
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def scan_api(controller, path=[]):
|
||||
for name in dir(controller):
|
||||
if name.startswith('_'):
|
||||
@ -95,7 +111,7 @@ class WSRoot(object):
|
||||
protocol = registered_protocols[protocol]()
|
||||
self.protocols[protocol.name] = protocol
|
||||
|
||||
def _handle_request(self, request):
|
||||
def _select_protocol(self, request):
|
||||
protocol = None
|
||||
if 'wsmeproto' in request.params:
|
||||
protocol = self.protocols[request.params['wsmeproto']]
|
||||
@ -104,8 +120,70 @@ class WSRoot(object):
|
||||
if p.accept(self, request):
|
||||
protocol = p
|
||||
break
|
||||
return protocol
|
||||
|
||||
return protocol.handle(self, request)
|
||||
def _handle_request(self, request):
|
||||
res = webob.Response()
|
||||
try:
|
||||
protocol = self._select_protocol(request)
|
||||
if protocol is None:
|
||||
msg = ("None of the following protocols can handle this "
|
||||
"request : %s" % ','.join(self.protocols.keys()))
|
||||
res.status = 500
|
||||
res.text = msg
|
||||
log.error(msg)
|
||||
return res
|
||||
path = protocol.extract_path(request)
|
||||
func, funcdef = self._lookup_function(path)
|
||||
kw = protocol.read_arguments(request, funcdef.arguments)
|
||||
|
||||
result = func(**kw)
|
||||
|
||||
# TODO make sure result type == a._wsme_definition.return_type
|
||||
res.status = 200
|
||||
res.body = protocol.encode_result(result, funcdef.return_type)
|
||||
except Exception, e:
|
||||
res.status = 500
|
||||
res.body = protocol.encode_error(
|
||||
self._format_exception(sys.exc_info()))
|
||||
|
||||
# Attempt to correctly guess what content-type we should return.
|
||||
res_content_type = None
|
||||
|
||||
last_q = 0
|
||||
if hasattr(request.accept, '_parsed'):
|
||||
for mimetype, q in request.accept._parsed:
|
||||
if mimetype in protocol.content_types and last_q < q:
|
||||
res_content_type = mimetype
|
||||
else:
|
||||
res_content_type = request.accept.best_match([
|
||||
ct for ct in protocol.content_types if ct])
|
||||
|
||||
# If not we will attempt to convert the body to an accepted
|
||||
# output format.
|
||||
if res_content_type is None:
|
||||
if "text/html" in request.accept:
|
||||
res.body = self._html_format(res.body, protocol.content_types)
|
||||
res_content_type = "text/html"
|
||||
|
||||
# TODO should we consider the encoding asked by
|
||||
# the web browser ?
|
||||
res.headers['Content-Type'] = "%s; charset=UTF-8" % res_content_type
|
||||
|
||||
return res
|
||||
|
||||
def _lookup_function(self, path):
|
||||
a = self
|
||||
|
||||
for name in path:
|
||||
a = getattr(a, name, None)
|
||||
if a is None:
|
||||
break
|
||||
|
||||
if not hasattr(a, '_wsme_definition'):
|
||||
raise exc.UnknownFunction('/'.join(path))
|
||||
|
||||
return a, a._wsme_definition
|
||||
|
||||
def _format_exception(self, excinfo):
|
||||
"""Extract informations that can be sent to the client."""
|
||||
@ -126,15 +204,32 @@ class WSRoot(object):
|
||||
r['debuginfo'] = debuginfo
|
||||
return r
|
||||
|
||||
def _lookup_function(self, path):
|
||||
a = self
|
||||
def _html_format(self, content, content_types):
|
||||
try:
|
||||
from pygments import highlight
|
||||
from pygments.lexers import get_lexer_for_mimetype
|
||||
from pygments.formatters import HtmlFormatter
|
||||
|
||||
for name in path:
|
||||
a = getattr(a, name, None)
|
||||
if a is None:
|
||||
lexer = None
|
||||
for ct in content_types:
|
||||
try:
|
||||
print ct
|
||||
lexer = get_lexer_for_mimetype(ct)
|
||||
break
|
||||
except:
|
||||
pass
|
||||
|
||||
if not hasattr(a, '_wsme_definition'):
|
||||
raise exc.UnknownFunction('/'.join(path))
|
||||
|
||||
return a, a._wsme_definition
|
||||
if lexer is None:
|
||||
raise ValueError("No lexer found")
|
||||
formatter = HtmlFormatter()
|
||||
return html_body % dict(
|
||||
css=formatter.get_style_defs(),
|
||||
content=highlight(content, lexer, formatter).encode('utf8'))
|
||||
except Exception, e:
|
||||
log.warning(
|
||||
"Could not pygment the content because of the following "
|
||||
"error :\n%s" % e)
|
||||
return html_body % dict(
|
||||
css='',
|
||||
content='<pre>%s</pre>' %
|
||||
content.replace('>', '>').replace('<', '<'))
|
||||
|
82
wsme/rest.py
82
wsme/rest.py
@ -1,24 +1,9 @@
|
||||
import webob
|
||||
import sys
|
||||
import logging
|
||||
|
||||
from wsme.exc import UnknownFunction, MissingArgument, UnknownArgument
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
html_body = """
|
||||
<html>
|
||||
<head>
|
||||
<style type='text/css'>
|
||||
%(css)s
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
%(content)s
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
class RestProtocol(object):
|
||||
name = None
|
||||
@ -66,73 +51,10 @@ class RestProtocol(object):
|
||||
raise UnknownArgument(parsed_args.keys()[0])
|
||||
return kw
|
||||
|
||||
def handle(self, root, request):
|
||||
def extract_path(self, request):
|
||||
path = request.path.strip('/').split('/')
|
||||
|
||||
if path[-1].endswith('.' + self.dataformat):
|
||||
path[-1] = path[-1][:-len(self.dataformat) - 1]
|
||||
|
||||
res = webob.Response()
|
||||
|
||||
try:
|
||||
func, funcdef = root._lookup_function(path)
|
||||
kw = self.read_arguments(request, funcdef.arguments)
|
||||
result = func(**kw)
|
||||
# TODO make sure result type == a._wsme_definition.return_type
|
||||
res.body = self.encode_result(result, funcdef.return_type)
|
||||
res.status = "200 OK"
|
||||
except Exception, e:
|
||||
res.status = 500
|
||||
res.body = self.encode_error(
|
||||
root._format_exception(sys.exc_info()))
|
||||
|
||||
# Attempt to correctly guess what content-type we should return.
|
||||
res_content_type = None
|
||||
|
||||
last_q = 0
|
||||
if hasattr(request.accept, '_parsed'):
|
||||
for mimetype, q in request.accept._parsed:
|
||||
if mimetype in self.content_types and last_q < q:
|
||||
res_content_type = mimetype
|
||||
else:
|
||||
res_content_type = request.accept.best_match([
|
||||
ct for ct in self.content_types if ct])
|
||||
|
||||
# If not we will attempt to convert the body to an accepted
|
||||
# output format.
|
||||
if res_content_type is None:
|
||||
if "text/html" in request.accept:
|
||||
res.body = self.html_format(res.body)
|
||||
res_content_type = "text/html"
|
||||
|
||||
res.headers['Content-Type'] = "%s; charset=UTF-8" % res_content_type
|
||||
|
||||
return res
|
||||
|
||||
def html_format(self, content):
|
||||
try:
|
||||
from pygments import highlight
|
||||
from pygments.lexers import get_lexer_for_mimetype
|
||||
from pygments.formatters import HtmlFormatter
|
||||
|
||||
lexer = None
|
||||
for ct in self.content_types:
|
||||
try:
|
||||
print ct
|
||||
lexer = get_lexer_for_mimetype(ct)
|
||||
break
|
||||
except:
|
||||
pass
|
||||
|
||||
if lexer is None:
|
||||
raise ValueError("No lexer found")
|
||||
formatter = HtmlFormatter()
|
||||
return html_body % dict(
|
||||
css=formatter.get_style_defs(),
|
||||
content=highlight(content, lexer, formatter).encode('utf8'))
|
||||
except Exception, e:
|
||||
log.warning(
|
||||
"Could not pygment the content because of the following error :\n%s" % e)
|
||||
return html_body % dict(
|
||||
css='',
|
||||
content='<pre>%s</pre>' % content.replace('>', '>').replace('<', '<'))
|
||||
return path
|
||||
|
@ -9,6 +9,7 @@ from wsme.controller import scan_api
|
||||
|
||||
class DummyProtocol(object):
|
||||
name = 'dummy'
|
||||
content_types = ['', None]
|
||||
|
||||
def __init__(self):
|
||||
self.hits = 0
|
||||
@ -16,12 +17,19 @@ class DummyProtocol(object):
|
||||
def accept(self, root, req):
|
||||
return True
|
||||
|
||||
def handle(self, root, req):
|
||||
self.lastreq = req
|
||||
self.lastroot = root
|
||||
res = webob.Response()
|
||||
def extract_path(self, request):
|
||||
return ['touch']
|
||||
|
||||
def read_arguments(self, request, arguments):
|
||||
self.lastreq = request
|
||||
self.hits += 1
|
||||
return res
|
||||
return {}
|
||||
|
||||
def encode_result(self, result, return_type):
|
||||
return str(result)
|
||||
|
||||
def encode_error(self, infos):
|
||||
return str(infos)
|
||||
|
||||
|
||||
def serve_ws(req, root):
|
||||
@ -92,6 +100,8 @@ class TestController(unittest.TestCase):
|
||||
|
||||
def test_handle_request(self):
|
||||
class MyRoot(WSRoot):
|
||||
@expose()
|
||||
def touch(self):
|
||||
pass
|
||||
|
||||
p = DummyProtocol()
|
||||
@ -103,11 +113,9 @@ class TestController(unittest.TestCase):
|
||||
res = app.get('/')
|
||||
|
||||
assert p.lastreq.path == '/'
|
||||
assert p.lastroot == r
|
||||
assert p.hits == 1
|
||||
|
||||
res = app.get('/?wsmeproto=dummy')
|
||||
res = app.get('/touch?wsmeproto=dummy')
|
||||
|
||||
assert p.lastreq.path == '/'
|
||||
assert p.lastroot == r
|
||||
assert p.lastreq.path == '/touch'
|
||||
assert p.hits == 2
|
||||
|
Loading…
Reference in New Issue
Block a user