Integration tests for UDTs, related fixes
This commit is contained in:
@@ -459,7 +459,7 @@ class Cluster(object):
|
||||
def register_user_type(self, keyspace, user_type, klass):
|
||||
self._user_types[keyspace][user_type] = klass
|
||||
for session in self.sessions:
|
||||
self.session.user_type_registered(keyspace, user_type, klass)
|
||||
session.user_type_registered(keyspace, user_type, klass)
|
||||
|
||||
def get_min_requests_per_connection(self, host_distance):
|
||||
return self._min_requests_per_connection[host_distance]
|
||||
@@ -1076,17 +1076,6 @@ class Session(object):
|
||||
|
||||
encoders = None
|
||||
|
||||
def user_type_registered(self, keyspace, user_type, klass):
|
||||
type_meta = self.cluster.metadata.keyspaces[keyspace].user_types[user_type]
|
||||
|
||||
def encode(val):
|
||||
return '{ %s }' % ' , '.join('%s : %s' % (
|
||||
field_name,
|
||||
cql_encode_all_types(getattr(val, field_name))
|
||||
) for field_name in type_meta.field_names)
|
||||
|
||||
self._encoders[klass] = encode
|
||||
|
||||
def __init__(self, cluster, hosts):
|
||||
self.cluster = cluster
|
||||
self.hosts = hosts
|
||||
@@ -1457,6 +1446,22 @@ class Session(object):
|
||||
for pool in self._pools.values():
|
||||
pool._set_keyspace_for_all_conns(keyspace, pool_finished_setting_keyspace)
|
||||
|
||||
def user_type_registered(self, keyspace, user_type, klass):
|
||||
"""
|
||||
Called by the parent Cluster instance when the user registers a new
|
||||
mapping from a user-defined type to a class. Intended for internal
|
||||
use only.
|
||||
"""
|
||||
type_meta = self.cluster.metadata.keyspaces[keyspace].user_types[user_type]
|
||||
|
||||
def encode(val):
|
||||
return '{ %s }' % ' , '.join('%s : %s' % (
|
||||
field_name,
|
||||
cql_encode_all_types(getattr(val, field_name))
|
||||
) for field_name in type_meta.field_names)
|
||||
|
||||
self._encoders[klass] = encode
|
||||
|
||||
def submit(self, fn, *args, **kwargs):
|
||||
""" Internal """
|
||||
if not self.is_shutdown:
|
||||
@@ -1521,7 +1526,7 @@ class ControlConnection(object):
|
||||
_SELECT_KEYSPACES = "SELECT * FROM system.schema_keyspaces"
|
||||
_SELECT_COLUMN_FAMILIES = "SELECT * FROM system.schema_columnfamilies"
|
||||
_SELECT_COLUMNS = "SELECT * FROM system.schema_columns"
|
||||
_SELECT_TYPES = "SELECT * FROM system.schema_types"
|
||||
_SELECT_USERTYPES = "SELECT * FROM system.schema_usertypes"
|
||||
|
||||
_SELECT_PEERS = "SELECT peer, data_center, rack, tokens, rpc_address, schema_version FROM system.peers"
|
||||
_SELECT_LOCAL = "SELECT cluster_name, data_center, rack, tokens, partitioner, schema_version FROM system.local WHERE key='local'"
|
||||
@@ -1725,7 +1730,7 @@ class ControlConnection(object):
|
||||
cf_result, col_result = connection.wait_for_responses(
|
||||
cf_query, col_query)
|
||||
|
||||
log.debug("[control connection] Fetched table info for %s.%s, rebuilding metadata", (keyspace, table))
|
||||
log.debug("[control connection] Fetched table info for %s.%s, rebuilding metadata", keyspace, table)
|
||||
cf_result = dict_factory(*cf_result.results) if cf_result else {}
|
||||
col_result = dict_factory(*col_result.results) if col_result else {}
|
||||
self._cluster.metadata.table_changed(keyspace, table, cf_result, col_result)
|
||||
@@ -1734,7 +1739,7 @@ class ControlConnection(object):
|
||||
where_clause = " WHERE keyspace_name = '%s' AND type_name = '%s'" % (keyspace, usertype)
|
||||
types_query = QueryMessage(query=self._SELECT_USERTYPES + where_clause, consistency_level=cl)
|
||||
types_result = connection.wait_for_response(types_query)
|
||||
log.debug("[control connection] Fetched user type info for %s.%s, rebuilding metadata", (keyspace, usertype))
|
||||
log.debug("[control connection] Fetched user type info for %s.%s, rebuilding metadata", keyspace, usertype)
|
||||
types_result = dict_factory(*types_result.results) if types_result.results else {}
|
||||
self._cluster.metadata.usertype_changed(keyspace, usertype, types_result)
|
||||
elif keyspace:
|
||||
@@ -1742,7 +1747,7 @@ class ControlConnection(object):
|
||||
where_clause = " WHERE keyspace_name = '%s'" % (keyspace,)
|
||||
ks_query = QueryMessage(query=self._SELECT_KEYSPACES + where_clause, consistency_level=cl)
|
||||
ks_result = connection.wait_for_response(ks_query)
|
||||
log.debug("[control connection] Fetched keyspace info for %s, rebuilding metadata", (keyspace,))
|
||||
log.debug("[control connection] Fetched keyspace info for %s, rebuilding metadata", keyspace)
|
||||
ks_result = dict_factory(*ks_result.results) if ks_result.results else {}
|
||||
self._cluster.metadata.keyspace_changed(keyspace, ks_result)
|
||||
else:
|
||||
|
@@ -344,7 +344,7 @@ class Connection(object):
|
||||
% (opcode, stream_id))
|
||||
|
||||
if body_len > 0:
|
||||
body = msg[8:]
|
||||
body = msg[self._full_header_length:]
|
||||
elif body_len == 0:
|
||||
body = six.binary_type()
|
||||
else:
|
||||
|
@@ -297,8 +297,7 @@ class _CassandraType(object):
|
||||
if cls.num_subtypes != 'UNKNOWN' and len(subtypes) != cls.num_subtypes:
|
||||
raise ValueError("%s types require %d subtypes (%d given)"
|
||||
% (cls.typename, cls.num_subtypes, len(subtypes)))
|
||||
# newname = cls.cass_parameterized_type_with(subtypes).encode('utf8')
|
||||
newname = cls.cass_parameterized_type_with(subtypes)
|
||||
newname = cls.cass_parameterized_type_with(subtypes).encode('utf8')
|
||||
return type(newname, (cls,), {'subtypes': subtypes, 'cassname': cls.cassname})
|
||||
|
||||
@classmethod
|
||||
@@ -782,6 +781,7 @@ class UserDefinedType(_ParameterizedType):
|
||||
|
||||
@classmethod
|
||||
def apply_parameters(cls, keyspace, udt_name, names_and_types, mapped_class):
|
||||
udt_name = udt_name.encode('utf-8')
|
||||
try:
|
||||
return cls._cache[(keyspace, udt_name)]
|
||||
except KeyError:
|
||||
|
@@ -119,7 +119,7 @@ class Metadata(object):
|
||||
keyspace_meta.tables[table_meta.name] = table_meta
|
||||
|
||||
for usertype_row in usertype_rows.get(keyspace_meta.name, []):
|
||||
usertype = self._build_usertype(usertype_row)
|
||||
usertype = self._build_usertype(keyspace_meta.name, usertype_row)
|
||||
keyspace_meta.user_types[usertype.name] = usertype
|
||||
|
||||
current_keyspaces.add(keyspace_meta.name)
|
||||
@@ -149,6 +149,7 @@ class Metadata(object):
|
||||
old_keyspace_meta = self.keyspaces.get(keyspace, None)
|
||||
self.keyspaces[keyspace] = keyspace_meta
|
||||
if old_keyspace_meta:
|
||||
keyspace_meta.user_types = old_keyspace_meta.user_types
|
||||
if (keyspace_meta.replication_strategy != old_keyspace_meta.replication_strategy):
|
||||
self._keyspace_updated(keyspace)
|
||||
else:
|
||||
@@ -156,7 +157,7 @@ class Metadata(object):
|
||||
|
||||
def usertype_changed(self, keyspace, name, type_results):
|
||||
new_usertype = self._build_usertype(keyspace, type_results[0])
|
||||
self.user_types[name] = new_usertype
|
||||
self.keyspaces[keyspace].user_types[name] = new_usertype
|
||||
|
||||
def table_changed(self, keyspace, table, cf_results, col_results):
|
||||
try:
|
||||
|
@@ -632,11 +632,11 @@ class ResultMessage(_MessageType):
|
||||
valsubtype = cls.read_type(f, user_type_map)
|
||||
typeclass = typeclass.apply_parameters((keysubtype, valsubtype))
|
||||
elif typeclass == UserDefinedType:
|
||||
ks = cls.read_string(f)
|
||||
udt_name = cls.read_string(f)
|
||||
num_fields = cls.read_short(f)
|
||||
names_and_types = ((cls.read_string(f), cls.read_type(f, user_type_map))
|
||||
for _ in xrange(num_fields))
|
||||
ks = read_string(f)
|
||||
udt_name = read_string(f)
|
||||
num_fields = read_short(f)
|
||||
names_and_types = tuple((read_string(f), cls.read_type(f, user_type_map))
|
||||
for _ in xrange(num_fields))
|
||||
mapped_class = user_type_map.get(ks, {}).get(udt_name)
|
||||
typeclass = typeclass.apply_parameters(
|
||||
ks, udt_name, names_and_types, mapped_class)
|
||||
|
229
tests/integration/standard/test_udts.py
Normal file
229
tests/integration/standard/test_udts.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# Copyright 2013-2014 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
|
||||
from tests.integration import get_server_versions, PROTOCOL_VERSION
|
||||
|
||||
import time
|
||||
|
||||
|
||||
class TypeTests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if PROTOCOL_VERSION < 3:
|
||||
raise unittest.SkipTest("v3 protocol is required for UDT tests")
|
||||
|
||||
self._cass_version, self._cql_version = get_server_versions()
|
||||
|
||||
def test_unprepared_registered_udts(self):
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect()
|
||||
|
||||
s.execute("""
|
||||
CREATE KEYSPACE udt_test_unprepared_registered
|
||||
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
|
||||
""")
|
||||
s.set_keyspace("udt_test_unprepared_registered")
|
||||
s.execute("CREATE TYPE user (age int, name text)")
|
||||
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)")
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
User = namedtuple('user', ('age', 'name'))
|
||||
c.register_user_type("udt_test_unprepared_registered", "user", User)
|
||||
|
||||
s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User(42, 'bob')))
|
||||
result = s.execute("SELECT b FROM mytable WHERE a=0")
|
||||
self.assertEqual(1, len(result))
|
||||
row = result[0]
|
||||
self.assertEqual(42, row.b.age)
|
||||
self.assertEqual('bob', row.b.name)
|
||||
self.assertTrue(type(row.b) is User)
|
||||
|
||||
# use the same UDT name in a different keyspace
|
||||
s.execute("""
|
||||
CREATE KEYSPACE udt_test_unprepared_registered2
|
||||
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
|
||||
""")
|
||||
s.set_keyspace("udt_test_unprepared_registered2")
|
||||
s.execute("CREATE TYPE user (state text, is_cool boolean)")
|
||||
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)")
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
User = namedtuple('user', ('state', 'is_cool'))
|
||||
c.register_user_type("udt_test_unprepared_registered2", "user", User)
|
||||
|
||||
s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User('Texas', True)))
|
||||
result = s.execute("SELECT b FROM mytable WHERE a=0")
|
||||
self.assertEqual(1, len(result))
|
||||
row = result[0]
|
||||
self.assertEqual('Texas', row.b.state)
|
||||
self.assertEqual(True, row.b.is_cool)
|
||||
self.assertTrue(type(row.b) is User)
|
||||
|
||||
def test_register_before_connecting(self):
|
||||
User1 = namedtuple('user', ('age', 'name'))
|
||||
User2 = namedtuple('user', ('state', 'is_cool'))
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect()
|
||||
|
||||
s.execute("""
|
||||
CREATE KEYSPACE udt_test_register_before_connecting
|
||||
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
|
||||
""")
|
||||
s.set_keyspace("udt_test_register_before_connecting")
|
||||
s.execute("CREATE TYPE user (age int, name text)")
|
||||
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)")
|
||||
|
||||
s.execute("""
|
||||
CREATE KEYSPACE udt_test_register_before_connecting2
|
||||
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
|
||||
""")
|
||||
s.set_keyspace("udt_test_register_before_connecting2")
|
||||
s.execute("CREATE TYPE user (state text, is_cool boolean)")
|
||||
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)")
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
# now that types are defined, shutdown and re-create Cluster
|
||||
c.shutdown()
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
c.register_user_type("udt_test_register_before_connecting", "user", User1)
|
||||
c.register_user_type("udt_test_register_before_connecting2", "user", User2)
|
||||
|
||||
s = c.connect()
|
||||
|
||||
s.set_keyspace("udt_test_register_before_connecting")
|
||||
s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User1(42, 'bob')))
|
||||
result = s.execute("SELECT b FROM mytable WHERE a=0")
|
||||
self.assertEqual(1, len(result))
|
||||
row = result[0]
|
||||
self.assertEqual(42, row.b.age)
|
||||
self.assertEqual('bob', row.b.name)
|
||||
self.assertTrue(type(row.b) is User1)
|
||||
|
||||
# use the same UDT name in a different keyspace
|
||||
s.set_keyspace("udt_test_register_before_connecting2")
|
||||
s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User2('Texas', True)))
|
||||
result = s.execute("SELECT b FROM mytable WHERE a=0")
|
||||
self.assertEqual(1, len(result))
|
||||
row = result[0]
|
||||
self.assertEqual('Texas', row.b.state)
|
||||
self.assertEqual(True, row.b.is_cool)
|
||||
self.assertTrue(type(row.b) is User2)
|
||||
|
||||
def test_prepared_unregistered_udts(self):
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect()
|
||||
|
||||
s.execute("""
|
||||
CREATE KEYSPACE udt_test_prepared_unregistered
|
||||
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
|
||||
""")
|
||||
s.set_keyspace("udt_test_prepared_unregistered")
|
||||
s.execute("CREATE TYPE user (age int, name text)")
|
||||
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)")
|
||||
|
||||
User = namedtuple('user', ('age', 'name'))
|
||||
insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
|
||||
s.execute(insert, (0, User(42, 'bob')))
|
||||
|
||||
select = s.prepare("SELECT b FROM mytable WHERE a=?")
|
||||
result = s.execute(select, (0,))
|
||||
self.assertEqual(1, len(result))
|
||||
row = result[0]
|
||||
self.assertEqual(42, row.b.age)
|
||||
self.assertEqual('bob', row.b.name)
|
||||
|
||||
# use the same UDT name in a different keyspace
|
||||
s.execute("""
|
||||
CREATE KEYSPACE udt_test_prepared_unregistered2
|
||||
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
|
||||
""")
|
||||
s.set_keyspace("udt_test_prepared_unregistered2")
|
||||
s.execute("CREATE TYPE user (state text, is_cool boolean)")
|
||||
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)")
|
||||
|
||||
User = namedtuple('user', ('state', 'is_cool'))
|
||||
insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
|
||||
s.execute(insert, (0, User('Texas', True)))
|
||||
|
||||
select = s.prepare("SELECT b FROM mytable WHERE a=?")
|
||||
result = s.execute(select, (0,))
|
||||
self.assertEqual(1, len(result))
|
||||
row = result[0]
|
||||
self.assertEqual('Texas', row.b.state)
|
||||
self.assertEqual(True, row.b.is_cool)
|
||||
|
||||
def test_prepared_registered_udts(self):
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect()
|
||||
|
||||
s.execute("""
|
||||
CREATE KEYSPACE udt_test_prepared_registered
|
||||
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
|
||||
""")
|
||||
s.set_keyspace("udt_test_prepared_registered")
|
||||
s.execute("CREATE TYPE user (age int, name text)")
|
||||
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)")
|
||||
|
||||
User = namedtuple('user', ('age', 'name'))
|
||||
c.register_user_type("udt_test_prepared_registered", "user", User)
|
||||
|
||||
insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
|
||||
s.execute(insert, (0, User(42, 'bob')))
|
||||
|
||||
select = s.prepare("SELECT b FROM mytable WHERE a=?")
|
||||
result = s.execute(select, (0,))
|
||||
self.assertEqual(1, len(result))
|
||||
row = result[0]
|
||||
self.assertEqual(42, row.b.age)
|
||||
self.assertEqual('bob', row.b.name)
|
||||
self.assertTrue(type(row.b) is User)
|
||||
|
||||
# use the same UDT name in a different keyspace
|
||||
s.execute("""
|
||||
CREATE KEYSPACE udt_test_prepared_registered2
|
||||
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' }
|
||||
""")
|
||||
s.set_keyspace("udt_test_prepared_registered2")
|
||||
s.execute("CREATE TYPE user (state text, is_cool boolean)")
|
||||
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b user)")
|
||||
|
||||
User = namedtuple('user', ('state', 'is_cool'))
|
||||
c.register_user_type("udt_test_prepared_registered2", "user", User)
|
||||
|
||||
insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)")
|
||||
s.execute(insert, (0, User('Texas', True)))
|
||||
|
||||
select = s.prepare("SELECT b FROM mytable WHERE a=?")
|
||||
result = s.execute(select, (0,))
|
||||
self.assertEqual(1, len(result))
|
||||
row = result[0]
|
||||
self.assertEqual('Texas', row.b.state)
|
||||
self.assertEqual(True, row.b.is_cool)
|
||||
self.assertTrue(type(row.b) is User)
|
Reference in New Issue
Block a user