Merge branch 'PYTHON-231'
Conflicts: cassandra/cqltypes.py cassandra/util.py
This commit is contained in:
@@ -784,12 +784,12 @@ class MapType(_ParameterizedType):
|
||||
|
||||
@classmethod
|
||||
def validate(cls, val):
|
||||
subkeytype, subvaltype = cls.subtypes
|
||||
return dict((subkeytype.validate(k), subvaltype.validate(v)) for (k, v) in six.iteritems(val))
|
||||
key_type, value_type = cls.subtypes
|
||||
return dict((key_type.validate(k), value_type.validate(v)) for (k, v) in six.iteritems(val))
|
||||
|
||||
@classmethod
|
||||
def deserialize_safe(cls, byts, protocol_version):
|
||||
subkeytype, subvaltype = cls.subtypes
|
||||
key_type, value_type = cls.subtypes
|
||||
if protocol_version >= 3:
|
||||
unpack = int32_unpack
|
||||
length = 4
|
||||
@@ -798,7 +798,7 @@ class MapType(_ParameterizedType):
|
||||
length = 2
|
||||
numelements = unpack(byts[:length])
|
||||
p = length
|
||||
themap = util.OrderedMap()
|
||||
themap = util.OrderedMapSerializedKey(key_type, protocol_version)
|
||||
for _ in range(numelements):
|
||||
key_len = unpack(byts[p:p + length])
|
||||
p += length
|
||||
@@ -808,14 +808,14 @@ class MapType(_ParameterizedType):
|
||||
p += length
|
||||
valbytes = byts[p:p + val_len]
|
||||
p += val_len
|
||||
key = subkeytype.from_binary(keybytes, protocol_version)
|
||||
val = subvaltype.from_binary(valbytes, protocol_version)
|
||||
themap._insert(key, val)
|
||||
key = key_type.from_binary(keybytes, protocol_version)
|
||||
val = value_type.from_binary(valbytes, protocol_version)
|
||||
themap._insert_unchecked(key, keybytes, val)
|
||||
return themap
|
||||
|
||||
@classmethod
|
||||
def serialize_safe(cls, themap, protocol_version):
|
||||
subkeytype, subvaltype = cls.subtypes
|
||||
key_type, value_type = cls.subtypes
|
||||
pack = int32_pack if protocol_version >= 3 else uint16_pack
|
||||
buf = io.BytesIO()
|
||||
buf.write(pack(len(themap)))
|
||||
@@ -824,8 +824,8 @@ class MapType(_ParameterizedType):
|
||||
except AttributeError:
|
||||
raise TypeError("Got a non-map object for a map value")
|
||||
for key, val in items:
|
||||
keybytes = subkeytype.to_binary(key, protocol_version)
|
||||
valbytes = subvaltype.to_binary(val, protocol_version)
|
||||
keybytes = key_type.to_binary(key, protocol_version)
|
||||
valbytes = value_type.to_binary(val, protocol_version)
|
||||
buf.write(pack(len(keybytes)))
|
||||
buf.write(keybytes)
|
||||
buf.write(pack(len(valbytes)))
|
||||
|
||||
@@ -680,6 +680,7 @@ except ImportError:
|
||||
isect.add(item)
|
||||
return isect
|
||||
|
||||
|
||||
from collections import Mapping
|
||||
from six.moves import cPickle
|
||||
|
||||
@@ -715,6 +716,7 @@ class OrderedMap(Mapping):
|
||||
or higher.
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if len(args) > 1:
|
||||
raise TypeError('expected at most 1 arguments, got %d' % len(args))
|
||||
@@ -776,11 +778,25 @@ class OrderedMap(Mapping):
|
||||
def __str__(self):
|
||||
return '{%s}' % ', '.join("%s: %s" % (k, v) for k, v in self._items)
|
||||
|
||||
@staticmethod
|
||||
def _serialize_key(key):
|
||||
def _serialize_key(self, key):
|
||||
return cPickle.dumps(key)
|
||||
|
||||
|
||||
class OrderedMapSerializedKey(OrderedMap):
|
||||
|
||||
def __init__(self, cass_type, protocol_version):
|
||||
super(OrderedMapSerializedKey, self).__init__()
|
||||
self.cass_key_type = cass_type
|
||||
self.protocol_version = protocol_version
|
||||
|
||||
def _insert_unchecked(self, key, flat_key, value):
|
||||
self._items.append((key, value))
|
||||
self._index[flat_key] = len(self._items) - 1
|
||||
|
||||
def _serialize_key(self, key):
|
||||
return self.cass_key_type.serialize(key, self.protocol_version)
|
||||
|
||||
|
||||
import datetime
|
||||
import time
|
||||
|
||||
|
||||
@@ -23,8 +23,8 @@ from datetime import datetime, date
|
||||
from decimal import Decimal
|
||||
from uuid import UUID
|
||||
|
||||
from cassandra.cqltypes import lookup_casstype
|
||||
from cassandra.util import OrderedMap, sortedset, Time
|
||||
from cassandra.cqltypes import lookup_casstype, DecimalType, UTF8Type
|
||||
from cassandra.util import OrderedMap, OrderedMapSerializedKey, sortedset, Time
|
||||
|
||||
marshalled_value_pairs = (
|
||||
# binary form, type, python native type
|
||||
@@ -75,7 +75,7 @@ marshalled_value_pairs = (
|
||||
(b'', 'MapType(AsciiType, BooleanType)', None),
|
||||
(b'', 'ListType(FloatType)', None),
|
||||
(b'', 'SetType(LongType)', None),
|
||||
(b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedMap()),
|
||||
(b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedMapSerializedKey(DecimalType, 0)),
|
||||
(b'\x00\x00', 'ListType(FloatType)', []),
|
||||
(b'\x00\x00', 'SetType(IntegerType)', sortedset()),
|
||||
(b'\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', [UUID(bytes=b'\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0')]),
|
||||
@@ -84,9 +84,10 @@ marshalled_value_pairs = (
|
||||
(b'\x00\x00\x00\x00\x00\x00\x00\x01', 'TimeType', Time(1))
|
||||
)
|
||||
|
||||
ordered_map_value = OrderedMap([(u'\u307fbob', 199),
|
||||
(u'', -1),
|
||||
(u'\\', 0)])
|
||||
ordered_map_value = OrderedMapSerializedKey(UTF8Type, 2)
|
||||
ordered_map_value._insert(u'\u307fbob', 199)
|
||||
ordered_map_value._insert(u'', -1)
|
||||
ordered_map_value._insert(u'\\', 0)
|
||||
|
||||
# these following entries work for me right now, but they're dependent on
|
||||
# vagaries of internal python ordering for unordered types
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
import unittest # noqa
|
||||
|
||||
from itertools import islice, cycle
|
||||
from mock import Mock
|
||||
@@ -139,21 +139,22 @@ class RoundRobinPolicyTest(unittest.TestCase):
|
||||
threads.append(Thread(target=host_down))
|
||||
|
||||
# make the GIL switch after every instruction, maximizing
|
||||
# the chace of race conditions
|
||||
if six.PY2:
|
||||
# the chance of race conditions
|
||||
check = six.PY2 or '__pypy__' in sys.builtin_module_names
|
||||
if check:
|
||||
original_interval = sys.getcheckinterval()
|
||||
else:
|
||||
original_interval = sys.getswitchinterval()
|
||||
|
||||
try:
|
||||
if six.PY2:
|
||||
if check:
|
||||
sys.setcheckinterval(0)
|
||||
else:
|
||||
sys.setswitchinterval(0.0001)
|
||||
map(lambda t: t.start(), threads)
|
||||
map(lambda t: t.join(), threads)
|
||||
finally:
|
||||
if six.PY2:
|
||||
if check:
|
||||
sys.setcheckinterval(original_interval)
|
||||
else:
|
||||
sys.setswitchinterval(original_interval)
|
||||
@@ -362,6 +363,7 @@ class DCAwareRoundRobinPolicyTest(unittest.TestCase):
|
||||
policy.on_add(host_remote)
|
||||
self.assertFalse(policy.local_dc)
|
||||
|
||||
|
||||
class TokenAwarePolicyTest(unittest.TestCase):
|
||||
|
||||
def test_wrap_round_robin(self):
|
||||
@@ -519,7 +521,6 @@ class TokenAwarePolicyTest(unittest.TestCase):
|
||||
qplan = list(policy.make_query_plan())
|
||||
self.assertEqual(qplan, [])
|
||||
|
||||
|
||||
def test_statement_keyspace(self):
|
||||
hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)]
|
||||
for host in hosts:
|
||||
@@ -666,6 +667,7 @@ class ExponentialReconnectionPolicyTest(unittest.TestCase):
|
||||
|
||||
ONE = ConsistencyLevel.ONE
|
||||
|
||||
|
||||
class RetryPolicyTest(unittest.TestCase):
|
||||
|
||||
def test_read_timeout(self):
|
||||
|
||||
@@ -20,6 +20,7 @@ from binascii import unhexlify
|
||||
import calendar
|
||||
import datetime
|
||||
import tempfile
|
||||
import six
|
||||
import time
|
||||
|
||||
import cassandra
|
||||
@@ -228,7 +229,7 @@ class TypeTests(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def apply_parameters(cls, subtypes, names):
|
||||
return cls(subtypes, [unhexlify(name) if name is not None else name for name in names])
|
||||
return cls(subtypes, [unhexlify(six.b(name)) if name is not None else name for name in names])
|
||||
|
||||
class BarType(FooType):
|
||||
typename = 'org.apache.cassandra.db.marshal.BarType'
|
||||
|
||||
Reference in New Issue
Block a user