Rework the rest implementation. We now have a single protocol that can manupulate different dataformat thanks to the helpers provided by the xml, json and args modules (which will be used by the adapters too). Some corner cases still don't pass the unittest, and some code cleaning is required.
This commit is contained in:
parent
eaa6cc8083
commit
c3891c477e
@ -30,8 +30,9 @@ classifier =
|
||||
|
||||
[entry_points]
|
||||
wsme.protocols =
|
||||
restjson = wsme.rest.json:RestJsonProtocol
|
||||
restxml = wsme.rest.xml:RestXmlProtocol
|
||||
rest = wsme.rest.protocol:RestProtocol
|
||||
restjson = wsme.rest.protocol:RestProtocol
|
||||
restxml = wsme.rest.protocol:RestProtocol
|
||||
|
||||
[files]
|
||||
packages =
|
||||
|
@ -55,7 +55,6 @@ class ObjectDict(object):
|
||||
class Protocol(object):
|
||||
name = None
|
||||
displayname = None
|
||||
dataformat = None
|
||||
content_types = []
|
||||
|
||||
def resolve_path(self, path):
|
||||
@ -73,8 +72,6 @@ class Protocol(object):
|
||||
yield self.resolve_path(path), attr
|
||||
|
||||
def accept(self, request):
|
||||
if request.path.endswith('.' + self.dataformat):
|
||||
return True
|
||||
return request.headers.get('Content-Type') in self.content_types
|
||||
|
||||
def iter_calls(self, request):
|
||||
|
@ -144,39 +144,32 @@ def args_from_body(funcdef, body, mimetype):
|
||||
from wsme.rest import json as restjson
|
||||
from wsme.rest import xml as restxml
|
||||
|
||||
kw = {}
|
||||
|
||||
if funcdef.body_type is not None:
|
||||
bodydata = None
|
||||
if mimetype in restjson.RestJsonProtocol.content_types:
|
||||
if hasattr(body, 'read'):
|
||||
jsonbody = restjson.json.load(body)
|
||||
else:
|
||||
jsonbody = restjson.json.loads(body)
|
||||
bodydata = restjson.fromjson(funcdef.body_type, jsonbody)
|
||||
elif mimetype in restxml.RestXmlProtocol.content_types:
|
||||
if hasattr(body, 'read'):
|
||||
xmlbody = restxml.et.parse(body)
|
||||
else:
|
||||
xmlbody = restxml.et.fromstring(body)
|
||||
bodydata = restxml.fromxml(funcdef.body_type, xmlbody)
|
||||
if bodydata:
|
||||
kw[funcdef.arguments[-1].name] = bodydata
|
||||
datatypes = {funcdef.arguments[-1].name: funcdef.body_type}
|
||||
else:
|
||||
datatypes = dict(((a.name, a.datatype) for a in funcdef.arguments))
|
||||
|
||||
if mimetype in restjson.accept_content_types:
|
||||
dataformat = restjson
|
||||
elif mimetype in restxml.accept_content_types:
|
||||
dataformat = restxml
|
||||
else:
|
||||
raise ValueError("Unknow mimetype: %s" % mimetype)
|
||||
|
||||
kw = dataformat.parse(
|
||||
body, datatypes, bodyarg=funcdef.body_type is not None
|
||||
)
|
||||
|
||||
return (), kw
|
||||
|
||||
|
||||
def combine_args(funcdef, *akw):
|
||||
newargs, newkwargs = [], {}
|
||||
argindexes = {}
|
||||
for i, arg in enumerate(funcdef.arguments):
|
||||
argindexes[arg.name] = i
|
||||
newargs.append(arg.default)
|
||||
for args, kwargs in akw:
|
||||
for i, arg in enumerate(args):
|
||||
newargs[i] = arg
|
||||
newkwargs[funcdef.arguments[i].name] = arg
|
||||
for name, value in kwargs.iteritems():
|
||||
newargs[argindexes[name]] = value
|
||||
newkwargs[name] = value
|
||||
return newargs, newkwargs
|
||||
|
||||
|
||||
|
@ -9,7 +9,6 @@ import six
|
||||
|
||||
from simplegeneric import generic
|
||||
|
||||
from wsme.rest.protocol import RestProtocol
|
||||
from wsme.types import Unset
|
||||
import wsme.types
|
||||
|
||||
@ -19,6 +18,14 @@ except ImportError:
|
||||
import json # noqa
|
||||
|
||||
|
||||
content_type = 'application/json'
|
||||
accept_content_types = [
|
||||
content_type,
|
||||
'text/javascript',
|
||||
'application/javascript'
|
||||
]
|
||||
|
||||
|
||||
@generic
|
||||
def tojson(datatype, value):
|
||||
"""
|
||||
@ -184,7 +191,7 @@ def datetime_fromjson(datatype, value):
|
||||
return datetime.datetime.strptime(value, '%Y-%m-%dT%H:%M:%S')
|
||||
|
||||
|
||||
class RestJsonProtocol(RestProtocol):
|
||||
class RestJson(object):
|
||||
"""
|
||||
REST+Json protocol.
|
||||
|
||||
@ -193,18 +200,12 @@ class RestJsonProtocol(RestProtocol):
|
||||
.. autoattribute:: content_types
|
||||
"""
|
||||
|
||||
name = 'restjson'
|
||||
displayname = 'REST+Json'
|
||||
dataformat = 'json'
|
||||
content_types = [
|
||||
'application/json',
|
||||
'application/javascript',
|
||||
'text/javascript',
|
||||
'']
|
||||
name = 'json'
|
||||
content_type = 'application/json'
|
||||
|
||||
def __init__(self, nest_result=False):
|
||||
super(RestJsonProtocol, self).__init__()
|
||||
self.nest_result = nest_result
|
||||
#def __init__(self, nest_result=False):
|
||||
# super(RestJsonProtocol, self).__init__()
|
||||
# self.nest_result = nest_result
|
||||
|
||||
def decode_arg(self, value, arg):
|
||||
return fromjson(arg.datatype, value)
|
||||
@ -222,9 +223,6 @@ class RestJsonProtocol(RestProtocol):
|
||||
r = {'result': r}
|
||||
return json.dumps(r)
|
||||
|
||||
def encode_error(self, context, errordetail):
|
||||
return json.dumps(errordetail)
|
||||
|
||||
def encode_sample_value(self, datatype, value, format=False):
|
||||
r = tojson(datatype, value)
|
||||
content = json.dumps(r, ensure_ascii=False,
|
||||
@ -249,3 +247,34 @@ class RestJsonProtocol(RestProtocol):
|
||||
indent=4 if format else 0,
|
||||
sort_keys=format)
|
||||
return ('javascript', content)
|
||||
|
||||
|
||||
def get_format():
|
||||
return RestJson()
|
||||
|
||||
|
||||
def parse(s, datatypes, bodyarg):
|
||||
if hasattr(s, 'read'):
|
||||
jdata = json.load(s)
|
||||
else:
|
||||
jdata = json.loads(s)
|
||||
if bodyarg:
|
||||
argname = list(datatypes.keys())[0]
|
||||
kw = {argname: fromjson(datatypes[argname], jdata)}
|
||||
else:
|
||||
kw = {}
|
||||
for key, datatype in datatypes.items():
|
||||
if key in jdata:
|
||||
kw[key] = fromjson(datatype, jdata[key])
|
||||
return kw
|
||||
|
||||
|
||||
def tostring(value, datatype, attrname=None):
|
||||
jsondata = tojson(datatype, value)
|
||||
if attrname is not None:
|
||||
jsondata = {attrname: jsondata}
|
||||
return json.dumps(tojson(datatype, value))
|
||||
|
||||
|
||||
def encode_error(context, errordetail):
|
||||
return json.dumps(errordetail)
|
||||
|
@ -1,19 +1,66 @@
|
||||
import collections
|
||||
import os.path
|
||||
import logging
|
||||
import six
|
||||
|
||||
from six import u
|
||||
|
||||
from wsme.exc import ClientSideError, UnknownArgument
|
||||
from wsme.exc import ClientSideError, UnknownArgument, MissingArgument
|
||||
from wsme.protocol import CallContext, Protocol
|
||||
from wsme.rest.args import from_params
|
||||
from wsme.types import Unset
|
||||
|
||||
import wsme.rest
|
||||
import wsme.rest.args
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RestProtocol(Protocol):
|
||||
name = 'rest'
|
||||
displayname = 'REST'
|
||||
dataformats = ['json', 'xml']
|
||||
content_types = ['application/json', 'text/xml']
|
||||
|
||||
def __init__(self, dataformats=None):
|
||||
if dataformats is None:
|
||||
dataformats = RestProtocol.dataformats
|
||||
|
||||
self.dataformats = collections.OrderedDict()
|
||||
self.content_types = []
|
||||
|
||||
for dataformat in dataformats:
|
||||
__import__('wsme.rest.' + dataformat)
|
||||
dfmod = getattr(wsme.rest, dataformat)
|
||||
self.dataformats[dataformat] = dfmod
|
||||
self.content_types.extend(dfmod.accept_content_types)
|
||||
|
||||
def accept(self, request):
|
||||
for dataformat in self.dataformats:
|
||||
if request.path.endswith('.' + dataformat):
|
||||
return True
|
||||
return request.headers.get('Content-Type') in self.content_types
|
||||
|
||||
def iter_calls(self, request):
|
||||
yield CallContext(request)
|
||||
context = CallContext(request)
|
||||
context.outformat = None
|
||||
ext = os.path.splitext(request.path.split('/')[-1])[1]
|
||||
inmime = request.content_type
|
||||
outmime = request.accept.best_match(self.content_types)
|
||||
|
||||
outformat = None
|
||||
for dfname, df in self.dataformats.items():
|
||||
if ext == '.' + dfname:
|
||||
outformat = df
|
||||
|
||||
if outformat is None and request.accept:
|
||||
for dfname, df in self.dataformats.items():
|
||||
if outmime in df.accept_content_types:
|
||||
outformat = df
|
||||
|
||||
if outformat is None:
|
||||
for dfname, df in self.dataformats.items():
|
||||
if inmime == df.content_type:
|
||||
outformat = df
|
||||
|
||||
context.outformat = outformat
|
||||
yield context
|
||||
|
||||
def extract_path(self, context):
|
||||
path = context.request.path
|
||||
@ -21,8 +68,9 @@ class RestProtocol(Protocol):
|
||||
path = path[len(self.root._webpath):]
|
||||
path = path.strip('/').split('/')
|
||||
|
||||
if path[-1].endswith('.' + self.dataformat):
|
||||
path[-1] = path[-1][:-len(self.dataformat) - 1]
|
||||
for dataformat in self.dataformats:
|
||||
if path[-1].endswith('.' + dataformat):
|
||||
path[-1] = path[-1][:-len(dataformat) - 1]
|
||||
|
||||
# Check if the path is actually a function, and if not
|
||||
# see if the http method make a difference
|
||||
@ -60,42 +108,55 @@ class RestProtocol(Protocol):
|
||||
raise ClientSideError(
|
||||
"Cannot read parameters from both a body and GET/POST params")
|
||||
|
||||
param_args = (), {}
|
||||
|
||||
body = None
|
||||
|
||||
if 'body' in request.params:
|
||||
body = request.params['body']
|
||||
body_mimetype = context.outformat.content_type
|
||||
if body is None:
|
||||
body = request.body
|
||||
body_mimetype = request.content_type
|
||||
param_args = wsme.rest.args.args_from_params(
|
||||
funcdef, request.params
|
||||
)
|
||||
if isinstance(body, six.binary_type):
|
||||
body = body.decode('utf8')
|
||||
|
||||
if body is None and len(request.params):
|
||||
kw = {}
|
||||
hit_paths = set()
|
||||
for argdef in funcdef.arguments:
|
||||
value = from_params(
|
||||
argdef.datatype, request.params, argdef.name, hit_paths)
|
||||
if value is not Unset:
|
||||
kw[argdef.name] = value
|
||||
paths = set(request.params.keys())
|
||||
unknown_paths = paths - hit_paths
|
||||
if unknown_paths:
|
||||
raise UnknownArgument(', '.join(unknown_paths))
|
||||
return kw
|
||||
if body and body_mimetype in self.content_types:
|
||||
body_args = wsme.rest.args.args_from_body(
|
||||
funcdef, body, body_mimetype
|
||||
)
|
||||
else:
|
||||
if body is None:
|
||||
body = request.body
|
||||
if isinstance(body, six.binary_type):
|
||||
body = body.decode('utf8')
|
||||
if body:
|
||||
parsed_args = self.parse_args(body)
|
||||
else:
|
||||
parsed_args = {}
|
||||
body_args = ((), {})
|
||||
|
||||
kw = {}
|
||||
args, kw = wsme.rest.args.combine_args(
|
||||
funcdef,
|
||||
param_args,
|
||||
body_args
|
||||
)
|
||||
|
||||
for arg in funcdef.arguments:
|
||||
if arg.name not in parsed_args:
|
||||
continue
|
||||
for a in funcdef.arguments:
|
||||
if a.mandatory and a.name not in kw:
|
||||
raise MissingArgument(a.name)
|
||||
|
||||
value = parsed_args.pop(arg.name)
|
||||
kw[arg.name] = self.decode_arg(value, arg)
|
||||
argnames = set((a.name for a in funcdef.arguments))
|
||||
|
||||
for k in kw:
|
||||
if k not in argnames:
|
||||
raise UnknownArgument(k)
|
||||
|
||||
if parsed_args:
|
||||
raise UnknownArgument(u(', ').join(parsed_args.keys()))
|
||||
return kw
|
||||
|
||||
def encode_result(self, context, result):
|
||||
out = context.outformat.tostring(
|
||||
result, context.funcdef.return_type
|
||||
)
|
||||
return out
|
||||
|
||||
def encode_error(self, context, errordetail):
|
||||
out = context.outformat.encode_error(
|
||||
context, errordetail
|
||||
)
|
||||
return out
|
||||
|
@ -11,9 +11,15 @@ from simplegeneric import generic
|
||||
|
||||
from wsme.rest.protocol import RestProtocol
|
||||
import wsme.types
|
||||
from wsme.exc import UnknownArgument
|
||||
|
||||
import re
|
||||
|
||||
content_type = 'text/xml'
|
||||
accept_content_types = [
|
||||
content_type,
|
||||
]
|
||||
|
||||
time_re = re.compile(r'(?P<h>[0-2][0-9]):(?P<m>[0-5][0-9]):(?P<s>[0-6][0-9])')
|
||||
|
||||
|
||||
@ -244,6 +250,11 @@ class RestXmlProtocol(RestProtocol):
|
||||
return et.tostring(
|
||||
toxml(context.funcdef.return_type, 'result', result))
|
||||
|
||||
|
||||
class RestXml(object):
|
||||
name = 'xml'
|
||||
content_type = 'text/xml'
|
||||
|
||||
def encode_error(self, context, errordetail):
|
||||
el = et.Element('error')
|
||||
et.SubElement(el, 'faultcode').text = errordetail['faultcode']
|
||||
@ -274,3 +285,37 @@ class RestXmlProtocol(RestProtocol):
|
||||
xml_indent(r)
|
||||
content = et.tostring(r)
|
||||
return ('xml', content)
|
||||
|
||||
|
||||
def get_format():
|
||||
return RestXml()
|
||||
|
||||
|
||||
def parse(s, datatypes, bodyarg):
|
||||
if hasattr(s, 'read'):
|
||||
tree = et.parse(s)
|
||||
else:
|
||||
tree = et.fromstring(s)
|
||||
if bodyarg:
|
||||
name = list(datatypes.keys())[0]
|
||||
return fromxml(datatypes[name], tree)
|
||||
else:
|
||||
kw = {}
|
||||
for sub in tree:
|
||||
if sub.tag not in datatypes:
|
||||
raise UnknownArgument(sub.tag)
|
||||
kw[sub.tag] = fromxml(datatypes[sub.tag], sub)
|
||||
return kw
|
||||
|
||||
|
||||
def tostring(value, datatype, attrname='result'):
|
||||
return et.tostring(toxml(datatype, attrname, value))
|
||||
|
||||
|
||||
def encode_error(context, errordetail):
|
||||
el = et.Element('error')
|
||||
et.SubElement(el, 'faultcode').text = errordetail['faultcode']
|
||||
et.SubElement(el, 'faultstring').text = errordetail['faultstring']
|
||||
if 'debuginfo' in errordetail:
|
||||
et.SubElement(el, 'debuginfo').text = errordetail['debuginfo']
|
||||
return et.tostring(el)
|
||||
|
@ -3,7 +3,7 @@ from wsme import types
|
||||
try:
|
||||
import simplejson as json
|
||||
except ImportError:
|
||||
import json
|
||||
import json # noqa
|
||||
|
||||
|
||||
def getdesc(root, host_url=''):
|
||||
|
Loading…
Reference in New Issue
Block a user