Rework the rest implementation. We now have a single protocol that can manupulate different dataformat thanks to the helpers provided by the xml, json and args modules (which will be used by the adapters too). Some corner cases still don't pass the unittest, and some code cleaning is required.

This commit is contained in:
Christophe de Vienne 2012-11-07 18:19:11 +01:00
parent eaa6cc8083
commit c3891c477e
7 changed files with 208 additions and 82 deletions

View File

@ -30,8 +30,9 @@ classifier =
[entry_points] [entry_points]
wsme.protocols = wsme.protocols =
restjson = wsme.rest.json:RestJsonProtocol rest = wsme.rest.protocol:RestProtocol
restxml = wsme.rest.xml:RestXmlProtocol restjson = wsme.rest.protocol:RestProtocol
restxml = wsme.rest.protocol:RestProtocol
[files] [files]
packages = packages =

View File

@ -55,7 +55,6 @@ class ObjectDict(object):
class Protocol(object): class Protocol(object):
name = None name = None
displayname = None displayname = None
dataformat = None
content_types = [] content_types = []
def resolve_path(self, path): def resolve_path(self, path):
@ -73,8 +72,6 @@ class Protocol(object):
yield self.resolve_path(path), attr yield self.resolve_path(path), attr
def accept(self, request): def accept(self, request):
if request.path.endswith('.' + self.dataformat):
return True
return request.headers.get('Content-Type') in self.content_types return request.headers.get('Content-Type') in self.content_types
def iter_calls(self, request): def iter_calls(self, request):

View File

@ -144,39 +144,32 @@ def args_from_body(funcdef, body, mimetype):
from wsme.rest import json as restjson from wsme.rest import json as restjson
from wsme.rest import xml as restxml from wsme.rest import xml as restxml
kw = {}
if funcdef.body_type is not None: if funcdef.body_type is not None:
bodydata = None datatypes = {funcdef.arguments[-1].name: funcdef.body_type}
if mimetype in restjson.RestJsonProtocol.content_types: else:
if hasattr(body, 'read'): datatypes = dict(((a.name, a.datatype) for a in funcdef.arguments))
jsonbody = restjson.json.load(body)
else: if mimetype in restjson.accept_content_types:
jsonbody = restjson.json.loads(body) dataformat = restjson
bodydata = restjson.fromjson(funcdef.body_type, jsonbody) elif mimetype in restxml.accept_content_types:
elif mimetype in restxml.RestXmlProtocol.content_types: dataformat = restxml
if hasattr(body, 'read'): else:
xmlbody = restxml.et.parse(body) raise ValueError("Unknow mimetype: %s" % mimetype)
else:
xmlbody = restxml.et.fromstring(body) kw = dataformat.parse(
bodydata = restxml.fromxml(funcdef.body_type, xmlbody) body, datatypes, bodyarg=funcdef.body_type is not None
if bodydata: )
kw[funcdef.arguments[-1].name] = bodydata
return (), kw return (), kw
def combine_args(funcdef, *akw): def combine_args(funcdef, *akw):
newargs, newkwargs = [], {} newargs, newkwargs = [], {}
argindexes = {}
for i, arg in enumerate(funcdef.arguments):
argindexes[arg.name] = i
newargs.append(arg.default)
for args, kwargs in akw: for args, kwargs in akw:
for i, arg in enumerate(args): for i, arg in enumerate(args):
newargs[i] = arg newkwargs[funcdef.arguments[i].name] = arg
for name, value in kwargs.iteritems(): for name, value in kwargs.iteritems():
newargs[argindexes[name]] = value newkwargs[name] = value
return newargs, newkwargs return newargs, newkwargs

View File

@ -9,7 +9,6 @@ import six
from simplegeneric import generic from simplegeneric import generic
from wsme.rest.protocol import RestProtocol
from wsme.types import Unset from wsme.types import Unset
import wsme.types import wsme.types
@ -19,6 +18,14 @@ except ImportError:
import json # noqa import json # noqa
content_type = 'application/json'
accept_content_types = [
content_type,
'text/javascript',
'application/javascript'
]
@generic @generic
def tojson(datatype, value): def tojson(datatype, value):
""" """
@ -184,7 +191,7 @@ def datetime_fromjson(datatype, value):
return datetime.datetime.strptime(value, '%Y-%m-%dT%H:%M:%S') return datetime.datetime.strptime(value, '%Y-%m-%dT%H:%M:%S')
class RestJsonProtocol(RestProtocol): class RestJson(object):
""" """
REST+Json protocol. REST+Json protocol.
@ -193,18 +200,12 @@ class RestJsonProtocol(RestProtocol):
.. autoattribute:: content_types .. autoattribute:: content_types
""" """
name = 'restjson' name = 'json'
displayname = 'REST+Json' content_type = 'application/json'
dataformat = 'json'
content_types = [
'application/json',
'application/javascript',
'text/javascript',
'']
def __init__(self, nest_result=False): #def __init__(self, nest_result=False):
super(RestJsonProtocol, self).__init__() # super(RestJsonProtocol, self).__init__()
self.nest_result = nest_result # self.nest_result = nest_result
def decode_arg(self, value, arg): def decode_arg(self, value, arg):
return fromjson(arg.datatype, value) return fromjson(arg.datatype, value)
@ -222,9 +223,6 @@ class RestJsonProtocol(RestProtocol):
r = {'result': r} r = {'result': r}
return json.dumps(r) return json.dumps(r)
def encode_error(self, context, errordetail):
return json.dumps(errordetail)
def encode_sample_value(self, datatype, value, format=False): def encode_sample_value(self, datatype, value, format=False):
r = tojson(datatype, value) r = tojson(datatype, value)
content = json.dumps(r, ensure_ascii=False, content = json.dumps(r, ensure_ascii=False,
@ -249,3 +247,34 @@ class RestJsonProtocol(RestProtocol):
indent=4 if format else 0, indent=4 if format else 0,
sort_keys=format) sort_keys=format)
return ('javascript', content) return ('javascript', content)
def get_format():
return RestJson()
def parse(s, datatypes, bodyarg):
if hasattr(s, 'read'):
jdata = json.load(s)
else:
jdata = json.loads(s)
if bodyarg:
argname = list(datatypes.keys())[0]
kw = {argname: fromjson(datatypes[argname], jdata)}
else:
kw = {}
for key, datatype in datatypes.items():
if key in jdata:
kw[key] = fromjson(datatype, jdata[key])
return kw
def tostring(value, datatype, attrname=None):
jsondata = tojson(datatype, value)
if attrname is not None:
jsondata = {attrname: jsondata}
return json.dumps(tojson(datatype, value))
def encode_error(context, errordetail):
return json.dumps(errordetail)

View File

@ -1,19 +1,66 @@
import collections
import os.path
import logging import logging
import six import six
from six import u from wsme.exc import ClientSideError, UnknownArgument, MissingArgument
from wsme.exc import ClientSideError, UnknownArgument
from wsme.protocol import CallContext, Protocol from wsme.protocol import CallContext, Protocol
from wsme.rest.args import from_params
from wsme.types import Unset import wsme.rest
import wsme.rest.args
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class RestProtocol(Protocol): class RestProtocol(Protocol):
name = 'rest'
displayname = 'REST'
dataformats = ['json', 'xml']
content_types = ['application/json', 'text/xml']
def __init__(self, dataformats=None):
if dataformats is None:
dataformats = RestProtocol.dataformats
self.dataformats = collections.OrderedDict()
self.content_types = []
for dataformat in dataformats:
__import__('wsme.rest.' + dataformat)
dfmod = getattr(wsme.rest, dataformat)
self.dataformats[dataformat] = dfmod
self.content_types.extend(dfmod.accept_content_types)
def accept(self, request):
for dataformat in self.dataformats:
if request.path.endswith('.' + dataformat):
return True
return request.headers.get('Content-Type') in self.content_types
def iter_calls(self, request): def iter_calls(self, request):
yield CallContext(request) context = CallContext(request)
context.outformat = None
ext = os.path.splitext(request.path.split('/')[-1])[1]
inmime = request.content_type
outmime = request.accept.best_match(self.content_types)
outformat = None
for dfname, df in self.dataformats.items():
if ext == '.' + dfname:
outformat = df
if outformat is None and request.accept:
for dfname, df in self.dataformats.items():
if outmime in df.accept_content_types:
outformat = df
if outformat is None:
for dfname, df in self.dataformats.items():
if inmime == df.content_type:
outformat = df
context.outformat = outformat
yield context
def extract_path(self, context): def extract_path(self, context):
path = context.request.path path = context.request.path
@ -21,8 +68,9 @@ class RestProtocol(Protocol):
path = path[len(self.root._webpath):] path = path[len(self.root._webpath):]
path = path.strip('/').split('/') path = path.strip('/').split('/')
if path[-1].endswith('.' + self.dataformat): for dataformat in self.dataformats:
path[-1] = path[-1][:-len(self.dataformat) - 1] if path[-1].endswith('.' + dataformat):
path[-1] = path[-1][:-len(dataformat) - 1]
# Check if the path is actually a function, and if not # Check if the path is actually a function, and if not
# see if the http method make a difference # see if the http method make a difference
@ -60,42 +108,55 @@ class RestProtocol(Protocol):
raise ClientSideError( raise ClientSideError(
"Cannot read parameters from both a body and GET/POST params") "Cannot read parameters from both a body and GET/POST params")
param_args = (), {}
body = None body = None
if 'body' in request.params: if 'body' in request.params:
body = request.params['body'] body = request.params['body']
body_mimetype = context.outformat.content_type
if body is None:
body = request.body
body_mimetype = request.content_type
param_args = wsme.rest.args.args_from_params(
funcdef, request.params
)
if isinstance(body, six.binary_type):
body = body.decode('utf8')
if body is None and len(request.params): if body and body_mimetype in self.content_types:
kw = {} body_args = wsme.rest.args.args_from_body(
hit_paths = set() funcdef, body, body_mimetype
for argdef in funcdef.arguments: )
value = from_params(
argdef.datatype, request.params, argdef.name, hit_paths)
if value is not Unset:
kw[argdef.name] = value
paths = set(request.params.keys())
unknown_paths = paths - hit_paths
if unknown_paths:
raise UnknownArgument(', '.join(unknown_paths))
return kw
else: else:
if body is None: body_args = ((), {})
body = request.body
if isinstance(body, six.binary_type):
body = body.decode('utf8')
if body:
parsed_args = self.parse_args(body)
else:
parsed_args = {}
kw = {} args, kw = wsme.rest.args.combine_args(
funcdef,
param_args,
body_args
)
for arg in funcdef.arguments: for a in funcdef.arguments:
if arg.name not in parsed_args: if a.mandatory and a.name not in kw:
continue raise MissingArgument(a.name)
value = parsed_args.pop(arg.name) argnames = set((a.name for a in funcdef.arguments))
kw[arg.name] = self.decode_arg(value, arg)
for k in kw:
if k not in argnames:
raise UnknownArgument(k)
if parsed_args:
raise UnknownArgument(u(', ').join(parsed_args.keys()))
return kw return kw
def encode_result(self, context, result):
out = context.outformat.tostring(
result, context.funcdef.return_type
)
return out
def encode_error(self, context, errordetail):
out = context.outformat.encode_error(
context, errordetail
)
return out

View File

@ -11,9 +11,15 @@ from simplegeneric import generic
from wsme.rest.protocol import RestProtocol from wsme.rest.protocol import RestProtocol
import wsme.types import wsme.types
from wsme.exc import UnknownArgument
import re import re
content_type = 'text/xml'
accept_content_types = [
content_type,
]
time_re = re.compile(r'(?P<h>[0-2][0-9]):(?P<m>[0-5][0-9]):(?P<s>[0-6][0-9])') time_re = re.compile(r'(?P<h>[0-2][0-9]):(?P<m>[0-5][0-9]):(?P<s>[0-6][0-9])')
@ -244,6 +250,11 @@ class RestXmlProtocol(RestProtocol):
return et.tostring( return et.tostring(
toxml(context.funcdef.return_type, 'result', result)) toxml(context.funcdef.return_type, 'result', result))
class RestXml(object):
name = 'xml'
content_type = 'text/xml'
def encode_error(self, context, errordetail): def encode_error(self, context, errordetail):
el = et.Element('error') el = et.Element('error')
et.SubElement(el, 'faultcode').text = errordetail['faultcode'] et.SubElement(el, 'faultcode').text = errordetail['faultcode']
@ -274,3 +285,37 @@ class RestXmlProtocol(RestProtocol):
xml_indent(r) xml_indent(r)
content = et.tostring(r) content = et.tostring(r)
return ('xml', content) return ('xml', content)
def get_format():
return RestXml()
def parse(s, datatypes, bodyarg):
if hasattr(s, 'read'):
tree = et.parse(s)
else:
tree = et.fromstring(s)
if bodyarg:
name = list(datatypes.keys())[0]
return fromxml(datatypes[name], tree)
else:
kw = {}
for sub in tree:
if sub.tag not in datatypes:
raise UnknownArgument(sub.tag)
kw[sub.tag] = fromxml(datatypes[sub.tag], sub)
return kw
def tostring(value, datatype, attrname='result'):
return et.tostring(toxml(datatype, attrname, value))
def encode_error(context, errordetail):
el = et.Element('error')
et.SubElement(el, 'faultcode').text = errordetail['faultcode']
et.SubElement(el, 'faultstring').text = errordetail['faultstring']
if 'debuginfo' in errordetail:
et.SubElement(el, 'debuginfo').text = errordetail['debuginfo']
return et.tostring(el)

View File

@ -3,7 +3,7 @@ from wsme import types
try: try:
import simplejson as json import simplejson as json
except ImportError: except ImportError:
import json import json # noqa
def getdesc(root, host_url=''): def getdesc(root, host_url=''):