diff --git a/wsme/controller.py b/wsme/controller.py index 78611a4..2b2484b 100644 --- a/wsme/controller.py +++ b/wsme/controller.py @@ -1,4 +1,8 @@ import inspect +import traceback +import weakref + +from wsme import exc __all__ = ['expose', 'validate', 'WSRoot'] @@ -26,8 +30,8 @@ class FunctionArgument(object): class FunctionDefinition(object): - def __init__(self, name): - self.name = name + def __init__(self, func): + self.name = func.__name__ self.return_type = None self.arguments = [] @@ -35,7 +39,7 @@ class FunctionDefinition(object): def get(cls, func): fd = getattr(func, '_wsme_definition', None) if fd is None: - fd = FunctionDefinition(func.__name__) + fd = FunctionDefinition(func) func._wsme_definition = fd return fd @@ -79,6 +83,7 @@ class validate(object): class WSRoot(object): def __init__(self, protocols=None): + self.debug = True if protocols is None: protocols = registered_protocols.keys() self.protocols = {} @@ -98,3 +103,30 @@ class WSRoot(object): break return protocol.handle(self, request) + + def _format_exception(self, excinfo): + """Extract informations that can be sent to the client.""" + if isinstance(excinfo[1], exc.ClientSideError): + return dict(faultcode="Client", + faultstring=unicode(excinfo[1])) + else: + r = dict(faultcode="Server", + faultstring=str(excinfo[1])) + if self.debug: + r['debuginfo'] = ("Traceback:\n%s\n" % + "\n".join(traceback.format_exception(*excinfo))) + return r + + def _lookup_function(self, path): + a = self + + for name in path: + a = getattr(a, name, None) + if a is None: + break + + if not hasattr(a, '_wsme_definition'): + raise exc.UnknownFunction('/'.join(path)) + + return a, a._wsme_definition + diff --git a/wsme/exc.py b/wsme/exc.py new file mode 100644 index 0000000..2d1c3ed --- /dev/null +++ b/wsme/exc.py @@ -0,0 +1,30 @@ +import __builtin__ + +if '_' not in __builtin__.__dict__: + __builtin__._ = lambda s: s + +class ClientSideError(RuntimeError): + pass + +class InvalidInput(ClientSideError): + def __init__(self, fieldname, value, msg=''): + self.fieldname = fieldname + self.value = value + self.msg = msg + + def __unicode__(self): + return _(u"Invalid input for field/attribute %s. Value: '%s'. %s") % ( + self.fieldname, self.value, self.msg) + + def __str__(self): + return unicode(self).encode('utf8', 'ignore') + +class UnknownFunction(ClientSideError): + def __init__(self, name): + self.name = name + + def __unicode__(self): + return _(u"Unknown function name: %s") % (self.name) + + def __str__(self): + return unicode(self).encode('utf8', 'ignore') diff --git a/wsme/rest.py b/wsme/rest.py index 01f4710..6ea0ad4 100644 --- a/wsme/rest.py +++ b/wsme/rest.py @@ -1,4 +1,7 @@ import webob +import sys + +from wsme.exc import UnknownFunction class RestProtocol(object): name = None @@ -12,20 +15,20 @@ class RestProtocol(object): def handle(self, root, request): path = request.path.strip('/').split('/') - a = root - for name in path: - a = getattr(a, name) - if not hasattr(a, '_wsme_definition'): - raise ValueError('Invalid path') - fonc = a - - kw = self.get_args(request) - res = webob.Response() res.headers['Content-Type'] = 'application/json' - res.status = "200 OK" - res.body = self.encode_response(a(**kw)) + try: + func, funcdef = root._lookup_function(path) + kw = self.get_args(request) + result = func(**kw) + # TODO make sure result type == a._wsme_definition.return_type + res.body = self.encode_result(result, funcdef.return_type) + res.status = "200 OK" + except Exception, e: + res.status = "500 Error" + res.body = self.encode_error( + root._format_exception(sys.exc_info())) return res diff --git a/wsme/restjson.py b/wsme/restjson.py index 8902682..359b8d0 100644 --- a/wsme/restjson.py +++ b/wsme/restjson.py @@ -15,7 +15,10 @@ class RestJsonProtocol(RestProtocol): kw = json.loads(req.body) return kw - def encode_response(self, response): - return json.dumps(response) + def encode_result(self, result, return_type): + return json.dumps({'result': result}) + + def encode_error(self, errordetail): + return json.dumps(errordetail) register_protocol(RestJsonProtocol) diff --git a/wsme/tests/protocol.py b/wsme/tests/protocol.py index 96e02e8..69f48aa 100644 --- a/wsme/tests/protocol.py +++ b/wsme/tests/protocol.py @@ -12,6 +12,15 @@ from wsme import * warnings.filterwarnings('ignore', module='webob.dec') +class CallException(RuntimeError): + def __init__(self, faultcode, faultstring, debuginfo): + self.faultcode = faultcode + self.faultstring = faultstring + self.debuginfo = debuginfo + + def __str__(self): + return 'faultcode=%s, faultstring=%s, debuginfo=%s' % ( + self.faultcode, self.faultstring, self.debuginfo) class ReturnTypes(object): @expose(str) @@ -46,8 +55,15 @@ class ReturnTypes(object): def getdate(self): return datetime.datetime(1994, 1, 26, 12, 0, 0) + +class WithErrors(object): + @expose() + def divide_by_zero(self): + 1 / 0 + class WSTestRoot(WSRoot): returntypes = ReturnTypes() + witherrors = WithErrors() def reset(self): self.touched = False @@ -64,9 +80,23 @@ class ProtocolTestCase(unittest.TestCase): self.app = TestApp(wsgify(self.root._handle_request)) - def _call(self, fpath, **kw): - pass - + def test_invalid_path(self): + try: + res = self.call('invalid_function') + assert "No error raised" + except CallException, e: + assert e.faultcode == 'Client' + assert e.faultstring == u'Unknown function name: invalid_function' + + def test_serverside_error(self): + try: + res = self.call('witherrors/divide_by_zero') + assert "No error raised" + except CallException, e: + print e + assert e.faultcode == 'Server' + assert e.faultstring == u'integer division or modulo by zero' + def test_touch(self): assert self.call('touch') is None @@ -86,3 +116,4 @@ class ProtocolTestCase(unittest.TestCase): r = self.call('returntypes/getfloat') assert r == 3.14159265, r + diff --git a/wsme/tests/test_controller.py b/wsme/tests/test_controller.py index 59e8b44..daebc50 100644 --- a/wsme/tests/test_controller.py +++ b/wsme/tests/test_controller.py @@ -61,10 +61,9 @@ class TestController(unittest.TestCase): assert args[2].default == 0 def test_register_protocol(self): - p = DummyProtocol() import wsme.controller - wsme.controller.register_protocol(p) - assert wsme.controller.registered_protocols['dummy'] == p + wsme.controller.register_protocol(DummyProtocol) + assert wsme.controller.registered_protocols['dummy'] == DummyProtocol r = WSRoot() assert r.protocols['dummy'] diff --git a/wsme/tests/test_restjson.py b/wsme/tests/test_restjson.py index cbffb3c..0ac1dc8 100644 --- a/wsme/tests/test_restjson.py +++ b/wsme/tests/test_restjson.py @@ -15,6 +15,16 @@ class TestRestJson(wsme.tests.protocol.ProtocolTestCase): content, headers={ 'Content-Type': 'application/json', - }) + }, + expect_errors=True) + r = json.loads(res.body) + if 'result' in r: + return r['result'] + else: + raise wsme.tests.protocol.CallException( + r['faultcode'], + r['faultstring'], + r.get('debuginfo')) + return json.loads(res.body)