From 00f322a20451a49ad389d53df7ea0f25e4ac2e75 Mon Sep 17 00:00:00 2001 From: Adam Holmberg Date: Wed, 14 Jan 2015 17:42:39 -0600 Subject: [PATCH] New OrderedMap type in support of nested collections PYTHON-186 Updated existing tests, still need to add tests for nested collections. --- cassandra/cqltypes.py | 6 +-- cassandra/encoder.py | 3 +- cassandra/util.py | 70 ++++++++++++++++++++++++++++++++++ tests/unit/test_marshalling.py | 13 +++---- 4 files changed, 81 insertions(+), 11 deletions(-) diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index 100c2968..ad65a4db 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -47,7 +47,7 @@ from cassandra.marshal import (int8_pack, int8_unpack, uint16_pack, uint16_unpac int32_pack, int32_unpack, int64_pack, int64_unpack, float_pack, float_unpack, double_pack, double_unpack, varint_pack, varint_unpack) -from cassandra.util import OrderedDict, sortedset +from cassandra.util import OrderedMap, sortedset apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.' @@ -741,7 +741,7 @@ class MapType(_ParameterizedType): length = 2 numelements = unpack(byts[:length]) p = length - themap = OrderedDict() + themap = OrderedMap() for _ in range(numelements): key_len = unpack(byts[p:p + length]) p += length @@ -753,7 +753,7 @@ class MapType(_ParameterizedType): p += val_len key = subkeytype.from_binary(keybytes, protocol_version) val = subvaltype.from_binary(valbytes, protocol_version) - themap[key] = val + themap._insert(key, val) return themap @classmethod diff --git a/cassandra/encoder.py b/cassandra/encoder.py index 8b1bb2fe..9b478901 100644 --- a/cassandra/encoder.py +++ b/cassandra/encoder.py @@ -28,7 +28,7 @@ import types from uuid import UUID import six -from cassandra.util import OrderedDict, sortedset +from cassandra.util import OrderedDict, OrderedMap, sortedset if six.PY3: long = int @@ -76,6 +76,7 @@ class Encoder(object): datetime.date: self.cql_encode_date, dict: self.cql_encode_map_collection, OrderedDict: self.cql_encode_map_collection, + OrderedMap: self.cql_encode_map_collection, list: self.cql_encode_list_collection, tuple: self.cql_encode_list_collection, set: self.cql_encode_set_collection, diff --git a/cassandra/util.py b/cassandra/util.py index d02219ff..87f520eb 100644 --- a/cassandra/util.py +++ b/cassandra/util.py @@ -555,3 +555,73 @@ except ImportError: if item in other: isect.add(item) return isect + +from collections import Mapping +import six +from six.moves import cPickle + +class OrderedMap(Mapping): + ''' + An ordered map that accepts non-hashable types for keys. + + Implemented in support of Cassandra nested collections. + + ''' + def __init__(self, *args, **kwargs): + if len(args) > 1: + raise TypeError('expected at most 1 arguments, got %d' % len(args)) + + self._items = [] + self._index = {} + if args: + e = args[0] + if callable(getattr(e, 'keys', None)): + for k in e.keys(): + self._items.append((k, e[k])) + else: + for k, v in e: + self._insert(k, v) + + for k, v in six.iteritems(kwargs): + self._insert(k, v) + + def _insert(self, key, value): + flat_key = self._serialize_key(key) + i = self._index.get(flat_key, -1) + if i >= 0: + self._items[i] = (key, value) + else: + self._items.append((key, value)) + self._index[flat_key] = len(self._items) - 1 + + def __getitem__(self, key): + index = self._index[self._serialize_key(key)] + return self._items[index][1] + + def __iter__(self): + for i in self._items: + yield i[0] + + def __len__(self): + return len(self._items) + + def __eq__(self, other): + if isinstance(other, OrderedMap): + return self._items == other._items + try: + d = dict(other) + return len(d) == len(self._items) and all(i[1] == d[i[0]] for i in self._items) + except KeyError: + return False + except TypeError: + pass + return NotImplemented + + def __repr__(self): + return '%s[%s]' % ( + self.__class__.__name__, + ', '.join("(%r, %r)" % (k, v) for k, v in self._items)) + + @staticmethod + def _serialize_key(key): + return cPickle.dumps(key) diff --git a/tests/unit/test_marshalling.py b/tests/unit/test_marshalling.py index 05239b47..54686e24 100644 --- a/tests/unit/test_marshalling.py +++ b/tests/unit/test_marshalling.py @@ -24,7 +24,7 @@ from decimal import Decimal from uuid import UUID from cassandra.cqltypes import lookup_casstype -from cassandra.util import OrderedDict, sortedset +from cassandra.util import OrderedMap, sortedset marshalled_value_pairs = ( # binary form, type, python native type @@ -75,21 +75,20 @@ marshalled_value_pairs = ( (b'', 'MapType(AsciiType, BooleanType)', None), (b'', 'ListType(FloatType)', None), (b'', 'SetType(LongType)', None), - (b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedDict()), + (b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedMap()), (b'\x00\x00', 'ListType(FloatType)', []), (b'\x00\x00', 'SetType(IntegerType)', sortedset()), (b'\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', [UUID(bytes=b'\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0')]), ) -ordered_dict_value = OrderedDict() -ordered_dict_value[u'\u307fbob'] = 199 -ordered_dict_value[u''] = -1 -ordered_dict_value[u'\\'] = 0 +ordered_map_value = OrderedMap([(u'\u307fbob', 199), + (u'', -1), + (u'\\', 0)]) # these following entries work for me right now, but they're dependent on # vagaries of internal python ordering for unordered types marshalled_value_pairs_unsafe = ( - (b'\x00\x03\x00\x06\xe3\x81\xbfbob\x00\x04\x00\x00\x00\xc7\x00\x00\x00\x04\xff\xff\xff\xff\x00\x01\\\x00\x04\x00\x00\x00\x00', 'MapType(UTF8Type, Int32Type)', ordered_dict_value), + (b'\x00\x03\x00\x06\xe3\x81\xbfbob\x00\x04\x00\x00\x00\xc7\x00\x00\x00\x04\xff\xff\xff\xff\x00\x01\\\x00\x04\x00\x00\x00\x00', 'MapType(UTF8Type, Int32Type)', ordered_map_value), (b'\x00\x02\x00\x08@\x01\x99\x99\x99\x99\x99\x9a\x00\x08@\x14\x00\x00\x00\x00\x00\x00', 'SetType(DoubleType)', sortedset([2.2, 5.0])), (b'\x00', 'IntegerType', 0), )