Integration tests for UDTs, related fixes

This commit is contained in:
Tyler Hobbs
2014-06-18 18:45:39 -05:00
parent 901d9a16e4
commit 6d806ee4c8
6 changed files with 261 additions and 26 deletions

View File

@@ -459,7 +459,7 @@ class Cluster(object):
def register_user_type(self, keyspace, user_type, klass):
self._user_types[keyspace][user_type] = klass
for session in self.sessions:
self.session.user_type_registered(keyspace, user_type, klass)
session.user_type_registered(keyspace, user_type, klass)
def get_min_requests_per_connection(self, host_distance):
return self._min_requests_per_connection[host_distance]
@@ -1076,17 +1076,6 @@ class Session(object):
encoders = None
def user_type_registered(self, keyspace, user_type, klass):
type_meta = self.cluster.metadata.keyspaces[keyspace].user_types[user_type]
def encode(val):
return '{ %s }' % ' , '.join('%s : %s' % (
field_name,
cql_encode_all_types(getattr(val, field_name))
) for field_name in type_meta.field_names)
self._encoders[klass] = encode
def __init__(self, cluster, hosts):
self.cluster = cluster
self.hosts = hosts
@@ -1457,6 +1446,22 @@ class Session(object):
for pool in self._pools.values():
pool._set_keyspace_for_all_conns(keyspace, pool_finished_setting_keyspace)
def user_type_registered(self, keyspace, user_type, klass):
"""
Called by the parent Cluster instance when the user registers a new
mapping from a user-defined type to a class. Intended for internal
use only.
"""
type_meta = self.cluster.metadata.keyspaces[keyspace].user_types[user_type]
def encode(val):
return '{ %s }' % ' , '.join('%s : %s' % (
field_name,
cql_encode_all_types(getattr(val, field_name))
) for field_name in type_meta.field_names)
self._encoders[klass] = encode
def submit(self, fn, *args, **kwargs):
""" Internal """
if not self.is_shutdown:
@@ -1521,7 +1526,7 @@ class ControlConnection(object):
_SELECT_KEYSPACES = "SELECT * FROM system.schema_keyspaces"
_SELECT_COLUMN_FAMILIES = "SELECT * FROM system.schema_columnfamilies"
_SELECT_COLUMNS = "SELECT * FROM system.schema_columns"
_SELECT_TYPES = "SELECT * FROM system.schema_types"
_SELECT_USERTYPES = "SELECT * FROM system.schema_usertypes"
_SELECT_PEERS = "SELECT peer, data_center, rack, tokens, rpc_address, schema_version FROM system.peers"
_SELECT_LOCAL = "SELECT cluster_name, data_center, rack, tokens, partitioner, schema_version FROM system.local WHERE key='local'"
@@ -1725,7 +1730,7 @@ class ControlConnection(object):
cf_result, col_result = connection.wait_for_responses(
cf_query, col_query)
log.debug("[control connection] Fetched table info for %s.%s, rebuilding metadata", (keyspace, table))
log.debug("[control connection] Fetched table info for %s.%s, rebuilding metadata", keyspace, table)
cf_result = dict_factory(*cf_result.results) if cf_result else {}
col_result = dict_factory(*col_result.results) if col_result else {}
self._cluster.metadata.table_changed(keyspace, table, cf_result, col_result)
@@ -1734,7 +1739,7 @@ class ControlConnection(object):
where_clause = " WHERE keyspace_name = '%s' AND type_name = '%s'" % (keyspace, usertype)
types_query = QueryMessage(query=self._SELECT_USERTYPES + where_clause, consistency_level=cl)
types_result = connection.wait_for_response(types_query)
log.debug("[control connection] Fetched user type info for %s.%s, rebuilding metadata", (keyspace, usertype))
log.debug("[control connection] Fetched user type info for %s.%s, rebuilding metadata", keyspace, usertype)
types_result = dict_factory(*types_result.results) if types_result.results else {}
self._cluster.metadata.usertype_changed(keyspace, usertype, types_result)
elif keyspace:
@@ -1742,7 +1747,7 @@ class ControlConnection(object):
where_clause = " WHERE keyspace_name = '%s'" % (keyspace,)
ks_query = QueryMessage(query=self._SELECT_KEYSPACES + where_clause, consistency_level=cl)
ks_result = connection.wait_for_response(ks_query)
log.debug("[control connection] Fetched keyspace info for %s, rebuilding metadata", (keyspace,))
log.debug("[control connection] Fetched keyspace info for %s, rebuilding metadata", keyspace)
ks_result = dict_factory(*ks_result.results) if ks_result.results else {}
self._cluster.metadata.keyspace_changed(keyspace, ks_result)
else:

View File

@@ -344,7 +344,7 @@ class Connection(object):
% (opcode, stream_id))
if body_len > 0:
body = msg[8:]
body = msg[self._full_header_length:]
elif body_len == 0:
body = six.binary_type()
else:

View File

@@ -297,8 +297,7 @@ class _CassandraType(object):
if cls.num_subtypes != 'UNKNOWN' and len(subtypes) != cls.num_subtypes:
raise ValueError("%s types require %d subtypes (%d given)"
% (cls.typename, cls.num_subtypes, len(subtypes)))
# newname = cls.cass_parameterized_type_with(subtypes).encode('utf8')
newname = cls.cass_parameterized_type_with(subtypes)
newname = cls.cass_parameterized_type_with(subtypes).encode('utf8')
return type(newname, (cls,), {'subtypes': subtypes, 'cassname': cls.cassname})
@classmethod
@@ -782,6 +781,7 @@ class UserDefinedType(_ParameterizedType):
@classmethod
def apply_parameters(cls, keyspace, udt_name, names_and_types, mapped_class):
udt_name = udt_name.encode('utf-8')
try:
return cls._cache[(keyspace, udt_name)]
except KeyError:

View File

@@ -119,7 +119,7 @@ class Metadata(object):
keyspace_meta.tables[table_meta.name] = table_meta
for usertype_row in usertype_rows.get(keyspace_meta.name, []):
usertype = self._build_usertype(usertype_row)
usertype = self._build_usertype(keyspace_meta.name, usertype_row)
keyspace_meta.user_types[usertype.name] = usertype
current_keyspaces.add(keyspace_meta.name)
@@ -149,6 +149,7 @@ class Metadata(object):
old_keyspace_meta = self.keyspaces.get(keyspace, None)
self.keyspaces[keyspace] = keyspace_meta
if old_keyspace_meta:
keyspace_meta.user_types = old_keyspace_meta.user_types
if (keyspace_meta.replication_strategy != old_keyspace_meta.replication_strategy):
self._keyspace_updated(keyspace)
else:
@@ -156,7 +157,7 @@ class Metadata(object):
def usertype_changed(self, keyspace, name, type_results):
new_usertype = self._build_usertype(keyspace, type_results[0])
self.user_types[name] = new_usertype
self.keyspaces[keyspace].user_types[name] = new_usertype
def table_changed(self, keyspace, table, cf_results, col_results):
try:

View File

@@ -632,11 +632,11 @@ class ResultMessage(_MessageType):
valsubtype = cls.read_type(f, user_type_map)
typeclass = typeclass.apply_parameters((keysubtype, valsubtype))
elif typeclass == UserDefinedType:
ks = cls.read_string(f)
udt_name = cls.read_string(f)
num_fields = cls.read_short(f)
names_and_types = ((cls.read_string(f), cls.read_type(f, user_type_map))
for _ in xrange(num_fields))
ks = read_string(f)
udt_name = read_string(f)
num_fields = read_short(f)
names_and_types = tuple((read_string(f), cls.read_type(f, user_type_map))
for _ in xrange(num_fields))
mapped_class = user_type_map.get(ks, {}).get(udt_name)
typeclass = typeclass.apply_parameters(
ks, udt_name, names_and_types, mapped_class)

View File

@@ -0,0 +1,229 @@
# Copyright 2013-2014 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
import logging
log = logging.getLogger(__name__)
from collections import namedtuple
from cassandra.cluster import Cluster
from tests.integration import get_server_versions, PROTOCOL_VERSION
import time
class TypeTests(unittest.TestCase):
def setUp(self):
if PROTOCOL_VERSION < 3:
raise unittest.SkipTest("v3 protocol is required for UDT tests")
self._cass_version, self._cql_version = get_server_versions()
def test_unprepared_registered_udts(self):
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
s.execute("""
CREATE KEYSPACE udt_test_unprepared_registered
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
""")
s.set_keyspace("udt_test_unprepared_registered")
s.execute("CREATE TYPE user (age int, name text)")
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)")
time.sleep(1)
User = namedtuple('user', ('age', 'name'))
c.register_user_type("udt_test_unprepared_registered", "user", User)
s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User(42, 'bob')))
result = s.execute("SELECT b FROM mytable WHERE a=0")
self.assertEqual(1, len(result))
row = result[0]
self.assertEqual(42, row.b.age)
self.assertEqual('bob', row.b.name)
self.assertTrue(type(row.b) is User)
# use the same UDT name in a different keyspace
s.execute("""
CREATE KEYSPACE udt_test_unprepared_registered2
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
""")
s.set_keyspace("udt_test_unprepared_registered2")
s.execute("CREATE TYPE user (state text, is_cool boolean)")
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)")
time.sleep(1)
User = namedtuple('user', ('state', 'is_cool'))
c.register_user_type("udt_test_unprepared_registered2", "user", User)
s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User('Texas', True)))
result = s.execute("SELECT b FROM mytable WHERE a=0")
self.assertEqual(1, len(result))
row = result[0]
self.assertEqual('Texas', row.b.state)
self.assertEqual(True, row.b.is_cool)
self.assertTrue(type(row.b) is User)
def test_register_before_connecting(self):
User1 = namedtuple('user', ('age', 'name'))
User2 = namedtuple('user', ('state', 'is_cool'))
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
s.execute("""
CREATE KEYSPACE udt_test_register_before_connecting
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
""")
s.set_keyspace("udt_test_register_before_connecting")
s.execute("CREATE TYPE user (age int, name text)")
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)")
s.execute("""
CREATE KEYSPACE udt_test_register_before_connecting2
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
""")
s.set_keyspace("udt_test_register_before_connecting2")
s.execute("CREATE TYPE user (state text, is_cool boolean)")
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)")
time.sleep(1)
# now that types are defined, shutdown and re-create Cluster
c.shutdown()
c = Cluster(protocol_version=PROTOCOL_VERSION)
c.register_user_type("udt_test_register_before_connecting", "user", User1)
c.register_user_type("udt_test_register_before_connecting2", "user", User2)
s = c.connect()
s.set_keyspace("udt_test_register_before_connecting")
s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User1(42, 'bob')))
result = s.execute("SELECT b FROM mytable WHERE a=0")
self.assertEqual(1, len(result))
row = result[0]
self.assertEqual(42, row.b.age)
self.assertEqual('bob', row.b.name)
self.assertTrue(type(row.b) is User1)
# use the same UDT name in a different keyspace
s.set_keyspace("udt_test_register_before_connecting2")
s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User2('Texas', True)))
result = s.execute("SELECT b FROM mytable WHERE a=0")
self.assertEqual(1, len(result))
row = result[0]
self.assertEqual('Texas', row.b.state)
self.assertEqual(True, row.b.is_cool)
self.assertTrue(type(row.b) is User2)
def test_prepared_unregistered_udts(self):
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
s.execute("""
CREATE KEYSPACE udt_test_prepared_unregistered
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
""")
s.set_keyspace("udt_test_prepared_unregistered")
s.execute("CREATE TYPE user (age int, name text)")
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)")
User = namedtuple('user', ('age', 'name'))
insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
s.execute(insert, (0, User(42, 'bob')))
select = s.prepare("SELECT b FROM mytable WHERE a=?")
result = s.execute(select, (0,))
self.assertEqual(1, len(result))
row = result[0]
self.assertEqual(42, row.b.age)
self.assertEqual('bob', row.b.name)
# use the same UDT name in a different keyspace
s.execute("""
CREATE KEYSPACE udt_test_prepared_unregistered2
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
""")
s.set_keyspace("udt_test_prepared_unregistered2")
s.execute("CREATE TYPE user (state text, is_cool boolean)")
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)")
User = namedtuple('user', ('state', 'is_cool'))
insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
s.execute(insert, (0, User('Texas', True)))
select = s.prepare("SELECT b FROM mytable WHERE a=?")
result = s.execute(select, (0,))
self.assertEqual(1, len(result))
row = result[0]
self.assertEqual('Texas', row.b.state)
self.assertEqual(True, row.b.is_cool)
def test_prepared_registered_udts(self):
c = Cluster(protocol_version=PROTOCOL_VERSION)
s = c.connect()
s.execute("""
CREATE KEYSPACE udt_test_prepared_registered
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
""")
s.set_keyspace("udt_test_prepared_registered")
s.execute("CREATE TYPE user (age int, name text)")
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)")
User = namedtuple('user', ('age', 'name'))
c.register_user_type("udt_test_prepared_registered", "user", User)
insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
s.execute(insert, (0, User(42, 'bob')))
select = s.prepare("SELECT b FROM mytable WHERE a=?")
result = s.execute(select, (0,))
self.assertEqual(1, len(result))
row = result[0]
self.assertEqual(42, row.b.age)
self.assertEqual('bob', row.b.name)
self.assertTrue(type(row.b) is User)
# use the same UDT name in a different keyspace
s.execute("""
CREATE KEYSPACE udt_test_prepared_registered2
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
""")
s.set_keyspace("udt_test_prepared_registered2")
s.execute("CREATE TYPE user (state text, is_cool boolean)")
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)")
User = namedtuple('user', ('state', 'is_cool'))
c.register_user_type("udt_test_prepared_registered2", "user", User)
insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
s.execute(insert, (0, User('Texas', True)))
select = s.prepare("SELECT b FROM mytable WHERE a=?")
result = s.execute(select, (0,))
self.assertEqual(1, len(result))
row = result[0]
self.assertEqual('Texas', row.b.state)
self.assertEqual(True, row.b.is_cool)
self.assertTrue(type(row.b) is User)