From 7343e158e61d5bfd2ae7204035e1d4bc39b90c64 Mon Sep 17 00:00:00 2001 From: Christophe de Vienne Date: Fri, 23 Sep 2011 14:43:02 +0200 Subject: [PATCH] Adapt the rest protocol implementation to the changes I did for the soap protocol --- wsme/rest.py | 4 ++-- wsme/restjson.py | 4 ++-- wsme/restxml.py | 4 ++-- wsme/tests/test_controller.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/wsme/rest.py b/wsme/rest.py index 676f8f7..19a0a9c 100644 --- a/wsme/rest.py +++ b/wsme/rest.py @@ -15,7 +15,7 @@ class RestProtocol(object): return True return request.headers.get('Content-Type') in self.content_types - def read_arguments(self, request, arguments): + def read_arguments(self, request, funcdef): if len(request.params) and request.body: raise ClientSideError( "Cannot read parameters from both a body and GET/POST params") @@ -38,7 +38,7 @@ class RestProtocol(object): kw = {} - for arg in arguments: + for arg in funcdef.arguments: if arg.name not in parsed_args: if arg.mandatory: raise MissingArgument(arg.name) diff --git a/wsme/restjson.py b/wsme/restjson.py index f49825e..8d3463c 100644 --- a/wsme/restjson.py +++ b/wsme/restjson.py @@ -106,8 +106,8 @@ class RestJsonProtocol(RestProtocol): raw_args = json.loads(body) return raw_args - def encode_result(self, result, return_type): - r = tojson(return_type, result) + def encode_result(self, result, funcdef): + r = tojson(funcdef.return_type, result) return json.dumps({'result': r}, ensure_ascii=False).encode('utf8') def encode_error(self, errordetail): diff --git a/wsme/restxml.py b/wsme/restxml.py index 301e2f4..2c77474 100644 --- a/wsme/restxml.py +++ b/wsme/restxml.py @@ -125,8 +125,8 @@ class RestXmlProtocol(RestProtocol): def parse_args(self, body): return dict((sub.tag, sub) for sub in et.fromstring(body)) - def encode_result(self, result, return_type): - return et.tostring(toxml(return_type, 'result', result)) + def encode_result(self, result, funcdef): + return et.tostring(toxml(funcdef.return_type, 'result', result)) def encode_error(self, errordetail): el = et.Element('error') diff --git a/wsme/tests/test_controller.py b/wsme/tests/test_controller.py index 12344c7..c8101fc 100644 --- a/wsme/tests/test_controller.py +++ b/wsme/tests/test_controller.py @@ -94,8 +94,8 @@ class TestController(unittest.TestCase): api = [i for i in scan_api(r)] assert len(api) == 1 - assert api[0][0] == ['ns'] - fd = api[0][1] + fd = api[0] + assert fd.path == ['ns'] assert fd.name == 'multiply' def test_handle_request(self):