diff --git a/oslo_serialization/msgpackutils.py b/oslo_serialization/msgpackutils.py index 70705f2..c891a36 100644 --- a/oslo_serialization/msgpackutils.py +++ b/oslo_serialization/msgpackutils.py @@ -41,6 +41,36 @@ import six.moves.xmlrpc_client as xmlrpclib netaddr = importutils.try_import("netaddr") + +class Interval(object): + """Small and/or simple immutable integer/float interval class. + + Interval checking is **inclusive** of the min/max boundaries. + """ + + def __init__(self, min_value, max_value): + if min_value > max_value: + raise ValueError("Minimum value %s must be less than" + " or equal to maximum value %s" % (min_value, + max_value)) + self._min_value = min_value + self._max_value = max_value + + @property + def min_value(self): + return self._min_value + + @property + def max_value(self): + return self._max_value + + def __contains__(self, value): + return value >= self.min_value and value <= self.max_value + + def __repr__(self): + return 'Interval(%s, %s)' % (self._min_value, self._max_value) + + # Expose these so that users don't have to import msgpack to gain these. PackException = msgpack.PackException @@ -61,48 +91,115 @@ class HandlerRegistry(object): .. versionadded:: 1.5 """ - # Applications can assign 0 to 127 to store - # application-specific type information... + reserved_extension_range = Interval(0, 32) + """ + These ranges are **always** reserved for use by ``oslo.serialization`` and + its own add-ons extensions (these extensions are meant to be generally + applicable to all of python). + """ + + non_reserved_extension_range = Interval(33, 127) + """ + These ranges are **always** reserved for use by applications building + their own type specific handlers (the meaning of extensions in this range + will typically vary depending on application). + """ + min_value = 0 + """ + Applications can assign 0 to 127 to store application (or library) + specific type handlers; see above ranges for what is reserved by this + library and what is not. + """ + max_value = 127 + """ + Applications can assign 0 to 127 to store application (or library) + specific type handlers; see above ranges for what is reserved by this + library and what is not. + """ def __init__(self): self._handlers = {} + self._num_handlers = 0 self.frozen = False def __iter__(self): - return six.itervalues(self._handlers) + """Iterates over **all** registered handlers.""" + for handlers in six.itervalues(self._handlers): + for h in handlers: + yield h - def register(self, handler): + def register(self, handler, reserved=False, override=False): """Register a extension handler to handle its associated type.""" if self.frozen: raise ValueError("Frozen handler registry can't be modified") + if reserved: + ok_interval = self.reserved_extension_range + else: + ok_interval = self.non_reserved_extension_range ident = handler.identity - if ident < self.min_value: + if ident < ok_interval.min_value: raise ValueError("Handler '%s' identity must be greater" - " or equal to %s" % (handler, self.min_value)) - if ident > self.max_value: + " or equal to %s" % (handler, + ok_interval.min_value)) + if ident > ok_interval.max_value: raise ValueError("Handler '%s' identity must be less than" - " or equal to %s" % (handler, self.max_value)) - if ident in self._handlers: - raise ValueError("Already registered handler with" + " or equal to %s" % (handler, + ok_interval.max_value)) + if ident in self._handlers and override: + existing_handlers = self._handlers[ident] + # Insert at the front so that overrides get selected before + # whatever existed before the override... + existing_handlers.insert(0, handler) + self._num_handlers += 1 + elif ident in self._handlers and not override: + raise ValueError("Already registered handler(s) with" " identity %s: %s" % (ident, self._handlers[ident])) else: - self._handlers[ident] = handler + self._handlers[ident] = [handler] + self._num_handlers += 1 def __len__(self): - return len(self._handlers) + """Return how many extension handlers are registered.""" + return self._num_handlers + + def __contains__(self, identity): + """Return if any handler exists for the given identity (number).""" + return identity in self._handlers + + def copy(self, unfreeze=False): + """Deep copy the given registry (and its handlers).""" + c = type(self)() + for ident, handlers in six.iteritems(self._handlers): + cloned_handlers = [] + for h in handlers: + if hasattr(h, 'copy'): + h = h.copy(c) + cloned_handlers.append(h) + c._handlers[ident] = cloned_handlers + c._num_handlers += len(cloned_handlers) + if not unfreeze and self.frozen: + c.frozen = True + return c def get(self, identity): - """Get the handle for the given numeric identity (or none).""" - return self._handlers.get(identity, None) + """Get the handler for the given numeric identity (or none).""" + maybe_handlers = self._handlers.get(identity) + if maybe_handlers: + # Prefer the first (if there are many) as this is how we + # override built-in extensions (for those that wish to do this). + return maybe_handlers[0] + else: + return None def match(self, obj): """Match the registries handlers to the given object (or none).""" - for handler in six.itervalues(self._handlers): - if isinstance(obj, handler.handles): - return handler + for possible_handlers in six.itervalues(self._handlers): + for h in possible_handlers: + if isinstance(obj, h.handles): + return h return None @@ -126,6 +223,9 @@ class DateTimeHandler(object): def __init__(self, registry): self._registry = registry + def copy(self, registry): + return type(self)(registry) + def serialize(self, dt): dct = { u'day': dt.day, @@ -222,6 +322,9 @@ class SetHandler(object): def __init__(self, registry): self._registry = registry + def copy(self, registry): + return type(self)(registry) + def serialize(self, obj): return dumps(list(obj), registry=self._registry) @@ -241,6 +344,9 @@ class XMLRPCDateTimeHandler(object): def __init__(self, registry): self._handler = DateTimeHandler(registry) + def copy(self, registry): + return type(self)(registry) + def serialize(self, obj): dt = datetime.datetime(*tuple(obj.timetuple())[:6]) return self._handler.serialize(dt) @@ -257,6 +363,9 @@ class DateHandler(object): def __init__(self, registry): self._registry = registry + def copy(self, registry): + return type(self)(registry) + def serialize(self, d): dct = { u'year': d.year, @@ -286,7 +395,7 @@ def _serializer(registry, obj): def _unserializer(registry, code, data): handler = registry.get(code) - if handler is None: + if not handler: return msgpack.ExtType(code, data) else: return handler.deserialize(data) @@ -294,15 +403,15 @@ def _unserializer(registry, code, data): def _create_default_registry(): registry = HandlerRegistry() - registry.register(DateTimeHandler(registry)) - registry.register(DateHandler(registry)) - registry.register(UUIDHandler()) - registry.register(CountHandler()) - registry.register(SetHandler(registry)) - registry.register(FrozenSetHandler(registry)) + registry.register(DateTimeHandler(registry), reserved=True) + registry.register(DateHandler(registry), reserved=True) + registry.register(UUIDHandler(), reserved=True) + registry.register(CountHandler(), reserved=True) + registry.register(SetHandler(registry), reserved=True) + registry.register(FrozenSetHandler(registry), reserved=True) if netaddr is not None: - registry.register(NetAddrIPHandler()) - registry.register(XMLRPCDateTimeHandler(registry)) + registry.register(NetAddrIPHandler(), reserved=True) + registry.register(XMLRPCDateTimeHandler(registry), reserved=True) registry.frozen = True return registry diff --git a/oslo_serialization/tests/test_msgpackutils.py b/oslo_serialization/tests/test_msgpackutils.py new file mode 100644 index 0000000..2374793 --- /dev/null +++ b/oslo_serialization/tests/test_msgpackutils.py @@ -0,0 +1,211 @@ +# Copyright (C) 2015 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import datetime +import itertools +import uuid + +import netaddr +from oslotest import base as test_base +from pytz import timezone +import six +import six.moves.xmlrpc_client as xmlrpclib + +from oslo_serialization import msgpackutils + + +_TZ_FMT = '%Y-%m-%d %H:%M:%S %Z%z' + + +class Color(object): + def __init__(self, r, g, b): + self.r = r + self.g = g + self.b = b + + +class ColorHandler(object): + handles = (Color,) + identity = ( + msgpackutils.HandlerRegistry.non_reserved_extension_range.min_value + 1 + ) + + @staticmethod + def serialize(obj): + blob = '%s, %s, %s' % (obj.r, obj.g, obj.b) + if six.PY3: + blob = blob.encode("ascii") + return blob + + @staticmethod + def deserialize(data): + chunks = [int(c.strip()) for c in data.split(b",")] + return Color(chunks[0], chunks[1], chunks[2]) + + +class MySpecialSetHandler(object): + handles = (set,) + identity = msgpackutils.SetHandler.identity + + +def _dumps_loads(obj): + obj = msgpackutils.dumps(obj) + return msgpackutils.loads(obj) + + +class MsgPackUtilsTest(test_base.BaseTestCase): + def test_list(self): + self.assertEqual(_dumps_loads([1, 2, 3]), [1, 2, 3]) + + def test_empty_list(self): + self.assertEqual(_dumps_loads([]), []) + + def test_tuple(self): + # Seems like we do lose whether it was a tuple or not... + # + # Maybe fixed someday: + # + # https://github.com/msgpack/msgpack-python/issues/98 + self.assertEqual(_dumps_loads((1, 2, 3)), [1, 2, 3]) + + def test_dict(self): + self.assertEqual(_dumps_loads(dict(a=1, b=2, c=3)), + dict(a=1, b=2, c=3)) + + def test_empty_dict(self): + self.assertEqual(_dumps_loads({}), {}) + + def test_complex_dict(self): + src = { + 'now': datetime.datetime(1920, 2, 3, 4, 5, 6, 7), + 'later': datetime.datetime(1921, 2, 3, 4, 5, 6, 9), + 'a': 1, + 'b': 2.0, + 'c': [], + 'd': set([1, 2, 3]), + 'zzz': uuid.uuid4(), + 'yyy': 'yyy', + 'ddd': b'bbb', + 'today': datetime.date.today(), + } + self.assertEqual(_dumps_loads(src), src) + + def test_itercount(self): + it = itertools.count(1) + six.next(it) + six.next(it) + it2 = _dumps_loads(it) + self.assertEqual(six.next(it), six.next(it2)) + + it = itertools.count(0) + it2 = _dumps_loads(it) + self.assertEqual(six.next(it), six.next(it2)) + + def test_itercount_step(self): + it = itertools.count(1, 3) + it2 = _dumps_loads(it) + self.assertEqual(six.next(it), six.next(it2)) + + def test_set(self): + self.assertEqual(_dumps_loads(set([1, 2])), set([1, 2])) + + def test_empty_set(self): + self.assertEqual(_dumps_loads(set([])), set([])) + + def test_frozenset(self): + self.assertEqual(_dumps_loads(frozenset([1, 2])), frozenset([1, 2])) + + def test_empty_frozenset(self): + self.assertEqual(_dumps_loads(frozenset([])), frozenset([])) + + def test_datetime_preserve(self): + x = datetime.datetime(1920, 2, 3, 4, 5, 6, 7) + self.assertEqual(_dumps_loads(x), x) + + def test_datetime(self): + x = xmlrpclib.DateTime() + x.decode("19710203T04:05:06") + self.assertEqual(_dumps_loads(x), x) + + def test_ipaddr(self): + thing = {'ip_addr': netaddr.IPAddress('1.2.3.4')} + self.assertEqual(_dumps_loads(thing), thing) + + def test_today(self): + today = datetime.date.today() + self.assertEqual(today, _dumps_loads(today)) + + def test_datetime_tz_clone(self): + eastern = timezone('US/Eastern') + now = datetime.datetime.now() + e_dt = eastern.localize(now) + e_dt2 = _dumps_loads(e_dt) + self.assertEqual(e_dt, e_dt2) + self.assertEqual(e_dt.strftime(_TZ_FMT), e_dt2.strftime(_TZ_FMT)) + + def test_datetime_tz_different(self): + eastern = timezone('US/Eastern') + pacific = timezone('US/Pacific') + now = datetime.datetime.now() + + e_dt = eastern.localize(now) + p_dt = pacific.localize(now) + + self.assertNotEqual(e_dt, p_dt) + self.assertNotEqual(e_dt.strftime(_TZ_FMT), p_dt.strftime(_TZ_FMT)) + + e_dt2 = _dumps_loads(e_dt) + p_dt2 = _dumps_loads(p_dt) + + self.assertNotEqual(e_dt2, p_dt2) + self.assertNotEqual(e_dt2.strftime(_TZ_FMT), p_dt2.strftime(_TZ_FMT)) + + self.assertEqual(e_dt, e_dt2) + self.assertEqual(p_dt, p_dt2) + + def test_copy_then_register(self): + registry = msgpackutils.default_registry + self.assertRaises(ValueError, + registry.register, MySpecialSetHandler(), + reserved=True, override=True) + registry = registry.copy(unfreeze=True) + registry.register(MySpecialSetHandler(), + reserved=True, override=True) + h = registry.match(set()) + self.assertIsInstance(h, MySpecialSetHandler) + + def test_bad_register(self): + registry = msgpackutils.default_registry + self.assertRaises(ValueError, + registry.register, MySpecialSetHandler(), + reserved=True, override=True) + self.assertRaises(ValueError, + registry.register, MySpecialSetHandler()) + registry = registry.copy(unfreeze=True) + registry.register(ColorHandler()) + + self.assertRaises(ValueError, + registry.register, ColorHandler()) + + def test_custom_register(self): + registry = msgpackutils.default_registry.copy(unfreeze=True) + registry.register(ColorHandler()) + + c = Color(255, 254, 253) + c_b = msgpackutils.dumps(c, registry=registry) + c = msgpackutils.loads(c_b, registry=registry) + + self.assertEqual(255, c.r) + self.assertEqual(254, c.g) + self.assertEqual(253, c.b)