Merge branch 'PYTHON-231'

Conflicts:
	cassandra/cqltypes.py
	cassandra/util.py
This commit is contained in:
Adam Holmberg
2015-03-13 13:43:10 -05:00
5 changed files with 45 additions and 25 deletions

View File

@@ -784,12 +784,12 @@ class MapType(_ParameterizedType):
@classmethod
def validate(cls, val):
subkeytype, subvaltype = cls.subtypes
return dict((subkeytype.validate(k), subvaltype.validate(v)) for (k, v) in six.iteritems(val))
key_type, value_type = cls.subtypes
return dict((key_type.validate(k), value_type.validate(v)) for (k, v) in six.iteritems(val))
@classmethod
def deserialize_safe(cls, byts, protocol_version):
subkeytype, subvaltype = cls.subtypes
key_type, value_type = cls.subtypes
if protocol_version >= 3:
unpack = int32_unpack
length = 4
@@ -798,7 +798,7 @@ class MapType(_ParameterizedType):
length = 2
numelements = unpack(byts[:length])
p = length
themap = util.OrderedMap()
themap = util.OrderedMapSerializedKey(key_type, protocol_version)
for _ in range(numelements):
key_len = unpack(byts[p:p + length])
p += length
@@ -808,14 +808,14 @@ class MapType(_ParameterizedType):
p += length
valbytes = byts[p:p + val_len]
p += val_len
key = subkeytype.from_binary(keybytes, protocol_version)
val = subvaltype.from_binary(valbytes, protocol_version)
themap._insert(key, val)
key = key_type.from_binary(keybytes, protocol_version)
val = value_type.from_binary(valbytes, protocol_version)
themap._insert_unchecked(key, keybytes, val)
return themap
@classmethod
def serialize_safe(cls, themap, protocol_version):
subkeytype, subvaltype = cls.subtypes
key_type, value_type = cls.subtypes
pack = int32_pack if protocol_version >= 3 else uint16_pack
buf = io.BytesIO()
buf.write(pack(len(themap)))
@@ -824,8 +824,8 @@ class MapType(_ParameterizedType):
except AttributeError:
raise TypeError("Got a non-map object for a map value")
for key, val in items:
keybytes = subkeytype.to_binary(key, protocol_version)
valbytes = subvaltype.to_binary(val, protocol_version)
keybytes = key_type.to_binary(key, protocol_version)
valbytes = value_type.to_binary(val, protocol_version)
buf.write(pack(len(keybytes)))
buf.write(keybytes)
buf.write(pack(len(valbytes)))

View File

@@ -680,6 +680,7 @@ except ImportError:
isect.add(item)
return isect
from collections import Mapping
from six.moves import cPickle
@@ -715,6 +716,7 @@ class OrderedMap(Mapping):
or higher.
'''
def __init__(self, *args, **kwargs):
if len(args) > 1:
raise TypeError('expected at most 1 arguments, got %d' % len(args))
@@ -776,11 +778,25 @@ class OrderedMap(Mapping):
def __str__(self):
return '{%s}' % ', '.join("%s: %s" % (k, v) for k, v in self._items)
@staticmethod
def _serialize_key(key):
def _serialize_key(self, key):
return cPickle.dumps(key)
class OrderedMapSerializedKey(OrderedMap):
def __init__(self, cass_type, protocol_version):
super(OrderedMapSerializedKey, self).__init__()
self.cass_key_type = cass_type
self.protocol_version = protocol_version
def _insert_unchecked(self, key, flat_key, value):
self._items.append((key, value))
self._index[flat_key] = len(self._items) - 1
def _serialize_key(self, key):
return self.cass_key_type.serialize(key, self.protocol_version)
import datetime
import time

View File

@@ -23,8 +23,8 @@ from datetime import datetime, date
from decimal import Decimal
from uuid import UUID
from cassandra.cqltypes import lookup_casstype
from cassandra.util import OrderedMap, sortedset, Time
from cassandra.cqltypes import lookup_casstype, DecimalType, UTF8Type
from cassandra.util import OrderedMap, OrderedMapSerializedKey, sortedset, Time
marshalled_value_pairs = (
# binary form, type, python native type
@@ -75,7 +75,7 @@ marshalled_value_pairs = (
(b'', 'MapType(AsciiType, BooleanType)', None),
(b'', 'ListType(FloatType)', None),
(b'', 'SetType(LongType)', None),
(b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedMap()),
(b'\x00\x00', 'MapType(DecimalType, BooleanType)', OrderedMapSerializedKey(DecimalType, 0)),
(b'\x00\x00', 'ListType(FloatType)', []),
(b'\x00\x00', 'SetType(IntegerType)', sortedset()),
(b'\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', [UUID(bytes=b'\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0')]),
@@ -84,9 +84,10 @@ marshalled_value_pairs = (
(b'\x00\x00\x00\x00\x00\x00\x00\x01', 'TimeType', Time(1))
)
ordered_map_value = OrderedMap([(u'\u307fbob', 199),
(u'', -1),
(u'\\', 0)])
ordered_map_value = OrderedMapSerializedKey(UTF8Type, 2)
ordered_map_value._insert(u'\u307fbob', 199)
ordered_map_value._insert(u'', -1)
ordered_map_value._insert(u'\\', 0)
# these following entries work for me right now, but they're dependent on
# vagaries of internal python ordering for unordered types

View File

@@ -15,7 +15,7 @@
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
import unittest # noqa
from itertools import islice, cycle
from mock import Mock
@@ -139,21 +139,22 @@ class RoundRobinPolicyTest(unittest.TestCase):
threads.append(Thread(target=host_down))
# make the GIL switch after every instruction, maximizing
# the chace of race conditions
if six.PY2:
# the chance of race conditions
check = six.PY2 or '__pypy__' in sys.builtin_module_names
if check:
original_interval = sys.getcheckinterval()
else:
original_interval = sys.getswitchinterval()
try:
if six.PY2:
if check:
sys.setcheckinterval(0)
else:
sys.setswitchinterval(0.0001)
map(lambda t: t.start(), threads)
map(lambda t: t.join(), threads)
finally:
if six.PY2:
if check:
sys.setcheckinterval(original_interval)
else:
sys.setswitchinterval(original_interval)
@@ -362,6 +363,7 @@ class DCAwareRoundRobinPolicyTest(unittest.TestCase):
policy.on_add(host_remote)
self.assertFalse(policy.local_dc)
class TokenAwarePolicyTest(unittest.TestCase):
def test_wrap_round_robin(self):
@@ -519,7 +521,6 @@ class TokenAwarePolicyTest(unittest.TestCase):
qplan = list(policy.make_query_plan())
self.assertEqual(qplan, [])
def test_statement_keyspace(self):
hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)]
for host in hosts:
@@ -666,6 +667,7 @@ class ExponentialReconnectionPolicyTest(unittest.TestCase):
ONE = ConsistencyLevel.ONE
class RetryPolicyTest(unittest.TestCase):
def test_read_timeout(self):

View File

@@ -20,6 +20,7 @@ from binascii import unhexlify
import calendar
import datetime
import tempfile
import six
import time
import cassandra
@@ -228,7 +229,7 @@ class TypeTests(unittest.TestCase):
@classmethod
def apply_parameters(cls, subtypes, names):
return cls(subtypes, [unhexlify(name) if name is not None else name for name in names])
return cls(subtypes, [unhexlify(six.b(name)) if name is not None else name for name in names])
class BarType(FooType):
typename = 'org.apache.cassandra.db.marshal.BarType'