Merge branch 'PYTHON-231'
Conflicts: cassandra/cqltypes.py cassandra/util.py
This commit is contained in:
@@ -784,12 +784,12 @@ class MapType(_ParameterizedType):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls, val):
|
def validate(cls, val):
|
||||||
subkeytype, subvaltype = cls.subtypes
|
key_type, value_type = cls.subtypes
|
||||||
return dict((subkeytype.validate(k), subvaltype.validate(v)) for (k, v) in six.iteritems(val))
|
return dict((key_type.validate(k), value_type.validate(v)) for (k, v) in six.iteritems(val))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def deserialize_safe(cls, byts, protocol_version):
|
def deserialize_safe(cls, byts, protocol_version):
|
||||||
subkeytype, subvaltype = cls.subtypes
|
key_type, value_type = cls.subtypes
|
||||||
if protocol_version >= 3:
|
if protocol_version >= 3:
|
||||||
unpack = int32_unpack
|
unpack = int32_unpack
|
||||||
length = 4
|
length = 4
|
||||||
@@ -798,7 +798,7 @@ class MapType(_ParameterizedType):
|
|||||||
length = 2
|
length = 2
|
||||||
numelements = unpack(byts[:length])
|
numelements = unpack(byts[:length])
|
||||||
p = length
|
p = length
|
||||||
themap = util.OrderedMap()
|
themap = util.OrderedMapSerializedKey(key_type, protocol_version)
|
||||||
for _ in range(numelements):
|
for _ in range(numelements):
|
||||||
key_len = unpack(byts[p:p + length])
|
key_len = unpack(byts[p:p + length])
|
||||||
p += length
|
p += length
|
||||||
@@ -808,14 +808,14 @@ class MapType(_ParameterizedType):
|
|||||||
p += length
|
p += length
|
||||||
valbytes = byts[p:p + val_len]
|
valbytes = byts[p:p + val_len]
|
||||||
p += val_len
|
p += val_len
|
||||||
key = subkeytype.from_binary(keybytes, protocol_version)
|
key = key_type.from_binary(keybytes, protocol_version)
|
||||||
val = subvaltype.from_binary(valbytes, protocol_version)
|
val = value_type.from_binary(valbytes, protocol_version)
|
||||||
themap._insert(key, val)
|
themap._insert_unchecked(key, keybytes, val)
|
||||||
return themap
|
return themap
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def serialize_safe(cls, themap, protocol_version):
|
def serialize_safe(cls, themap, protocol_version):
|
||||||
subkeytype, subvaltype = cls.subtypes
|
key_type, value_type = cls.subtypes
|
||||||
pack = int32_pack if protocol_version >= 3 else uint16_pack
|
pack = int32_pack if protocol_version >= 3 else uint16_pack
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
buf.write(pack(len(themap)))
|
buf.write(pack(len(themap)))
|
||||||
@@ -824,8 +824,8 @@ class MapType(_ParameterizedType):
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise TypeError("Got a non-map object for a map value")
|
raise TypeError("Got a non-map object for a map value")
|
||||||
for key, val in items:
|
for key, val in items:
|
||||||
keybytes = subkeytype.to_binary(key, protocol_version)
|
keybytes = key_type.to_binary(key, protocol_version)
|
||||||
valbytes = subvaltype.to_binary(val, protocol_version)
|
valbytes = value_type.to_binary(val, protocol_version)
|
||||||
buf.write(pack(len(keybytes)))
|
buf.write(pack(len(keybytes)))
|
||||||
buf.write(keybytes)
|
buf.write(keybytes)
|
||||||
buf.write(pack(len(valbytes)))
|
buf.write(pack(len(valbytes)))
|
||||||
|
|||||||
@@ -680,6 +680,7 @@ except ImportError:
|
|||||||
isect.add(item)
|
isect.add(item)
|
||||||
return isect
|
return isect
|
||||||
|
|
||||||
|
|
||||||
from collections import Mapping
|
from collections import Mapping
|
||||||
from six.moves import cPickle
|
from six.moves import cPickle
|
||||||
|
|
||||||
@@ -715,6 +716,7 @@ class OrderedMap(Mapping):
|
|||||||
or higher.
|
or higher.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
if len(args) > 1:
|
if len(args) > 1:
|
||||||
raise TypeError('expected at most 1 arguments, got %d' % len(args))
|
raise TypeError('expected at most 1 arguments, got %d' % len(args))
|
||||||
@@ -776,11 +778,25 @@ class OrderedMap(Mapping):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return '{%s}' % ', '.join("%s: %s" % (k, v) for k, v in self._items)
|
return '{%s}' % ', '.join("%s: %s" % (k, v) for k, v in self._items)
|
||||||
|
|
||||||
@staticmethod
|
def _serialize_key(self, key):
|
||||||
def _serialize_key(key):
|
|
||||||
return cPickle.dumps(key)
|
return cPickle.dumps(key)
|
||||||
|
|
||||||
|
|
||||||
|
class OrderedMapSerializedKey(OrderedMap):
|
||||||
|
|
||||||
|
def __init__(self, cass_type, protocol_version):
|
||||||
|
super(OrderedMapSerializedKey, self).__init__()
|
||||||
|
self.cass_key_type = cass_type
|
||||||
|
self.protocol_version = protocol_version
|
||||||
|
|
||||||
|
def _insert_unchecked(self, key, flat_key, value):
|
||||||
|
self._items.append((key, value))
|
||||||
|
self._index[flat_key] = len(self._items) - 1
|
||||||
|
|
||||||
|
def _serialize_key(self, key):
|
||||||
|
return self.cass_key_type.serialize(key, self.protocol_version)
|
||||||
|
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|||||||
@@ -23,8 +23,8 @@ from datetime import datetime, date
|
|||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from cassandra.cqltypes import lookup_casstype
|
from cassandra.cqltypes import lookup_casstype, DecimalType, UTF8Type
|
||||||
from cassandra.util import OrderedMap, sortedset, Time
|
from cassandra.util import OrderedMap, OrderedMapSerializedKey, sortedset, Time
|
||||||
|
|
||||||
marshalled_value_pairs = (
|
marshalled_value_pairs = (
|
||||||
# binary form, type, python native type
|
# binary form, type, python native type
|
||||||
@@ -75,7 +75,7 @@ marshalled_value_pairs = (
|
|||||||
(b'', 'MapType(AsciiType, BooleanType)', None),
|
(b'', 'MapType(AsciiType, BooleanType)', None),
|
||||||
(b'', 'ListType(FloatType)', None),
|
(b'', 'ListType(FloatType)', None),
|
||||||
(b'', 'SetType(LongType)', None),
|
(b'', 'SetType(LongType)', None),
|
||||||
(b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedMap()),
|
(b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedMapSerializedKey(DecimalType, 0)),
|
||||||
(b'\x00\x00', 'ListType(FloatType)', []),
|
(b'\x00\x00', 'ListType(FloatType)', []),
|
||||||
(b'\x00\x00', 'SetType(IntegerType)', sortedset()),
|
(b'\x00\x00', 'SetType(IntegerType)', sortedset()),
|
||||||
(b'\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', [UUID(bytes=b'\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0')]),
|
(b'\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', [UUID(bytes=b'\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0')]),
|
||||||
@@ -84,9 +84,10 @@ marshalled_value_pairs = (
|
|||||||
(b'\x00\x00\x00\x00\x00\x00\x00\x01', 'TimeType', Time(1))
|
(b'\x00\x00\x00\x00\x00\x00\x00\x01', 'TimeType', Time(1))
|
||||||
)
|
)
|
||||||
|
|
||||||
ordered_map_value = OrderedMap([(u'\u307fbob', 199),
|
ordered_map_value = OrderedMapSerializedKey(UTF8Type, 2)
|
||||||
(u'', -1),
|
ordered_map_value._insert(u'\u307fbob', 199)
|
||||||
(u'\\', 0)])
|
ordered_map_value._insert(u'', -1)
|
||||||
|
ordered_map_value._insert(u'\\', 0)
|
||||||
|
|
||||||
# these following entries work for me right now, but they're dependent on
|
# these following entries work for me right now, but they're dependent on
|
||||||
# vagaries of internal python ordering for unordered types
|
# vagaries of internal python ordering for unordered types
|
||||||
|
|||||||
@@ -139,21 +139,22 @@ class RoundRobinPolicyTest(unittest.TestCase):
|
|||||||
threads.append(Thread(target=host_down))
|
threads.append(Thread(target=host_down))
|
||||||
|
|
||||||
# make the GIL switch after every instruction, maximizing
|
# make the GIL switch after every instruction, maximizing
|
||||||
# the chace of race conditions
|
# the chance of race conditions
|
||||||
if six.PY2:
|
check = six.PY2 or '__pypy__' in sys.builtin_module_names
|
||||||
|
if check:
|
||||||
original_interval = sys.getcheckinterval()
|
original_interval = sys.getcheckinterval()
|
||||||
else:
|
else:
|
||||||
original_interval = sys.getswitchinterval()
|
original_interval = sys.getswitchinterval()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if six.PY2:
|
if check:
|
||||||
sys.setcheckinterval(0)
|
sys.setcheckinterval(0)
|
||||||
else:
|
else:
|
||||||
sys.setswitchinterval(0.0001)
|
sys.setswitchinterval(0.0001)
|
||||||
map(lambda t: t.start(), threads)
|
map(lambda t: t.start(), threads)
|
||||||
map(lambda t: t.join(), threads)
|
map(lambda t: t.join(), threads)
|
||||||
finally:
|
finally:
|
||||||
if six.PY2:
|
if check:
|
||||||
sys.setcheckinterval(original_interval)
|
sys.setcheckinterval(original_interval)
|
||||||
else:
|
else:
|
||||||
sys.setswitchinterval(original_interval)
|
sys.setswitchinterval(original_interval)
|
||||||
@@ -362,6 +363,7 @@ class DCAwareRoundRobinPolicyTest(unittest.TestCase):
|
|||||||
policy.on_add(host_remote)
|
policy.on_add(host_remote)
|
||||||
self.assertFalse(policy.local_dc)
|
self.assertFalse(policy.local_dc)
|
||||||
|
|
||||||
|
|
||||||
class TokenAwarePolicyTest(unittest.TestCase):
|
class TokenAwarePolicyTest(unittest.TestCase):
|
||||||
|
|
||||||
def test_wrap_round_robin(self):
|
def test_wrap_round_robin(self):
|
||||||
@@ -519,7 +521,6 @@ class TokenAwarePolicyTest(unittest.TestCase):
|
|||||||
qplan = list(policy.make_query_plan())
|
qplan = list(policy.make_query_plan())
|
||||||
self.assertEqual(qplan, [])
|
self.assertEqual(qplan, [])
|
||||||
|
|
||||||
|
|
||||||
def test_statement_keyspace(self):
|
def test_statement_keyspace(self):
|
||||||
hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)]
|
hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)]
|
||||||
for host in hosts:
|
for host in hosts:
|
||||||
@@ -666,6 +667,7 @@ class ExponentialReconnectionPolicyTest(unittest.TestCase):
|
|||||||
|
|
||||||
ONE = ConsistencyLevel.ONE
|
ONE = ConsistencyLevel.ONE
|
||||||
|
|
||||||
|
|
||||||
class RetryPolicyTest(unittest.TestCase):
|
class RetryPolicyTest(unittest.TestCase):
|
||||||
|
|
||||||
def test_read_timeout(self):
|
def test_read_timeout(self):
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from binascii import unhexlify
|
|||||||
import calendar
|
import calendar
|
||||||
import datetime
|
import datetime
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import six
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import cassandra
|
import cassandra
|
||||||
@@ -228,7 +229,7 @@ class TypeTests(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply_parameters(cls, subtypes, names):
|
def apply_parameters(cls, subtypes, names):
|
||||||
return cls(subtypes, [unhexlify(name) if name is not None else name for name in names])
|
return cls(subtypes, [unhexlify(six.b(name)) if name is not None else name for name in names])
|
||||||
|
|
||||||
class BarType(FooType):
|
class BarType(FooType):
|
||||||
typename = 'org.apache.cassandra.db.marshal.BarType'
|
typename = 'org.apache.cassandra.db.marshal.BarType'
|
||||||
|
|||||||
Reference in New Issue
Block a user