From 3a2280bd6a9bea9fbf3ab3c2eb5b4cee3b198e99 Mon Sep 17 00:00:00 2001 From: Christophe de Vienne Date: Wed, 21 Sep 2011 22:20:50 +0200 Subject: [PATCH] Move as much as possible the request handling code out of the protocol --- setup.py | 5 +- wsme/controller.py | 119 ++++++++++++++++++++++++++++++---- wsme/rest.py | 82 +---------------------- wsme/tests/test_controller.py | 28 +++++--- 4 files changed, 131 insertions(+), 103 deletions(-) diff --git a/setup.py b/setup.py index 456a188..048f1a3 100644 --- a/setup.py +++ b/setup.py @@ -3,5 +3,8 @@ from setuptools import setup setup( name='wsme', packages=['wsme'], - install_requires=['webob'], + install_requires=[ + 'simplegeneric', + 'webob', + ], ) diff --git a/wsme/controller.py b/wsme/controller.py index d066a7e..9101ccb 100644 --- a/wsme/controller.py +++ b/wsme/controller.py @@ -2,6 +2,8 @@ import inspect import traceback import weakref import logging +import webob +import sys from wsme import exc from wsme.types import register_type @@ -13,6 +15,20 @@ log = logging.getLogger(__name__) registered_protocols = {} +html_body = """ + + + + + +%(content)s + + +""" + + def scan_api(controller, path=[]): for name in dir(controller): if name.startswith('_'): @@ -95,7 +111,7 @@ class WSRoot(object): protocol = registered_protocols[protocol]() self.protocols[protocol.name] = protocol - def _handle_request(self, request): + def _select_protocol(self, request): protocol = None if 'wsmeproto' in request.params: protocol = self.protocols[request.params['wsmeproto']] @@ -104,8 +120,70 @@ class WSRoot(object): if p.accept(self, request): protocol = p break + return protocol - return protocol.handle(self, request) + def _handle_request(self, request): + res = webob.Response() + try: + protocol = self._select_protocol(request) + if protocol is None: + msg = ("None of the following protocols can handle this " + "request : %s" % ','.join(self.protocols.keys())) + res.status = 500 + res.text = msg + log.error(msg) + return res + path = protocol.extract_path(request) + func, funcdef = self._lookup_function(path) + kw = protocol.read_arguments(request, funcdef.arguments) + + result = func(**kw) + + # TODO make sure result type == a._wsme_definition.return_type + res.status = 200 + res.body = protocol.encode_result(result, funcdef.return_type) + except Exception, e: + res.status = 500 + res.body = protocol.encode_error( + self._format_exception(sys.exc_info())) + + # Attempt to correctly guess what content-type we should return. + res_content_type = None + + last_q = 0 + if hasattr(request.accept, '_parsed'): + for mimetype, q in request.accept._parsed: + if mimetype in protocol.content_types and last_q < q: + res_content_type = mimetype + else: + res_content_type = request.accept.best_match([ + ct for ct in protocol.content_types if ct]) + + # If not we will attempt to convert the body to an accepted + # output format. + if res_content_type is None: + if "text/html" in request.accept: + res.body = self._html_format(res.body, protocol.content_types) + res_content_type = "text/html" + + # TODO should we consider the encoding asked by + # the web browser ? + res.headers['Content-Type'] = "%s; charset=UTF-8" % res_content_type + + return res + + def _lookup_function(self, path): + a = self + + for name in path: + a = getattr(a, name, None) + if a is None: + break + + if not hasattr(a, '_wsme_definition'): + raise exc.UnknownFunction('/'.join(path)) + + return a, a._wsme_definition def _format_exception(self, excinfo): """Extract informations that can be sent to the client.""" @@ -126,15 +204,32 @@ class WSRoot(object): r['debuginfo'] = debuginfo return r - def _lookup_function(self, path): - a = self + def _html_format(self, content, content_types): + try: + from pygments import highlight + from pygments.lexers import get_lexer_for_mimetype + from pygments.formatters import HtmlFormatter - for name in path: - a = getattr(a, name, None) - if a is None: - break + lexer = None + for ct in content_types: + try: + print ct + lexer = get_lexer_for_mimetype(ct) + break + except: + pass - if not hasattr(a, '_wsme_definition'): - raise exc.UnknownFunction('/'.join(path)) - - return a, a._wsme_definition + if lexer is None: + raise ValueError("No lexer found") + formatter = HtmlFormatter() + return html_body % dict( + css=formatter.get_style_defs(), + content=highlight(content, lexer, formatter).encode('utf8')) + except Exception, e: + log.warning( + "Could not pygment the content because of the following " + "error :\n%s" % e) + return html_body % dict( + css='', + content='
%s
' % + content.replace('>', '>').replace('<', '<')) diff --git a/wsme/rest.py b/wsme/rest.py index 550ff1e..676f8f7 100644 --- a/wsme/rest.py +++ b/wsme/rest.py @@ -1,24 +1,9 @@ -import webob -import sys import logging from wsme.exc import UnknownFunction, MissingArgument, UnknownArgument log = logging.getLogger(__name__) -html_body = """ - - - - - -%(content)s - - -""" - class RestProtocol(object): name = None @@ -66,73 +51,10 @@ class RestProtocol(object): raise UnknownArgument(parsed_args.keys()[0]) return kw - def handle(self, root, request): + def extract_path(self, request): path = request.path.strip('/').split('/') if path[-1].endswith('.' + self.dataformat): path[-1] = path[-1][:-len(self.dataformat) - 1] - res = webob.Response() - - try: - func, funcdef = root._lookup_function(path) - kw = self.read_arguments(request, funcdef.arguments) - result = func(**kw) - # TODO make sure result type == a._wsme_definition.return_type - res.body = self.encode_result(result, funcdef.return_type) - res.status = "200 OK" - except Exception, e: - res.status = 500 - res.body = self.encode_error( - root._format_exception(sys.exc_info())) - - # Attempt to correctly guess what content-type we should return. - res_content_type = None - - last_q = 0 - if hasattr(request.accept, '_parsed'): - for mimetype, q in request.accept._parsed: - if mimetype in self.content_types and last_q < q: - res_content_type = mimetype - else: - res_content_type = request.accept.best_match([ - ct for ct in self.content_types if ct]) - - # If not we will attempt to convert the body to an accepted - # output format. - if res_content_type is None: - if "text/html" in request.accept: - res.body = self.html_format(res.body) - res_content_type = "text/html" - - res.headers['Content-Type'] = "%s; charset=UTF-8" % res_content_type - - return res - - def html_format(self, content): - try: - from pygments import highlight - from pygments.lexers import get_lexer_for_mimetype - from pygments.formatters import HtmlFormatter - - lexer = None - for ct in self.content_types: - try: - print ct - lexer = get_lexer_for_mimetype(ct) - break - except: - pass - - if lexer is None: - raise ValueError("No lexer found") - formatter = HtmlFormatter() - return html_body % dict( - css=formatter.get_style_defs(), - content=highlight(content, lexer, formatter).encode('utf8')) - except Exception, e: - log.warning( - "Could not pygment the content because of the following error :\n%s" % e) - return html_body % dict( - css='', - content='
%s
' % content.replace('>', '>').replace('<', '<')) + return path diff --git a/wsme/tests/test_controller.py b/wsme/tests/test_controller.py index a1d7eaf..12344c7 100644 --- a/wsme/tests/test_controller.py +++ b/wsme/tests/test_controller.py @@ -9,6 +9,7 @@ from wsme.controller import scan_api class DummyProtocol(object): name = 'dummy' + content_types = ['', None] def __init__(self): self.hits = 0 @@ -16,12 +17,19 @@ class DummyProtocol(object): def accept(self, root, req): return True - def handle(self, root, req): - self.lastreq = req - self.lastroot = root - res = webob.Response() + def extract_path(self, request): + return ['touch'] + + def read_arguments(self, request, arguments): + self.lastreq = request self.hits += 1 - return res + return {} + + def encode_result(self, result, return_type): + return str(result) + + def encode_error(self, infos): + return str(infos) def serve_ws(req, root): @@ -92,7 +100,9 @@ class TestController(unittest.TestCase): def test_handle_request(self): class MyRoot(WSRoot): - pass + @expose() + def touch(self): + pass p = DummyProtocol() r = MyRoot(protocols=[p]) @@ -103,11 +113,9 @@ class TestController(unittest.TestCase): res = app.get('/') assert p.lastreq.path == '/' - assert p.lastroot == r assert p.hits == 1 - res = app.get('/?wsmeproto=dummy') + res = app.get('/touch?wsmeproto=dummy') - assert p.lastreq.path == '/' - assert p.lastroot == r + assert p.lastreq.path == '/touch' assert p.hits == 2