Continuing the rest+json implementation

This commit is contained in:
Christophe de Vienne
2011-09-19 19:54:32 +02:00
parent 26d64e5f15
commit b58a916588
7 changed files with 131 additions and 23 deletions

View File

@@ -1,4 +1,8 @@
import inspect
import traceback
import weakref
from wsme import exc
__all__ = ['expose', 'validate', 'WSRoot']
@@ -26,8 +30,8 @@ class FunctionArgument(object):
class FunctionDefinition(object):
def __init__(self, name):
self.name = name
def __init__(self, func):
self.name = func.__name__
self.return_type = None
self.arguments = []
@@ -35,7 +39,7 @@ class FunctionDefinition(object):
def get(cls, func):
fd = getattr(func, '_wsme_definition', None)
if fd is None:
fd = FunctionDefinition(func.__name__)
fd = FunctionDefinition(func)
func._wsme_definition = fd
return fd
@@ -79,6 +83,7 @@ class validate(object):
class WSRoot(object):
def __init__(self, protocols=None):
self.debug = True
if protocols is None:
protocols = registered_protocols.keys()
self.protocols = {}
@@ -98,3 +103,30 @@ class WSRoot(object):
break
return protocol.handle(self, request)
def _format_exception(self, excinfo):
"""Extract informations that can be sent to the client."""
if isinstance(excinfo[1], exc.ClientSideError):
return dict(faultcode="Client",
faultstring=unicode(excinfo[1]))
else:
r = dict(faultcode="Server",
faultstring=str(excinfo[1]))
if self.debug:
r['debuginfo'] = ("Traceback:\n%s\n" %
"\n".join(traceback.format_exception(*excinfo)))
return r
def _lookup_function(self, path):
a = self
for name in path:
a = getattr(a, name, None)
if a is None:
break
if not hasattr(a, '_wsme_definition'):
raise exc.UnknownFunction('/'.join(path))
return a, a._wsme_definition

30
wsme/exc.py Normal file
View File

@@ -0,0 +1,30 @@
import __builtin__
if '_' not in __builtin__.__dict__:
__builtin__._ = lambda s: s
class ClientSideError(RuntimeError):
pass
class InvalidInput(ClientSideError):
def __init__(self, fieldname, value, msg=''):
self.fieldname = fieldname
self.value = value
self.msg = msg
def __unicode__(self):
return _(u"Invalid input for field/attribute %s. Value: '%s'. %s") % (
self.fieldname, self.value, self.msg)
def __str__(self):
return unicode(self).encode('utf8', 'ignore')
class UnknownFunction(ClientSideError):
def __init__(self, name):
self.name = name
def __unicode__(self):
return _(u"Unknown function name: %s") % (self.name)
def __str__(self):
return unicode(self).encode('utf8', 'ignore')

View File

@@ -1,4 +1,7 @@
import webob
import sys
from wsme.exc import UnknownFunction
class RestProtocol(object):
name = None
@@ -12,20 +15,20 @@ class RestProtocol(object):
def handle(self, root, request):
path = request.path.strip('/').split('/')
a = root
for name in path:
a = getattr(a, name)
if not hasattr(a, '_wsme_definition'):
raise ValueError('Invalid path')
fonc = a
kw = self.get_args(request)
res = webob.Response()
res.headers['Content-Type'] = 'application/json'
res.status = "200 OK"
res.body = self.encode_response(a(**kw))
try:
func, funcdef = root._lookup_function(path)
kw = self.get_args(request)
result = func(**kw)
# TODO make sure result type == a._wsme_definition.return_type
res.body = self.encode_result(result, funcdef.return_type)
res.status = "200 OK"
except Exception, e:
res.status = "500 Error"
res.body = self.encode_error(
root._format_exception(sys.exc_info()))
return res

View File

@@ -15,7 +15,10 @@ class RestJsonProtocol(RestProtocol):
kw = json.loads(req.body)
return kw
def encode_response(self, response):
return json.dumps(response)
def encode_result(self, result, return_type):
return json.dumps({'result': result})
def encode_error(self, errordetail):
return json.dumps(errordetail)
register_protocol(RestJsonProtocol)

View File

@@ -12,6 +12,15 @@ from wsme import *
warnings.filterwarnings('ignore', module='webob.dec')
class CallException(RuntimeError):
def __init__(self, faultcode, faultstring, debuginfo):
self.faultcode = faultcode
self.faultstring = faultstring
self.debuginfo = debuginfo
def __str__(self):
return 'faultcode=%s, faultstring=%s, debuginfo=%s' % (
self.faultcode, self.faultstring, self.debuginfo)
class ReturnTypes(object):
@expose(str)
@@ -46,8 +55,15 @@ class ReturnTypes(object):
def getdate(self):
return datetime.datetime(1994, 1, 26, 12, 0, 0)
class WithErrors(object):
@expose()
def divide_by_zero(self):
1 / 0
class WSTestRoot(WSRoot):
returntypes = ReturnTypes()
witherrors = WithErrors()
def reset(self):
self.touched = False
@@ -64,9 +80,23 @@ class ProtocolTestCase(unittest.TestCase):
self.app = TestApp(wsgify(self.root._handle_request))
def _call(self, fpath, **kw):
pass
def test_invalid_path(self):
try:
res = self.call('invalid_function')
assert "No error raised"
except CallException, e:
assert e.faultcode == 'Client'
assert e.faultstring == u'Unknown function name: invalid_function'
def test_serverside_error(self):
try:
res = self.call('witherrors/divide_by_zero')
assert "No error raised"
except CallException, e:
print e
assert e.faultcode == 'Server'
assert e.faultstring == u'integer division or modulo by zero'
def test_touch(self):
assert self.call('touch') is None
@@ -86,3 +116,4 @@ class ProtocolTestCase(unittest.TestCase):
r = self.call('returntypes/getfloat')
assert r == 3.14159265, r

View File

@@ -61,10 +61,9 @@ class TestController(unittest.TestCase):
assert args[2].default == 0
def test_register_protocol(self):
p = DummyProtocol()
import wsme.controller
wsme.controller.register_protocol(p)
assert wsme.controller.registered_protocols['dummy'] == p
wsme.controller.register_protocol(DummyProtocol)
assert wsme.controller.registered_protocols['dummy'] == DummyProtocol
r = WSRoot()
assert r.protocols['dummy']

View File

@@ -15,6 +15,16 @@ class TestRestJson(wsme.tests.protocol.ProtocolTestCase):
content,
headers={
'Content-Type': 'application/json',
})
},
expect_errors=True)
r = json.loads(res.body)
if 'result' in r:
return r['result']
else:
raise wsme.tests.protocol.CallException(
r['faultcode'],
r['faultstring'],
r.get('debuginfo'))
return json.loads(res.body)