Merge pull request #171 from datastax/PYTHON-110

(De)serialize null values in tuples & udts correctly
This commit is contained in:
aholmberg
2014-08-04 08:46:12 -05:00
4 changed files with 99 additions and 10 deletions

View File

@@ -1,3 +1,11 @@
2.1.0
=====
Bug Fixes
---------
* Correctly serialize and deserialize null values in tuples and
user-defined types (PYTHON-110)
2.1.0c1
=======
July 25, 2014

View File

@@ -792,8 +792,11 @@ class TupleType(_ParameterizedType):
break
itemlen = int32_unpack(byts[p:p + 4])
p += 4
item = byts[p:p + itemlen]
p += itemlen
if itemlen >= 0:
item = byts[p:p + itemlen]
p += itemlen
else:
item = None
# collections inside UDTs are always encoded with at least the
# version 3 format
values.append(col_type.from_binary(item, proto_version))
@@ -813,9 +816,12 @@ class TupleType(_ParameterizedType):
proto_version = max(3, protocol_version)
buf = io.BytesIO()
for item, subtype in zip(val, cls.subtypes):
packed_item = subtype.to_binary(item, proto_version)
buf.write(int32_pack(len(packed_item)))
buf.write(packed_item)
if item is not None:
packed_item = subtype.to_binary(item, proto_version)
buf.write(int32_pack(len(packed_item)))
buf.write(packed_item)
else:
buf.write(int32_pack(-1))
return buf.getvalue()
@@ -869,8 +875,11 @@ class UserType(TupleType):
break
itemlen = int32_unpack(byts[p:p + 4])
p += 4
item = byts[p:p + itemlen]
p += itemlen
if itemlen >= 0:
item = byts[p:p + itemlen]
p += itemlen
else:
item = None
# collections inside UDTs are always encoded with at least the
# version 3 format
values.append(col_type.from_binary(item, proto_version))
@@ -890,9 +899,13 @@ class UserType(TupleType):
proto_version = max(3, protocol_version)
buf = io.BytesIO()
for fieldname, subtype in zip(cls.fieldnames, cls.subtypes):
packed_item = subtype.to_binary(getattr(val, fieldname), proto_version)
buf.write(int32_pack(len(packed_item)))
buf.write(packed_item)
item = getattr(val, fieldname)
if item is not None:
packed_item = subtype.to_binary(getattr(val, fieldname), proto_version)
buf.write(int32_pack(len(packed_item)))
buf.write(packed_item)
else:
buf.write(int32_pack(-1))
return buf.getvalue()

View File

@@ -670,6 +670,39 @@ class TypeTests(unittest.TestCase):
result = s.execute("SELECT v_%s FROM mytable WHERE k=%s", (i, i))[0]
self.assertEqual(created_tuple, result['v_%s' % i])
def test_tuples_with_nulls(self):
"""
Test tuples with null and empty string fields.
"""
if self._cass_version < (2, 1, 0):
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
s.execute("""CREATE KEYSPACE test_tuples_with_nulls
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""")
s.set_keyspace("test_tuples_with_nulls")
s.execute("CREATE TABLE mytable (k int PRIMARY KEY, t tuple<text, int, uuid, blob>)")
insert = s.prepare("INSERT INTO mytable (k, t) VALUES (0, ?)")
s.execute(insert, [(None, None, None, None)])
result = s.execute("SELECT * FROM mytable WHERE k=0")
self.assertEquals((None, None, None, None), result[0].t)
read = s.prepare("SELECT * FROM mytable WHERE k=0")
self.assertEquals((None, None, None, None), s.execute(read)[0].t)
# also test empty strings where compatible
s.execute(insert, [('', None, None, '')])
result = s.execute("SELECT * FROM mytable WHERE k=0")
self.assertEquals(('', None, None, ''), result[0].t)
self.assertEquals(('', None, None, ''), s.execute(read)[0].t)
c.shutdown()
def test_unicode_query_string(self):
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()

View File

@@ -231,6 +231,41 @@ class TypeTests(unittest.TestCase):
c.shutdown()
def test_udts_with_nulls(self):
"""
Test UDTs with null and empty string fields.
"""
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
s.execute("""
CREATE KEYSPACE test_udts_with_nulls
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
""")
s.set_keyspace("test_udts_with_nulls")
s.execute("CREATE TYPE user (a text, b int, c uuid, d blob)")
User = namedtuple('user', ('a', 'b', 'c', 'd'))
c.register_user_type("test_udts_with_nulls", "user", User)
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)")
insert = s.prepare("INSERT INTO mytable (a, b) VALUES (0, ?)")
s.execute(insert, [User(None, None, None, None)])
results = s.execute("SELECT b FROM mytable WHERE a=0")
self.assertEqual((None, None, None, None), results[0].b)
select = s.prepare("SELECT b FROM mytable WHERE a=0")
self.assertEqual((None, None, None, None), s.execute(select)[0].b)
# also test empty strings
s.execute(insert, [User('', None, None, '')])
results = s.execute("SELECT b FROM mytable WHERE a=0")
self.assertEqual(('', None, None, ''), results[0].b)
self.assertEqual(('', None, None, ''), s.execute(select)[0].b)
c.shutdown()
def test_udt_sizes(self):
"""
Test for ensuring extra-lengthy udts are handled correctly.