diff --git a/wsme/protocols/restjson.py b/wsme/protocols/restjson.py index 9a59143..a47ccca 100644 --- a/wsme/protocols/restjson.py +++ b/wsme/protocols/restjson.py @@ -51,6 +51,18 @@ def array_tojson(datatype, value): return [tojson(datatype[0], item) for item in value] +@tojson.when_type(dict) +def dict_tojson(datatype, value): + if value is None: + return None + key_type = datatype.keys()[0] + value_type = datatype.values()[0] + return dict(( + (tojson(key_type, item[0]), tojson(value_type, item[1])) + for item in value.items() + )) + + @tojson.when_object(decimal.Decimal) def decimal_tojson(datatype, value): if value is None: @@ -118,6 +130,17 @@ def array_fromjson(datatype, value): return [fromjson(datatype[0], item) for item in value] +@fromjson.when_type(dict) +def dict_fromjson(datatype, value): + if value is None: + return None + key_type = datatype.keys()[0] + value_type = datatype.values()[0] + return dict(( + (fromjson(key_type, item[0]), fromjson(value_type, item[1])) + for item in value.items())) + + @fromjson.when_object(str) def str_fromjson(datatype, value): if value is None: diff --git a/wsme/protocols/restxml.py b/wsme/protocols/restxml.py index cd0395a..2e90875 100644 --- a/wsme/protocols/restxml.py +++ b/wsme/protocols/restxml.py @@ -97,6 +97,24 @@ def array_toxml(datatype, key, value): return el +@toxml.when_type(dict) +def dict_toxml(datatype, key, value): + el = et.Element(key) + if value is None: + el.set('nil', 'true') + else: + key_type = datatype.keys()[0] + value_type = datatype.values()[0] + for item in value.items(): + key = toxml(key_type, 'key', item[0]) + value = toxml(value_type, 'value', item[1]) + node = et.Element('item') + node.append(key) + node.append(value) + el.append(node) + return el + + @toxml.when_object(bool) def bool_toxml(datatype, key, value): el = et.Element(key) @@ -134,6 +152,18 @@ def array_fromxml(datatype, element): return [fromxml(datatype[0], item) for item in element.findall('item')] +@fromxml.when_type(dict) +def dict_fromxml(datatype, element): + if element.get('nil') == 'true': + return None + key_type = datatype.keys()[0] + value_type = datatype.values()[0] + return dict(( + (fromxml(key_type, item.find('key')), + fromxml(value_type, item.find('value'))) + for item in element.findall('item'))) + + @fromxml.when_object(datetime.date) def date_fromxml(datatype, element): return datetime.datetime.strptime(element.text, '%Y-%m-%d').date() diff --git a/wsme/tests/protocol.py b/wsme/tests/protocol.py index 836ad4f..077241d 100644 --- a/wsme/tests/protocol.py +++ b/wsme/tests/protocol.py @@ -115,6 +115,10 @@ class ReturnTypes(object): def getnestedarray(self): return [NestedOuter(), NestedOuter()] + @expose({str: NestedOuter}) + def getnesteddict(self): + return {'a': NestedOuter(), 'b': NestedOuter()} + @expose(myenumtype) def getenum(self): return 'v2' @@ -226,6 +230,15 @@ class ArgTypes(object): assert type(value[0]) == NestedOuter return value + @expose({str: NestedOuter}) + @validate({str: NestedOuter}) + def setnesteddict(self, value): + print repr(value) + assert type(value) == dict + assert type(value.keys()[0]) == str + assert type(value.values()[0]) == NestedOuter + return value + @expose(myenumtype) @validate(myenumtype) def setenum(self, value): @@ -363,10 +376,14 @@ class ProtocolTestCase(unittest.TestCase): r = self.call('returntypes/getstrarray', _rt=[str]) assert r == ['A', 'B', 'C'], r - def test_return_strnested(self): + def test_return_nestedarray(self): r = self.call('returntypes/getnestedarray', _rt=[NestedOuter]) assert r == [{'inner': {'aint': 0}}, {'inner': {'aint': 0}}], r + def test_return_nesteddict(self): + r = self.call('returntypes/getnesteddict', _rt={str:NestedOuter}) + assert r == {'a': {'inner': {'aint': 0}}, 'b': {'inner': {'aint': 0}}} + def test_return_enum(self): r = self.call('returntypes/getenum', _rt=myenumtype) assert r == 'v2', r @@ -453,6 +470,17 @@ class ProtocolTestCase(unittest.TestCase): _rt=[NestedOuter]) assert r == value + def test_setnesteddict(self): + value = { + 'o1': {'inner': {'aint': 54}}, + 'o2': {'inner': {'aint': 55}}, + } + r = self.call('argtypes/setnesteddict', + value=(value, {str: NestedOuter}), + _rt={str: NestedOuter}) + print r + assert r == value + def test_setenum(self): value = 'v1' r = self.call('argtypes/setenum', value=value, diff --git a/wsme/tests/test_restjson.py b/wsme/tests/test_restjson.py index a853068..eeda730 100644 --- a/wsme/tests/test_restjson.py +++ b/wsme/tests/test_restjson.py @@ -18,6 +18,12 @@ from wsme.types import isusertype def prepare_value(value, datatype): if isinstance(datatype, list): return [prepare_value(item, datatype[0]) for item in value] + if isinstance(datatype, dict): + return dict(( + (prepare_value(item[0], datatype.keys()[0]), + prepare_value(item[1], datatype.values()[0])) + for item in value.items() + )) if datatype in (datetime.date, datetime.time, datetime.datetime): return value.isoformat() if datatype == decimal.Decimal: @@ -32,6 +38,12 @@ def prepare_result(value, datatype): datatype = datatype.basetype if isinstance(datatype, list): return [prepare_result(item, datatype[0]) for item in value] + if isinstance(datatype, dict): + return dict(( + (prepare_result(item[0], datatype.keys()[0]), + prepare_result(item[1], datatype.values()[0])) + for item in value.items() + )) if datatype == datetime.date: return parse_isodate(value) if datatype == datetime.time: diff --git a/wsme/tests/test_restxml.py b/wsme/tests/test_restxml.py index a941f41..3f16f81 100644 --- a/wsme/tests/test_restxml.py +++ b/wsme/tests/test_restxml.py @@ -3,7 +3,7 @@ import datetime import base64 import wsme.tests.protocol -from wsme.utils import * +from wsme.utils import parse_isodatetime, parse_isodate, parse_isotime from wsme.types import isusertype try: @@ -19,6 +19,11 @@ def dumpxml(key, obj, datatype=None): if isinstance(datatype, list): for item in obj: el.append(dumpxml('item', item, datatype[0])) + elif isinstance(datatype, dict): + for item in obj.items(): + node = et.SubElement(el, 'item') + node.append(dumpxml('key', item[0], datatype.keys()[0])) + node.append(dumpxml('value', item[1], datatype.values()[0])) elif datatype == wsme.types.binary: el.text = base64.encodestring(obj) elif isinstance(obj, basestring): @@ -46,6 +51,12 @@ def loadxml(el, datatype): return None if isinstance(datatype, list): return [loadxml(item, datatype[0]) for item in el.findall('item')] + elif isinstance(datatype, dict): + return dict(( + (loadxml(item.find('key'), datatype.keys()[0]), + loadxml(item.find('value'), datatype.values()[0])) + for item in el.findall('item') + )) elif len(el): d = {} for attr in datatype._wsme_attributes: diff --git a/wsme/types.py b/wsme/types.py index 1872676..94eb05b 100644 --- a/wsme/types.py +++ b/wsme/types.py @@ -75,6 +75,7 @@ native_types = pod_types + dt_types + extra_types complex_types = [] array_types = [] +dict_types = [] class UnsetType(object): @@ -93,6 +94,10 @@ def isarray(datatype): return isinstance(datatype, list) +def isdict(datatype): + return isinstance(datatype, dict) + + def validate_value(datatype, value): if hasattr(datatype, 'validate'): return datatype.validate(value) @@ -107,6 +112,16 @@ def validate_value(datatype, value): )) for item in value: validate_value(datatype[0], item) + elif isdict(datatype): + if not isinstance(value, dict): + raise ValueError("Wrong type. Expected '%s', got '%s'" % ( + datatype, type(value) + )) + key_type = datatype.keys()[0] + value_type = datatype.values()[0] + for key, v in value.items(): + validate_value(key_type, key) + validate_value(value_type, value) elif not isinstance(value, datatype): raise ValueError( "Wrong type. Expected '%s', got '%s'" % ( @@ -290,7 +305,19 @@ def register_type(class_): if len(class_) != 1: raise ValueError("Cannot register type %s" % repr(class_)) register_type(class_[0]) - array_types.append(class_[0]) + if class_[0] not in array_types: + array_types.append(class_[0]) + return + + if isinstance(class_, dict): + if len(class_) != 1: + raise ValueError("Cannot register type %s" % repr(class_)) + if class_.keys()[0] not in pod_types: + raise ValueError("Dictionnaries key can only be a pod type") + register_type(class_.values()[0]) + t = (class_.keys()[0], class_.values()[0]) + if t not in dict_types: + dict_types.append(t) return class_._wsme_attributes = None