diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 1c3b43f5..1375a51a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,11 @@ +2.1.0 +===== + +Bug Fixes +--------- +* Correctly serialize and deserialize null values in tuples and + user-defined types (PYTHON-110) + 2.1.0c1 ======= July 25, 2014 diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index f062678a..4b7022d6 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -792,8 +792,11 @@ class TupleType(_ParameterizedType): break itemlen = int32_unpack(byts[p:p + 4]) p += 4 - item = byts[p:p + itemlen] - p += itemlen + if itemlen >= 0: + item = byts[p:p + itemlen] + p += itemlen + else: + item = None # collections inside UDTs are always encoded with at least the # version 3 format values.append(col_type.from_binary(item, proto_version)) @@ -813,9 +816,12 @@ class TupleType(_ParameterizedType): proto_version = max(3, protocol_version) buf = io.BytesIO() for item, subtype in zip(val, cls.subtypes): - packed_item = subtype.to_binary(item, proto_version) - buf.write(int32_pack(len(packed_item))) - buf.write(packed_item) + if item is not None: + packed_item = subtype.to_binary(item, proto_version) + buf.write(int32_pack(len(packed_item))) + buf.write(packed_item) + else: + buf.write(int32_pack(-1)) return buf.getvalue() @@ -869,8 +875,11 @@ class UserType(TupleType): break itemlen = int32_unpack(byts[p:p + 4]) p += 4 - item = byts[p:p + itemlen] - p += itemlen + if itemlen >= 0: + item = byts[p:p + itemlen] + p += itemlen + else: + item = None # collections inside UDTs are always encoded with at least the # version 3 format values.append(col_type.from_binary(item, proto_version)) @@ -890,9 +899,13 @@ class UserType(TupleType): proto_version = max(3, protocol_version) buf = io.BytesIO() for fieldname, subtype in zip(cls.fieldnames, cls.subtypes): - packed_item = subtype.to_binary(getattr(val, fieldname), proto_version) - buf.write(int32_pack(len(packed_item))) - buf.write(packed_item) + item = getattr(val, fieldname) + if item is not None: + packed_item = subtype.to_binary(getattr(val, fieldname), proto_version) + buf.write(int32_pack(len(packed_item))) + buf.write(packed_item) + else: + buf.write(int32_pack(-1)) return buf.getvalue() diff --git a/tests/integration/standard/test_types.py b/tests/integration/standard/test_types.py index 36b1a099..a3d8f2f0 100644 --- a/tests/integration/standard/test_types.py +++ b/tests/integration/standard/test_types.py @@ -670,6 +670,39 @@ class TypeTests(unittest.TestCase): result = s.execute("SELECT v_%s FROM mytable WHERE k=%s", (i, i))[0] self.assertEqual(created_tuple, result['v_%s' % i]) + def test_tuples_with_nulls(self): + """ + Test tuples with null and empty string fields. + """ + if self._cass_version < (2, 1, 0): + raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect() + + s.execute("""CREATE KEYSPACE test_tuples_with_nulls + WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}""") + s.set_keyspace("test_tuples_with_nulls") + + s.execute("CREATE TABLE mytable (k int PRIMARY KEY, t tuple)") + + insert = s.prepare("INSERT INTO mytable (k, t) VALUES (0, ?)") + s.execute(insert, [(None, None, None, None)]) + + result = s.execute("SELECT * FROM mytable WHERE k=0") + self.assertEquals((None, None, None, None), result[0].t) + + read = s.prepare("SELECT * FROM mytable WHERE k=0") + self.assertEquals((None, None, None, None), s.execute(read)[0].t) + + # also test empty strings where compatible + s.execute(insert, [('', None, None, '')]) + result = s.execute("SELECT * FROM mytable WHERE k=0") + self.assertEquals(('', None, None, ''), result[0].t) + self.assertEquals(('', None, None, ''), s.execute(read)[0].t) + + c.shutdown() + def test_unicode_query_string(self): c = Cluster(protocol_version=PROTOCOL_VERSION) s = c.connect() diff --git a/tests/integration/standard/test_udts.py b/tests/integration/standard/test_udts.py index 0c5418f4..0a7a62ba 100644 --- a/tests/integration/standard/test_udts.py +++ b/tests/integration/standard/test_udts.py @@ -231,6 +231,41 @@ class TypeTests(unittest.TestCase): c.shutdown() + def test_udts_with_nulls(self): + """ + Test UDTs with null and empty string fields. + """ + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect() + + s.execute(""" + CREATE KEYSPACE test_udts_with_nulls + WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + """) + s.set_keyspace("test_udts_with_nulls") + s.execute("CREATE TYPE user (a text, b int, c uuid, d blob)") + User = namedtuple('user', ('a', 'b', 'c', 'd')) + c.register_user_type("test_udts_with_nulls", "user", User) + + s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)") + + insert = s.prepare("INSERT INTO mytable (a, b) VALUES (0, ?)") + s.execute(insert, [User(None, None, None, None)]) + + results = s.execute("SELECT b FROM mytable WHERE a=0") + self.assertEqual((None, None, None, None), results[0].b) + + select = s.prepare("SELECT b FROM mytable WHERE a=0") + self.assertEqual((None, None, None, None), s.execute(select)[0].b) + + # also test empty strings + s.execute(insert, [User('', None, None, '')]) + results = s.execute("SELECT b FROM mytable WHERE a=0") + self.assertEqual(('', None, None, ''), results[0].b) + self.assertEqual(('', None, None, ''), s.execute(select)[0].b) + + c.shutdown() + def test_udt_sizes(self): """ Test for ensuring extra-lengthy udts are handled correctly.