From 6d806ee4c855ed3e19cc0111ee7617e32f5dcb6b Mon Sep 17 00:00:00 2001 From: Tyler Hobbs Date: Wed, 18 Jun 2014 18:45:39 -0500 Subject: [PATCH] Integration tests for UDTs, related fixes --- cassandra/cluster.py | 37 ++-- cassandra/connection.py | 2 +- cassandra/cqltypes.py | 4 +- cassandra/metadata.py | 5 +- cassandra/protocol.py | 10 +- tests/integration/standard/test_udts.py | 229 ++++++++++++++++++++++++ 6 files changed, 261 insertions(+), 26 deletions(-) create mode 100644 tests/integration/standard/test_udts.py diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 2869ff97..da3733c3 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -459,7 +459,7 @@ class Cluster(object): def register_user_type(self, keyspace, user_type, klass): self._user_types[keyspace][user_type] = klass for session in self.sessions: - self.session.user_type_registered(keyspace, user_type, klass) + session.user_type_registered(keyspace, user_type, klass) def get_min_requests_per_connection(self, host_distance): return self._min_requests_per_connection[host_distance] @@ -1076,17 +1076,6 @@ class Session(object): encoders = None - def user_type_registered(self, keyspace, user_type, klass): - type_meta = self.cluster.metadata.keyspaces[keyspace].user_types[user_type] - - def encode(val): - return '{ %s }' % ' , '.join('%s : %s' % ( - field_name, - cql_encode_all_types(getattr(val, field_name)) - ) for field_name in type_meta.field_names) - - self._encoders[klass] = encode - def __init__(self, cluster, hosts): self.cluster = cluster self.hosts = hosts @@ -1457,6 +1446,22 @@ class Session(object): for pool in self._pools.values(): pool._set_keyspace_for_all_conns(keyspace, pool_finished_setting_keyspace) + def user_type_registered(self, keyspace, user_type, klass): + """ + Called by the parent Cluster instance when the user registers a new + mapping from a user-defined type to a class. Intended for internal + use only. + """ + type_meta = self.cluster.metadata.keyspaces[keyspace].user_types[user_type] + + def encode(val): + return '{ %s }' % ' , '.join('%s : %s' % ( + field_name, + cql_encode_all_types(getattr(val, field_name)) + ) for field_name in type_meta.field_names) + + self._encoders[klass] = encode + def submit(self, fn, *args, **kwargs): """ Internal """ if not self.is_shutdown: @@ -1521,7 +1526,7 @@ class ControlConnection(object): _SELECT_KEYSPACES = "SELECT * FROM system.schema_keyspaces" _SELECT_COLUMN_FAMILIES = "SELECT * FROM system.schema_columnfamilies" _SELECT_COLUMNS = "SELECT * FROM system.schema_columns" - _SELECT_TYPES = "SELECT * FROM system.schema_types" + _SELECT_USERTYPES = "SELECT * FROM system.schema_usertypes" _SELECT_PEERS = "SELECT peer, data_center, rack, tokens, rpc_address, schema_version FROM system.peers" _SELECT_LOCAL = "SELECT cluster_name, data_center, rack, tokens, partitioner, schema_version FROM system.local WHERE key='local'" @@ -1725,7 +1730,7 @@ class ControlConnection(object): cf_result, col_result = connection.wait_for_responses( cf_query, col_query) - log.debug("[control connection] Fetched table info for %s.%s, rebuilding metadata", (keyspace, table)) + log.debug("[control connection] Fetched table info for %s.%s, rebuilding metadata", keyspace, table) cf_result = dict_factory(*cf_result.results) if cf_result else {} col_result = dict_factory(*col_result.results) if col_result else {} self._cluster.metadata.table_changed(keyspace, table, cf_result, col_result) @@ -1734,7 +1739,7 @@ class ControlConnection(object): where_clause = " WHERE keyspace_name = '%s' AND type_name = '%s'" % (keyspace, usertype) types_query = QueryMessage(query=self._SELECT_USERTYPES + where_clause, consistency_level=cl) types_result = connection.wait_for_response(types_query) - log.debug("[control connection] Fetched user type info for %s.%s, rebuilding metadata", (keyspace, usertype)) + log.debug("[control connection] Fetched user type info for %s.%s, rebuilding metadata", keyspace, usertype) types_result = dict_factory(*types_result.results) if types_result.results else {} self._cluster.metadata.usertype_changed(keyspace, usertype, types_result) elif keyspace: @@ -1742,7 +1747,7 @@ class ControlConnection(object): where_clause = " WHERE keyspace_name = '%s'" % (keyspace,) ks_query = QueryMessage(query=self._SELECT_KEYSPACES + where_clause, consistency_level=cl) ks_result = connection.wait_for_response(ks_query) - log.debug("[control connection] Fetched keyspace info for %s, rebuilding metadata", (keyspace,)) + log.debug("[control connection] Fetched keyspace info for %s, rebuilding metadata", keyspace) ks_result = dict_factory(*ks_result.results) if ks_result.results else {} self._cluster.metadata.keyspace_changed(keyspace, ks_result) else: diff --git a/cassandra/connection.py b/cassandra/connection.py index 6f90da26..c94cf972 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -344,7 +344,7 @@ class Connection(object): % (opcode, stream_id)) if body_len > 0: - body = msg[8:] + body = msg[self._full_header_length:] elif body_len == 0: body = six.binary_type() else: diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index c42ad33f..a09eddbe 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -297,8 +297,7 @@ class _CassandraType(object): if cls.num_subtypes != 'UNKNOWN' and len(subtypes) != cls.num_subtypes: raise ValueError("%s types require %d subtypes (%d given)" % (cls.typename, cls.num_subtypes, len(subtypes))) - # newname = cls.cass_parameterized_type_with(subtypes).encode('utf8') - newname = cls.cass_parameterized_type_with(subtypes) + newname = cls.cass_parameterized_type_with(subtypes).encode('utf8') return type(newname, (cls,), {'subtypes': subtypes, 'cassname': cls.cassname}) @classmethod @@ -782,6 +781,7 @@ class UserDefinedType(_ParameterizedType): @classmethod def apply_parameters(cls, keyspace, udt_name, names_and_types, mapped_class): + udt_name = udt_name.encode('utf-8') try: return cls._cache[(keyspace, udt_name)] except KeyError: diff --git a/cassandra/metadata.py b/cassandra/metadata.py index dceccb9f..a9bda180 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -119,7 +119,7 @@ class Metadata(object): keyspace_meta.tables[table_meta.name] = table_meta for usertype_row in usertype_rows.get(keyspace_meta.name, []): - usertype = self._build_usertype(usertype_row) + usertype = self._build_usertype(keyspace_meta.name, usertype_row) keyspace_meta.user_types[usertype.name] = usertype current_keyspaces.add(keyspace_meta.name) @@ -149,6 +149,7 @@ class Metadata(object): old_keyspace_meta = self.keyspaces.get(keyspace, None) self.keyspaces[keyspace] = keyspace_meta if old_keyspace_meta: + keyspace_meta.user_types = old_keyspace_meta.user_types if (keyspace_meta.replication_strategy != old_keyspace_meta.replication_strategy): self._keyspace_updated(keyspace) else: @@ -156,7 +157,7 @@ class Metadata(object): def usertype_changed(self, keyspace, name, type_results): new_usertype = self._build_usertype(keyspace, type_results[0]) - self.user_types[name] = new_usertype + self.keyspaces[keyspace].user_types[name] = new_usertype def table_changed(self, keyspace, table, cf_results, col_results): try: diff --git a/cassandra/protocol.py b/cassandra/protocol.py index e0dad3ef..25095992 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -632,11 +632,11 @@ class ResultMessage(_MessageType): valsubtype = cls.read_type(f, user_type_map) typeclass = typeclass.apply_parameters((keysubtype, valsubtype)) elif typeclass == UserDefinedType: - ks = cls.read_string(f) - udt_name = cls.read_string(f) - num_fields = cls.read_short(f) - names_and_types = ((cls.read_string(f), cls.read_type(f, user_type_map)) - for _ in xrange(num_fields)) + ks = read_string(f) + udt_name = read_string(f) + num_fields = read_short(f) + names_and_types = tuple((read_string(f), cls.read_type(f, user_type_map)) + for _ in xrange(num_fields)) mapped_class = user_type_map.get(ks, {}).get(udt_name) typeclass = typeclass.apply_parameters( ks, udt_name, names_and_types, mapped_class) diff --git a/tests/integration/standard/test_udts.py b/tests/integration/standard/test_udts.py new file mode 100644 index 00000000..df2bddfc --- /dev/null +++ b/tests/integration/standard/test_udts.py @@ -0,0 +1,229 @@ +# Copyright 2013-2014 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +import logging +log = logging.getLogger(__name__) + +from collections import namedtuple + +from cassandra.cluster import Cluster + +from tests.integration import get_server_versions, PROTOCOL_VERSION + +import time + + +class TypeTests(unittest.TestCase): + + def setUp(self): + if PROTOCOL_VERSION < 3: + raise unittest.SkipTest("v3 protocol is required for UDT tests") + + self._cass_version, self._cql_version = get_server_versions() + + def test_unprepared_registered_udts(self): + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect() + + s.execute(""" + CREATE KEYSPACE udt_test_unprepared_registered + WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + """) + s.set_keyspace("udt_test_unprepared_registered") + s.execute("CREATE TYPE user (age int, name text)") + s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)") + + time.sleep(1) + + User = namedtuple('user', ('age', 'name')) + c.register_user_type("udt_test_unprepared_registered", "user", User) + + s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User(42, 'bob'))) + result = s.execute("SELECT b FROM mytable WHERE a=0") + self.assertEqual(1, len(result)) + row = result[0] + self.assertEqual(42, row.b.age) + self.assertEqual('bob', row.b.name) + self.assertTrue(type(row.b) is User) + + # use the same UDT name in a different keyspace + s.execute(""" + CREATE KEYSPACE udt_test_unprepared_registered2 + WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + """) + s.set_keyspace("udt_test_unprepared_registered2") + s.execute("CREATE TYPE user (state text, is_cool boolean)") + s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)") + + time.sleep(1) + + User = namedtuple('user', ('state', 'is_cool')) + c.register_user_type("udt_test_unprepared_registered2", "user", User) + + s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User('Texas', True))) + result = s.execute("SELECT b FROM mytable WHERE a=0") + self.assertEqual(1, len(result)) + row = result[0] + self.assertEqual('Texas', row.b.state) + self.assertEqual(True, row.b.is_cool) + self.assertTrue(type(row.b) is User) + + def test_register_before_connecting(self): + User1 = namedtuple('user', ('age', 'name')) + User2 = namedtuple('user', ('state', 'is_cool')) + + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect() + + s.execute(""" + CREATE KEYSPACE udt_test_register_before_connecting + WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + """) + s.set_keyspace("udt_test_register_before_connecting") + s.execute("CREATE TYPE user (age int, name text)") + s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)") + + s.execute(""" + CREATE KEYSPACE udt_test_register_before_connecting2 + WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + """) + s.set_keyspace("udt_test_register_before_connecting2") + s.execute("CREATE TYPE user (state text, is_cool boolean)") + s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)") + + time.sleep(1) + + # now that types are defined, shutdown and re-create Cluster + c.shutdown() + c = Cluster(protocol_version=PROTOCOL_VERSION) + c.register_user_type("udt_test_register_before_connecting", "user", User1) + c.register_user_type("udt_test_register_before_connecting2", "user", User2) + + s = c.connect() + + s.set_keyspace("udt_test_register_before_connecting") + s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User1(42, 'bob'))) + result = s.execute("SELECT b FROM mytable WHERE a=0") + self.assertEqual(1, len(result)) + row = result[0] + self.assertEqual(42, row.b.age) + self.assertEqual('bob', row.b.name) + self.assertTrue(type(row.b) is User1) + + # use the same UDT name in a different keyspace + s.set_keyspace("udt_test_register_before_connecting2") + s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User2('Texas', True))) + result = s.execute("SELECT b FROM mytable WHERE a=0") + self.assertEqual(1, len(result)) + row = result[0] + self.assertEqual('Texas', row.b.state) + self.assertEqual(True, row.b.is_cool) + self.assertTrue(type(row.b) is User2) + + def test_prepared_unregistered_udts(self): + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect() + + s.execute(""" + CREATE KEYSPACE udt_test_prepared_unregistered + WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + """) + s.set_keyspace("udt_test_prepared_unregistered") + s.execute("CREATE TYPE user (age int, name text)") + s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)") + + User = namedtuple('user', ('age', 'name')) + insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)") + s.execute(insert, (0, User(42, 'bob'))) + + select = s.prepare("SELECT b FROM mytable WHERE a=?") + result = s.execute(select, (0,)) + self.assertEqual(1, len(result)) + row = result[0] + self.assertEqual(42, row.b.age) + self.assertEqual('bob', row.b.name) + + # use the same UDT name in a different keyspace + s.execute(""" + CREATE KEYSPACE udt_test_prepared_unregistered2 + WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + """) + s.set_keyspace("udt_test_prepared_unregistered2") + s.execute("CREATE TYPE user (state text, is_cool boolean)") + s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)") + + User = namedtuple('user', ('state', 'is_cool')) + insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)") + s.execute(insert, (0, User('Texas', True))) + + select = s.prepare("SELECT b FROM mytable WHERE a=?") + result = s.execute(select, (0,)) + self.assertEqual(1, len(result)) + row = result[0] + self.assertEqual('Texas', row.b.state) + self.assertEqual(True, row.b.is_cool) + + def test_prepared_registered_udts(self): + c = Cluster(protocol_version=PROTOCOL_VERSION) + s = c.connect() + + s.execute(""" + CREATE KEYSPACE udt_test_prepared_registered + WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + """) + s.set_keyspace("udt_test_prepared_registered") + s.execute("CREATE TYPE user (age int, name text)") + s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)") + + User = namedtuple('user', ('age', 'name')) + c.register_user_type("udt_test_prepared_registered", "user", User) + + insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)") + s.execute(insert, (0, User(42, 'bob'))) + + select = s.prepare("SELECT b FROM mytable WHERE a=?") + result = s.execute(select, (0,)) + self.assertEqual(1, len(result)) + row = result[0] + self.assertEqual(42, row.b.age) + self.assertEqual('bob', row.b.name) + self.assertTrue(type(row.b) is User) + + # use the same UDT name in a different keyspace + s.execute(""" + CREATE KEYSPACE udt_test_prepared_registered2 + WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } + """) + s.set_keyspace("udt_test_prepared_registered2") + s.execute("CREATE TYPE user (state text, is_cool boolean)") + s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)") + + User = namedtuple('user', ('state', 'is_cool')) + c.register_user_type("udt_test_prepared_registered2", "user", User) + + insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)") + s.execute(insert, (0, User('Texas', True))) + + select = s.prepare("SELECT b FROM mytable WHERE a=?") + result = s.execute(select, (0,)) + self.assertEqual(1, len(result)) + row = result[0] + self.assertEqual('Texas', row.b.state) + self.assertEqual(True, row.b.is_cool) + self.assertTrue(type(row.b) is User)