Merge branch 'PYTHON-231'
Conflicts: cassandra/cqltypes.py cassandra/util.py
This commit is contained in:
		| @@ -784,12 +784,12 @@ class MapType(_ParameterizedType): | ||||
|  | ||||
|     @classmethod | ||||
|     def validate(cls, val): | ||||
|         subkeytype, subvaltype = cls.subtypes | ||||
|         return dict((subkeytype.validate(k), subvaltype.validate(v)) for (k, v) in six.iteritems(val)) | ||||
|         key_type, value_type = cls.subtypes | ||||
|         return dict((key_type.validate(k), value_type.validate(v)) for (k, v) in six.iteritems(val)) | ||||
|  | ||||
|     @classmethod | ||||
|     def deserialize_safe(cls, byts, protocol_version): | ||||
|         subkeytype, subvaltype = cls.subtypes | ||||
|         key_type, value_type = cls.subtypes | ||||
|         if protocol_version >= 3: | ||||
|             unpack = int32_unpack | ||||
|             length = 4 | ||||
| @@ -798,7 +798,7 @@ class MapType(_ParameterizedType): | ||||
|             length = 2 | ||||
|         numelements = unpack(byts[:length]) | ||||
|         p = length | ||||
|         themap = util.OrderedMap() | ||||
|         themap = util.OrderedMapSerializedKey(key_type, protocol_version) | ||||
|         for _ in range(numelements): | ||||
|             key_len = unpack(byts[p:p + length]) | ||||
|             p += length | ||||
| @@ -808,14 +808,14 @@ class MapType(_ParameterizedType): | ||||
|             p += length | ||||
|             valbytes = byts[p:p + val_len] | ||||
|             p += val_len | ||||
|             key = subkeytype.from_binary(keybytes, protocol_version) | ||||
|             val = subvaltype.from_binary(valbytes, protocol_version) | ||||
|             themap._insert(key, val) | ||||
|             key = key_type.from_binary(keybytes, protocol_version) | ||||
|             val = value_type.from_binary(valbytes, protocol_version) | ||||
|             themap._insert_unchecked(key, keybytes, val) | ||||
|         return themap | ||||
|  | ||||
|     @classmethod | ||||
|     def serialize_safe(cls, themap, protocol_version): | ||||
|         subkeytype, subvaltype = cls.subtypes | ||||
|         key_type, value_type = cls.subtypes | ||||
|         pack = int32_pack if protocol_version >= 3 else uint16_pack | ||||
|         buf = io.BytesIO() | ||||
|         buf.write(pack(len(themap))) | ||||
| @@ -824,8 +824,8 @@ class MapType(_ParameterizedType): | ||||
|         except AttributeError: | ||||
|             raise TypeError("Got a non-map object for a map value") | ||||
|         for key, val in items: | ||||
|             keybytes = subkeytype.to_binary(key, protocol_version) | ||||
|             valbytes = subvaltype.to_binary(val, protocol_version) | ||||
|             keybytes = key_type.to_binary(key, protocol_version) | ||||
|             valbytes = value_type.to_binary(val, protocol_version) | ||||
|             buf.write(pack(len(keybytes))) | ||||
|             buf.write(keybytes) | ||||
|             buf.write(pack(len(valbytes))) | ||||
|   | ||||
| @@ -680,6 +680,7 @@ except ImportError: | ||||
|                         isect.add(item) | ||||
|             return isect | ||||
|  | ||||
|  | ||||
| from collections import Mapping | ||||
| from six.moves import cPickle | ||||
|  | ||||
| @@ -715,6 +716,7 @@ class OrderedMap(Mapping): | ||||
|     or higher. | ||||
|  | ||||
|     ''' | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         if len(args) > 1: | ||||
|             raise TypeError('expected at most 1 arguments, got %d' % len(args)) | ||||
| @@ -776,11 +778,25 @@ class OrderedMap(Mapping): | ||||
|     def __str__(self): | ||||
|         return '{%s}' % ', '.join("%s: %s" % (k, v) for k, v in self._items) | ||||
|  | ||||
|     @staticmethod | ||||
|     def _serialize_key(key): | ||||
|     def _serialize_key(self, key): | ||||
|         return cPickle.dumps(key) | ||||
|  | ||||
|  | ||||
| class OrderedMapSerializedKey(OrderedMap): | ||||
|  | ||||
|     def __init__(self, cass_type, protocol_version): | ||||
|         super(OrderedMapSerializedKey, self).__init__() | ||||
|         self.cass_key_type = cass_type | ||||
|         self.protocol_version = protocol_version | ||||
|  | ||||
|     def _insert_unchecked(self, key, flat_key, value): | ||||
|         self._items.append((key, value)) | ||||
|         self._index[flat_key] = len(self._items) - 1 | ||||
|  | ||||
|     def _serialize_key(self, key): | ||||
|         return self.cass_key_type.serialize(key, self.protocol_version) | ||||
|  | ||||
|  | ||||
| import datetime | ||||
| import time | ||||
|  | ||||
|   | ||||
| @@ -23,8 +23,8 @@ from datetime import datetime, date | ||||
| from decimal import Decimal | ||||
| from uuid import UUID | ||||
|  | ||||
| from cassandra.cqltypes import lookup_casstype | ||||
| from cassandra.util import OrderedMap, sortedset, Time | ||||
| from cassandra.cqltypes import lookup_casstype, DecimalType, UTF8Type | ||||
| from cassandra.util import OrderedMap, OrderedMapSerializedKey, sortedset, Time | ||||
|  | ||||
| marshalled_value_pairs = ( | ||||
|     # binary form, type, python native type | ||||
| @@ -75,7 +75,7 @@ marshalled_value_pairs = ( | ||||
|     (b'', 'MapType(AsciiType, BooleanType)', None), | ||||
|     (b'', 'ListType(FloatType)', None), | ||||
|     (b'', 'SetType(LongType)', None), | ||||
|     (b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedMap()), | ||||
|     (b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedMapSerializedKey(DecimalType, 0)), | ||||
|     (b'\x00\x00', 'ListType(FloatType)', []), | ||||
|     (b'\x00\x00', 'SetType(IntegerType)', sortedset()), | ||||
|     (b'\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', [UUID(bytes=b'\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0')]), | ||||
| @@ -84,9 +84,10 @@ marshalled_value_pairs = ( | ||||
|     (b'\x00\x00\x00\x00\x00\x00\x00\x01', 'TimeType', Time(1)) | ||||
| ) | ||||
|  | ||||
| ordered_map_value = OrderedMap([(u'\u307fbob', 199), | ||||
|                                 (u'', -1), | ||||
|                                 (u'\\', 0)]) | ||||
| ordered_map_value = OrderedMapSerializedKey(UTF8Type, 2) | ||||
| ordered_map_value._insert(u'\u307fbob', 199) | ||||
| ordered_map_value._insert(u'', -1) | ||||
| ordered_map_value._insert(u'\\', 0) | ||||
|  | ||||
| # these following entries work for me right now, but they're dependent on | ||||
| # vagaries of internal python ordering for unordered types | ||||
|   | ||||
| @@ -139,21 +139,22 @@ class RoundRobinPolicyTest(unittest.TestCase): | ||||
|             threads.append(Thread(target=host_down)) | ||||
|  | ||||
|         # make the GIL switch after every instruction, maximizing | ||||
|         # the chace of race conditions | ||||
|         if six.PY2: | ||||
|         # the chance of race conditions | ||||
|         check = six.PY2 or '__pypy__' in sys.builtin_module_names | ||||
|         if check: | ||||
|             original_interval = sys.getcheckinterval() | ||||
|         else: | ||||
|             original_interval = sys.getswitchinterval() | ||||
|  | ||||
|         try: | ||||
|             if six.PY2: | ||||
|             if check: | ||||
|                 sys.setcheckinterval(0) | ||||
|             else: | ||||
|                 sys.setswitchinterval(0.0001) | ||||
|             map(lambda t: t.start(), threads) | ||||
|             map(lambda t: t.join(), threads) | ||||
|         finally: | ||||
|             if six.PY2: | ||||
|             if check: | ||||
|                 sys.setcheckinterval(original_interval) | ||||
|             else: | ||||
|                 sys.setswitchinterval(original_interval) | ||||
| @@ -362,6 +363,7 @@ class DCAwareRoundRobinPolicyTest(unittest.TestCase): | ||||
|         policy.on_add(host_remote) | ||||
|         self.assertFalse(policy.local_dc) | ||||
|  | ||||
|  | ||||
| class TokenAwarePolicyTest(unittest.TestCase): | ||||
|  | ||||
|     def test_wrap_round_robin(self): | ||||
| @@ -519,7 +521,6 @@ class TokenAwarePolicyTest(unittest.TestCase): | ||||
|         qplan = list(policy.make_query_plan()) | ||||
|         self.assertEqual(qplan, []) | ||||
|  | ||||
|  | ||||
|     def test_statement_keyspace(self): | ||||
|         hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)] | ||||
|         for host in hosts: | ||||
| @@ -666,6 +667,7 @@ class ExponentialReconnectionPolicyTest(unittest.TestCase): | ||||
|  | ||||
| ONE = ConsistencyLevel.ONE | ||||
|  | ||||
|  | ||||
| class RetryPolicyTest(unittest.TestCase): | ||||
|  | ||||
|     def test_read_timeout(self): | ||||
|   | ||||
| @@ -20,6 +20,7 @@ from binascii import unhexlify | ||||
| import calendar | ||||
| import datetime | ||||
| import tempfile | ||||
| import six | ||||
| import time | ||||
|  | ||||
| import cassandra | ||||
| @@ -228,7 +229,7 @@ class TypeTests(unittest.TestCase): | ||||
|  | ||||
|             @classmethod | ||||
|             def apply_parameters(cls, subtypes, names): | ||||
|                 return cls(subtypes, [unhexlify(name) if name is not None else name for name in names]) | ||||
|                 return cls(subtypes, [unhexlify(six.b(name)) if name is not None else name for name in names]) | ||||
|  | ||||
|         class BarType(FooType): | ||||
|             typename = 'org.apache.cassandra.db.marshal.BarType' | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Adam Holmberg
					Adam Holmberg