We now have a CallContext that follows a function execution from the path extraction to the result encoding.
This commit is contained in:
@@ -213,6 +213,15 @@ class validate(object):
|
||||
return func
|
||||
|
||||
|
||||
class CallContext(object):
|
||||
def __init__(self, request):
|
||||
self.request = request
|
||||
self.path = None
|
||||
|
||||
self.func = None
|
||||
self.funcdef = None
|
||||
|
||||
|
||||
class WSRoot(object):
|
||||
"""
|
||||
Root controller for webservices.
|
||||
@@ -287,44 +296,56 @@ class WSRoot(object):
|
||||
res.body = msg
|
||||
log.error(msg)
|
||||
return res
|
||||
path = protocol.extract_path(request)
|
||||
if path is None:
|
||||
|
||||
context = None
|
||||
calls = protocol.list_calls(request)
|
||||
|
||||
if isinstance(calls, CallContext):
|
||||
calls = [calls]
|
||||
|
||||
if len(calls) != 1:
|
||||
raise NotImplementedError("Batch calls are not yet supported")
|
||||
|
||||
context = calls[0]
|
||||
context.path = protocol.extract_path(context)
|
||||
if context.path is None:
|
||||
raise exc.ClientSideError(
|
||||
u'The %s protocol was unable to extract a function '
|
||||
u'path from the request' % protocol.name)
|
||||
func, funcdef = self._lookup_function(path)
|
||||
kw = protocol.read_arguments(funcdef, request)
|
||||
context.func, context.funcdef = self._lookup_function(context.path)
|
||||
kw = protocol.read_arguments(context)
|
||||
|
||||
for arg in funcdef.arguments:
|
||||
for arg in context.funcdef.arguments:
|
||||
if arg.mandatory and arg.name not in kw:
|
||||
raise exc.MissingArgument(arg.name)
|
||||
|
||||
result = func(**kw)
|
||||
result = context.func(**kw)
|
||||
|
||||
res.status = 200
|
||||
|
||||
if funcdef.protocol_specific and funcdef.return_type is None:
|
||||
if context.funcdef.protocol_specific and context.funcdef.return_type is None:
|
||||
if isinstance(result, unicode):
|
||||
res.unicode_body = result
|
||||
else:
|
||||
res.body = result
|
||||
else:
|
||||
# TODO make sure result type == a._wsme_definition.return_type
|
||||
res.body = protocol.encode_result(request, funcdef, result)
|
||||
res_content_type = funcdef.contenttype
|
||||
res.body = protocol.encode_result(context, result)
|
||||
res_content_type = context.funcdef.contenttype
|
||||
except Exception, e:
|
||||
infos = self._format_exception(sys.exc_info())
|
||||
if isinstance(e, exc.ClientSideError):
|
||||
res.status = 400
|
||||
else:
|
||||
res.status = 500
|
||||
res.body = protocol.encode_error(infos)
|
||||
res.body = protocol.encode_error(context, infos)
|
||||
|
||||
if res_content_type is None:
|
||||
# Attempt to correctly guess what content-type we should return.
|
||||
last_q = 0
|
||||
res_content_type = request.accept.best_match([
|
||||
ct for ct in protocol.content_types if ct])
|
||||
ctypes = [ct for ct in protocol.content_types if ct]
|
||||
if ctypes:
|
||||
res_content_type = request.accept.best_match(ctypes)
|
||||
|
||||
# If not we will attempt to convert the body to an accepted
|
||||
# output format.
|
||||
|
||||
@@ -155,9 +155,9 @@ class RestJsonProtocol(RestProtocol):
|
||||
raw_args = json.loads(body)
|
||||
return raw_args
|
||||
|
||||
def encode_result(self, funcdef, result):
|
||||
r = tojson(funcdef.return_type, result)
|
||||
def encode_result(self, context, result):
|
||||
r = tojson(context.funcdef.return_type, result)
|
||||
return json.dumps({'result': r}, ensure_ascii=False).encode('utf8')
|
||||
|
||||
def encode_error(self, errordetail):
|
||||
def encode_error(self, context, errordetail):
|
||||
return json.dumps(errordetail, encoding='utf-8')
|
||||
|
||||
@@ -185,10 +185,10 @@ class RestXmlProtocol(RestProtocol):
|
||||
def parse_args(self, body):
|
||||
return dict((sub.tag, sub) for sub in et.fromstring(body))
|
||||
|
||||
def encode_result(self, funcdef, result):
|
||||
return et.tostring(toxml(funcdef.return_type, 'result', result))
|
||||
def encode_result(self, context, result):
|
||||
return et.tostring(toxml(context.funcdef.return_type, 'result', result))
|
||||
|
||||
def encode_error(self, errordetail):
|
||||
def encode_error(self, context, errordetail):
|
||||
el = et.Element('error')
|
||||
et.SubElement(el, 'faultcode').text = errordetail['faultcode']
|
||||
et.SubElement(el, 'faultstring').text = errordetail['faultstring']
|
||||
|
||||
29
wsme/rest.py
29
wsme/rest.py
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
|
||||
from wsme.exc import UnknownFunction, MissingArgument, UnknownArgument
|
||||
from wsme.controller import CallContext
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -15,7 +16,23 @@ class RestProtocol(object):
|
||||
return True
|
||||
return request.headers.get('Content-Type') in self.content_types
|
||||
|
||||
def read_arguments(self, funcdef, request):
|
||||
def list_calls(self, request):
|
||||
return CallContext(request)
|
||||
|
||||
def extract_path(self, context):
|
||||
path = context.request.path
|
||||
assert path.startswith(self.root._webpath)
|
||||
path = path[len(self.root._webpath):]
|
||||
path = path.strip('/').split('/')
|
||||
|
||||
if path[-1].endswith('.' + self.dataformat):
|
||||
path[-1] = path[-1][:-len(self.dataformat) - 1]
|
||||
|
||||
return path
|
||||
|
||||
def read_arguments(self, context):
|
||||
request = context.request
|
||||
funcdef = context.funcdef
|
||||
if len(request.params) and request.body:
|
||||
raise ClientSideError(
|
||||
"Cannot read parameters from both a body and GET/POST params")
|
||||
@@ -49,13 +66,3 @@ class RestProtocol(object):
|
||||
raise UnknownArgument(parsed_args.keys()[0])
|
||||
return kw
|
||||
|
||||
def extract_path(self, request):
|
||||
path = request.path
|
||||
assert path.startswith(self.root._webpath)
|
||||
path = path[len(self.root._webpath):]
|
||||
path = path.strip('/').split('/')
|
||||
|
||||
if path[-1].endswith('.' + self.dataformat):
|
||||
path[-1] = path[-1][:-len(self.dataformat) - 1]
|
||||
|
||||
return path
|
||||
|
||||
@@ -7,7 +7,7 @@ import webtest
|
||||
|
||||
from wsme import *
|
||||
from wsme.controller import getprotocol, scan_api, pexpose
|
||||
from wsme.controller import FunctionArgument, FunctionDefinition
|
||||
from wsme.controller import FunctionArgument, FunctionDefinition, CallContext
|
||||
import wsme.wsgi
|
||||
|
||||
|
||||
@@ -21,18 +21,21 @@ class DummyProtocol(object):
|
||||
def accept(self, req):
|
||||
return True
|
||||
|
||||
def extract_path(self, request):
|
||||
def list_calls(self, req):
|
||||
return CallContext(req)
|
||||
|
||||
def extract_path(self, context):
|
||||
return ['touch']
|
||||
|
||||
def read_arguments(self, funcdef, request):
|
||||
self.lastreq = request
|
||||
def read_arguments(self, context):
|
||||
self.lastreq = context.request
|
||||
self.hits += 1
|
||||
return {}
|
||||
|
||||
def encode_result(self, funcdef, result):
|
||||
def encode_result(self, context, result):
|
||||
return str(result)
|
||||
|
||||
def encode_error(self, infos):
|
||||
def encode_error(self, context, infos):
|
||||
return str(infos)
|
||||
|
||||
|
||||
@@ -46,8 +49,8 @@ def test_getprotocol():
|
||||
|
||||
def test_pexpose():
|
||||
class Proto(DummyProtocol):
|
||||
def extract_path(self, request):
|
||||
if request.path.endswith('ufunc'):
|
||||
def extract_path(self, context):
|
||||
if context.request.path.endswith('ufunc'):
|
||||
return ['_protocol', 'dummy', 'ufunc']
|
||||
else:
|
||||
return ['_protocol', 'dummy', 'func']
|
||||
|
||||
Reference in New Issue
Block a user