diff --git a/wsme/exc.py b/wsme/exc.py index 54e22f5..b3529ff 100644 --- a/wsme/exc.py +++ b/wsme/exc.py @@ -22,6 +22,18 @@ class InvalidInput(ClientSideError): return unicode(self).encode('utf8', 'ignore') +class MissingArgument(ClientSideError): + def __init__(self, argname, msg=''): + self.argname = argname + self.msg = msg + + def __unicode__(self): + return _(u"Missing argument: %s. %s") % (self.argname, self.msg) + + def __str__(self): + return unicode(self).encode('utf8', 'ignore') + + class UnknownFunction(ClientSideError): def __init__(self, name): self.name = name diff --git a/wsme/restjson.py b/wsme/restjson.py index 35ccb31..68dac20 100644 --- a/wsme/restjson.py +++ b/wsme/restjson.py @@ -49,13 +49,57 @@ def datetime_tojson(datatype, value): return base64.encodestring(value) +@generic +def fromjson(datatype, value): + if value is None: + return None + if wsme.types.isstructured(datatype): + obj = datatype() + for name, attrdef in wsme.types.list_attributes(datatype): + if name in value: + setattr(obj, name, fromjson(attrdef.datatype, value[name])) + return obj + return value + + +@fromjson.when_object(decimal.Decimal) +def decimal_fromjson(datatype, value): + return decimal.Decimal(value) + + +@fromjson.when_object(datetime.date) +def date_fromjson(datatype, value): + return datetime.datetime.strptime(value, '%Y-%m-%d').date() + + +@fromjson.when_object(datetime.time) +def time_fromjson(datatype, value): + return datetime.datetime.strptime(value, '%H:%M:%S').time() + + +@fromjson.when_object(datetime.datetime) +def time_fromjson(datatype, value): + return datetime.datetime.strptime(value, '%Y-%m-%dT%H:%M:%S') + + +@fromjson.when_object(wsme.types.binary) +def binary_fromjson(datatype, value): + return base64.decodestring(value) + + class RestJsonProtocol(RestProtocol): name = 'REST+Json' dataformat = 'json' content_types = ['application/json', 'text/json', None] def decode_args(self, req, arguments): - kw = json.loads(req.body) + raw_args = json.loads(req.body) + kw = {} + for farg in arguments: + if farg.mandatory and farg.name not in raw_args: + raise MissingArgument(farg.name) + value = raw_args[farg.name] + kw[farg.name] = fromjson(farg.datatype, value) return kw def encode_result(self, result, return_type): diff --git a/wsme/restxml.py b/wsme/restxml.py index 2bb1448..1b8ae41 100644 --- a/wsme/restxml.py +++ b/wsme/restxml.py @@ -12,6 +12,10 @@ from wsme.rest import RestProtocol from wsme.controller import register_protocol import wsme.types +import re + +time_re = re.compile(r'(?P[0-2][0-9]):(?P[0-5][0-9]):(?P[0-6][0-9])') + @generic def toxml(datatype, key, value): @@ -27,6 +31,30 @@ def toxml(datatype, key, value): return el +@generic +def fromxml(datatype, element): + if element.get('nil', False): + return None + if wsme.types.isstructured(datatype): + obj = datatype() + for key, attrdef in datatype._wsme_attributes: + sub = element.find(key) + if sub is not None: + setattr(obj, key, fromxml(attrdef.datatype, sub)) + return obj + return datatype(element.text) + + +@toxml.when_object(bool) +def bool_toxml(datatype, key, value): + el = et.Element(key) + if value is None: + el.set('nil', 'true') + else: + el.text = value and 'true' or 'false' + return el + + @toxml.when_object(datetime.date) def date_toxml(datatype, key, value): el = et.Element(key) @@ -57,6 +85,31 @@ def binary_toxml(datatype, key, value): return el +@fromxml.when_object(datetime.date) +def date_fromxml(datatype, element): + return datetime.datetime.strptime(element.text, '%Y-%m-%d').date() + + +@fromxml.when_object(datetime.time) +def time_fromxml(datatype, element): + m = time_re.match(element.text) + if m: + return datetime.time( + int(m.group('h')), + int(m.group('m')), + int(m.group('s'))) + + +@fromxml.when_object(datetime.datetime) +def datetime_fromxml(datatype, element): + return datetime.datetime.strptime(element.text, '%Y-%m-%dT%H:%M:%S') + + +@fromxml.when_object(wsme.types.binary) +def binary_fromxml(datatype, element): + return base64.decodestring(element.text) + + class RestXmlProtocol(RestProtocol): name = 'REST+XML' dataformat = 'xml' @@ -66,6 +119,12 @@ class RestXmlProtocol(RestProtocol): el = et.fromstring(req.body) assert el.tag == 'parameters' kw = {} + for farg in arguments: + sub = el.find(farg.name) + if farg.mandatory and sub is None: + raise MissingArgument(farg.name) + if sub is not None: + kw[farg.name] = fromxml(farg.datatype, sub) return kw def encode_result(self, result, return_type): diff --git a/wsme/tests/protocol.py b/wsme/tests/protocol.py index fb01fde..42178ee 100644 --- a/wsme/tests/protocol.py +++ b/wsme/tests/protocol.py @@ -85,6 +85,85 @@ class ReturnTypes(object): return n +class ArgTypes(object): + @expose(str) + @validate(str) + def setstr(self, value): + print repr(value) + assert type(value) == str + return value + + @expose(unicode) + @validate(unicode) + def setunicode(self, value): + print repr(value) + assert type(value) == unicode + return value + + @expose(bool) + @validate(bool) + def setbool(self, value): + print repr(value) + assert type(value) == bool + return value + + @expose(int) + @validate(int) + def setint(self, value): + print repr(value) + assert type(value) == int + return value + + @expose(float) + @validate(float) + def setfloat(self, value): + print repr(value) + assert type(value) == float + return value + + @expose(decimal.Decimal) + @validate(decimal.Decimal) + def setdecimal(self, value): + print repr(value) + assert type(value) == decimal.Decimal + return value + + @expose(datetime.date) + @validate(datetime.date) + def setdate(self, value): + print repr(value) + assert type(value) == datetime.date + return value + + @expose(datetime.time) + @validate(datetime.time) + def settime(self, value): + print repr(value) + assert type(value) == datetime.time + return value + + @expose(datetime.datetime) + @validate(datetime.datetime) + def setdatetime(self, value): + print repr(value) + assert type(value) == datetime.datetime + return value + + @expose(wsme.types.binary) + @validate(wsme.types.binary) + def setbinary(self, value): + print repr(value) + assert type(value) == str + return value + + @expose(NestedOuter) + @validate(NestedOuter) + def setnested(self, value): + print repr(value) + assert type(value) == NestedOuter + return value + + class WithErrors(object): @expose() def divide_by_zero(self): @@ -92,6 +171,7 @@ class WithErrors(object): class WSTestRoot(WSRoot): + argtypes = ArgTypes() returntypes = ReturnTypes() witherrors = WithErrors() @@ -167,10 +247,46 @@ class ProtocolTestCase(unittest.TestCase): r = self.call('returntypes/getbinary') assert r == binarysample or r == base64.encodestring(binarysample), r - def test_return_binary(self): - r = self.call('returntypes/getbinary') - assert r == binarysample or r == base64.encodestring(binarysample), r - def test_return_nested(self): r = self.call('returntypes/getnested') assert r == {'inner': {'aint': 0}} or r == {'inner': {'aint': '0'}}, r + + def test_setstr(self): + assert self.call('argtypes/setstr', value='astring') in ('astring',) + + def test_setunicode(self): + assert self.call('argtypes/setunicode', value=u'の') in (u'の',) + + def test_setint(self): + assert self.call('argtypes/setint', value=3) in (3, '3') + + def test_setfloat(self): + return self.call('argtypes/setfloat', value=3.54) in (3.54, '3.54') + + def test_setdecimal(self): + return self.call('argtypes/setdecimal', value='3.14') in ('3.14', decimal.Decimal('3.14')) + + def test_setdate(self): + return self.call('argtypes/setdate', value='2008-04-06') in ( + datetime.date(2008, 4, 6), '2008-04-06') + + def test_settime(self): + return self.call('argtypes/settime', value='12:12:15') \ + in ('12:12:15', datetime.time(12, 12, 15)) + + def test_setdatetime(self): + return self.call('argtypes/setdatetime', value='2008-04-06T12:12:15') \ + in ('2008-04-06T12:12:15', + datetime.datetime(2008, 4, 6, 12, 12, 15)) + + def test_setbinary(self): + r = self.call('argtypes/setbinary', + value=base64.encodestring(binarysample)) + assert r == binarysample or r == base64.encodestring(binarysample), r + + def test_setnested(self): + return self.call('argtypes/setnested', + value={'inner': {'aint': 54}}) in ( + {'inner': {'aint': 54}}, + {'inner': {'aint': '54'}} + ) diff --git a/wsme/tests/test_restxml.py b/wsme/tests/test_restxml.py index 7125a65..84c87b2 100644 --- a/wsme/tests/test_restxml.py +++ b/wsme/tests/test_restxml.py @@ -14,14 +14,14 @@ import wsme.restxml def dumpxml(key, obj): el = et.Element(key) if isinstance(obj, basestring): - el(obj) + el.text = obj elif type(obj) in (int, float, decimal.Decimal): - el(str(obj)) + el.text = str(obj) elif type(obj) in (datetime.date, datetime.time, datetime.datetime): - el(obj.isoformat()) + el.text = obj.isoformat() elif type(obj) == dict: for key, obj in obj.items(): - e.append(dumpxml(key, obj)) + el.append(dumpxml(key, obj)) return el @@ -35,7 +35,7 @@ def loadxml(el): return el.text -class TestRestJson(wsme.tests.protocol.ProtocolTestCase): +class TestRestXML(wsme.tests.protocol.ProtocolTestCase): protocol = 'REST+XML' def call(self, fpath, **kw):