From 6e48ce443d94cc2b80a27aa1b7b3a9cac3d9c9db Mon Sep 17 00:00:00 2001 From: Christophe de Vienne Date: Tue, 25 Oct 2011 12:38:23 +0200 Subject: [PATCH] Protocols can now implement batch-calls --- wsme/controller.py | 105 +++++++++++++++++++++++++++++++-------------- 1 file changed, 72 insertions(+), 33 deletions(-) diff --git a/wsme/controller.py b/wsme/controller.py index 33f28ea..8542239 100644 --- a/wsme/controller.py +++ b/wsme/controller.py @@ -211,7 +211,7 @@ class validate(object): class CallContext(object): def __init__(self, request): - self.request = request + self.request = weakref.proxy(request) self.path = None self.func = None @@ -278,7 +278,46 @@ class WSRoot(object): break return protocol + def _do_call(self, protocol, context): + request = context.request + request.calls.append(context) + try: + if context.path is None: + context.path = protocol.extract_path(context) + + if context.path is None: + raise exc.ClientSideError( + u'The %s protocol was unable to extract a function ' + u'path from the request' % protocol.name) + + context.func, context.funcdef = self._lookup_function(context.path) + kw = protocol.read_arguments(context) + + for arg in context.funcdef.arguments: + if arg.mandatory and arg.name not in kw: + raise exc.MissingArgument(arg.name) + + result = context.func(**kw) + + if context.funcdef.protocol_specific \ + and context.funcdef.return_type is None: + return result + else: + # TODO make sure result type == a._wsme_definition.return_type + return protocol.encode_result(context, result) + + except Exception, e: + infos = self._format_exception(sys.exc_info()) + if isinstance(e, exc.ClientSideError): + request.client_errorcount += 1 + else: + request.server_errorcount += 1 + return protocol.encode_error(context, infos) + def _handle_request(self, request): + def default_prepare_response_body(request, results): + return '\n'.join(results) + res = webob.Response() res_content_type = None try: @@ -294,44 +333,44 @@ class WSRoot(object): return res context = None - calls = list(protocol.iter_calls(request)) - if len(calls) != 1: - raise NotImplementedError("Batch calls are not yet supported") + request.calls = [] + request.client_errorcount = 0 + request.server_errorcount = 0 - context = calls[0] - context.path = protocol.extract_path(context) - if context.path is None: - raise exc.ClientSideError( - u'The %s protocol was unable to extract a function ' - u'path from the request' % protocol.name) - context.func, context.funcdef = self._lookup_function(context.path) - kw = protocol.read_arguments(context) - - for arg in context.funcdef.arguments: - if arg.mandatory and arg.name not in kw: - raise exc.MissingArgument(arg.name) - - result = context.func(**kw) - - res.status = 200 - - if context.funcdef.protocol_specific and context.funcdef.return_type is None: - if isinstance(result, unicode): - res.unicode_body = result - else: - res.body = result + if hasattr(protocol, 'prepare_response_body'): + prepare_response_body = protocol.prepare_response_body else: - # TODO make sure result type == a._wsme_definition.return_type - res.body = protocol.encode_result(context, result) - res_content_type = context.funcdef.contenttype + prepare_response_body = default_prepare_response_body + + body = prepare_response_body(request, ( + self._do_call(protocol, context) + for context in protocol.iter_calls(request))) + if isinstance(body, unicode): + res.unicode_body = body + else: + res.body = body + + if len(request.calls) == 1: + if hasattr(protocol, 'get_response_status'): + res.status = protocol.get_response_status(request) + else: + if request.client_errorcount: + res.status = 400 + elif request.server_errorcount: + res.status = 500 + else: + res.status = 200 + if request.calls[0].funcdef: + res_content_type = request.calls[0].funcdef.contenttype + else: + res.status = protocol.get_response_status(request) + res_content_type = protocol.get_response_contenttype(request) except Exception, e: infos = self._format_exception(sys.exc_info()) - if isinstance(e, exc.ClientSideError): - res.status = 400 - else: - res.status = 500 + request.server_errorcount += 1 res.body = protocol.encode_error(context, infos) + res.status = 500 if res_content_type is None: # Attempt to correctly guess what content-type we should return.