From e2319577ec56bc36437cedca07c2f35e298155f8 Mon Sep 17 00:00:00 2001 From: GregBestland Date: Mon, 16 Nov 2015 13:13:48 -0600 Subject: [PATCH] Added tests for Python-215, 226, 212, 439. Refactoring tests to use common setup and teardown. Fixing broken testcases. --- tests/integration/__init__.py | 44 ++++- .../integration/cqlengine/query/test_named.py | 21 ++- .../standard/test_cython_protocol_handlers.py | 1 + tests/integration/standard/test_query.py | 71 ++++---- tests/integration/standard/test_types.py | 162 +++++++++++++++--- tests/integration/standard/test_udts.py | 102 ++++++----- 6 files changed, 284 insertions(+), 117 deletions(-) diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index 8a873813..1fd1a670 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -127,6 +127,12 @@ else: PROTOCOL_VERSION = int(os.getenv('PROTOCOL_VERSION', default_protocol_version)) notprotocolv1 = unittest.skipUnless(PROTOCOL_VERSION > 1, 'Protocol v1 not supported') +lessthenprotocolv4 = unittest.skipUnless(PROTOCOL_VERSION < 4, 'Protocol versions 4 or greater no supported') +greaterthanprotocolv3 = unittest.skipUnless(PROTOCOL_VERSION >= 4, 'Protocol versions less than 4 are not supported') + +greaterthancass20 = unittest.skipUnless(CASSANDRA_VERSION >= '2.1', 'Cassandra version 2.1 or greater required') +greaterthanorequalcass30 = unittest.skipUnless(CASSANDRA_VERSION >= '3.0', 'Cassandra version 3.0 or greater required') +lessthancass30 = unittest.skipUnless(CASSANDRA_VERSION < '3.0', 'Cassandra version less then 3.0 required') def get_cluster(): @@ -171,6 +177,7 @@ def remove_cluster(): raise RuntimeError("Failed to remove cluster after 100 attempts") + def is_current_cluster(cluster_name, node_counts): global CCM_CLUSTER if CCM_CLUSTER and CCM_CLUSTER.name == cluster_name: @@ -395,12 +402,13 @@ class BasicKeyspaceUnitTestCase(unittest.TestCase): execute_with_long_wait_retry(cls.session, ddl) @classmethod - def common_setup(cls, rf, create_class_table=False, skip_if_cass_version_less_than=None): + def common_setup(cls, rf, keyspace_creation=True, create_class_table=False): cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) cls.session = cls.cluster.connect() cls.ks_name = cls.__name__.lower() - cls.create_keyspace(rf) - cls.cass_version = get_server_versions() + if keyspace_creation: + cls.create_keyspace(rf) + cls.cass_version, cls.cql_version = get_server_versions() if create_class_table: @@ -422,6 +430,19 @@ class BasicKeyspaceUnitTestCase(unittest.TestCase): execute_until_pass(self.session, ddl) +class BasicExistingKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase): + """ + This is basic unit test defines class level teardown and setup methods. It assumes that keyspace is already defined, or created as part of the test. + """ + @classmethod + def setUpClass(cls): + cls.common_setup(1, keyspace_creation=False) + + @classmethod + def tearDownClass(cls): + cls.cluster.shutdown() + + class BasicSharedKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase): """ This is basic unit test case that can be leveraged to scope a keyspace to a specific test class. @@ -433,7 +454,7 @@ class BasicSharedKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase): @classmethod def tearDownClass(cls): - drop_keyspace_shutdown_cluster(cls.keyspace_name, cls.session, cls.cluster) + drop_keyspace_shutdown_cluster(cls.ks_name, cls.session, cls.cluster) class BasicSharedKeyspaceUnitTestCaseWTable(BasicSharedKeyspaceUnitTestCase): @@ -510,4 +531,17 @@ class BasicSegregatedKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase): self.common_setup(1) def tearDown(self): - drop_keyspace_shutdown_cluster(self.keyspace_name, self.session, self.cluster) + drop_keyspace_shutdown_cluster(self.ks_name, self.session, self.cluster) + + +class BasicExistingSegregatedKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase): + """ + This unit test will create and teardown or each individual unit tests. It assumes that keyspace is existing + or created as part of a test. + It has some overhead and should only be used when sharing cluster/session is not feasible. + """ + def setUp(self): + self.common_setup(1, keyspace_creation=False) + + def tearDown(self): + self.cluster.shutdown() diff --git a/tests/integration/cqlengine/query/test_named.py b/tests/integration/cqlengine/query/test_named.py index 9f744cfd..12b02003 100644 --- a/tests/integration/cqlengine/query/test_named.py +++ b/tests/integration/cqlengine/query/test_named.py @@ -29,7 +29,7 @@ from tests.integration.cqlengine.base import BaseCassEngTestCase from tests.integration.cqlengine.query.test_queryset import BaseQuerySetUsage -from tests.integration import BasicSharedKeyspaceUnitTestCase, get_server_versions +from tests.integration import BasicSharedKeyspaceUnitTestCase, greaterthanorequalcass30 class TestQuerySetOperation(BaseCassEngTestCase): @@ -270,18 +270,21 @@ class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage): self.table.objects.get(test_id=1) -class TestNamedWithMV(BaseCassEngTestCase): +class TestNamedWithMV(BasicSharedKeyspaceUnitTestCase): - def setUp(self): - cass_version = get_server_versions()[0] - if cass_version < (3, 0): - raise unittest.SkipTest("Materialized views require Cassandra 3.0+") - super(TestNamedWithMV, self).setUp() + @classmethod + def setUpClass(cls): + super(TestNamedWithMV, cls).setUpClass() + cls.default_keyspace = models.DEFAULT_KEYSPACE + models.DEFAULT_KEYSPACE = cls.ks_name - def tearDown(self): - models.DEFAULT_KEYSPACE = self.default_keyspace + @classmethod + def tearDownClass(cls): + models.DEFAULT_KEYSPACE = cls.default_keyspace setup_connection(models.DEFAULT_KEYSPACE) + super(TestNamedWithMV, cls).tearDownClass() + @greaterthanorequalcass30 def test_named_table_with_mv(self): """ Test NamedTable access to materialized views diff --git a/tests/integration/standard/test_cython_protocol_handlers.py b/tests/integration/standard/test_cython_protocol_handlers.py index 7532426e..dc24d0a3 100644 --- a/tests/integration/standard/test_cython_protocol_handlers.py +++ b/tests/integration/standard/test_cython_protocol_handlers.py @@ -18,6 +18,7 @@ from tests.integration.standard.utils import ( from tests.unit.cython.utils import cythontest, numpytest + def setup_module(): use_singledc() update_datatypes() diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index 99add1a9..ade54bcf 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -26,48 +26,42 @@ from cassandra.query import (PreparedStatement, BoundStatement, SimpleStatement, from cassandra.cluster import Cluster from cassandra.policies import HostDistance -from tests.integration import use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCase, get_server_versions +from tests.integration import use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCase, get_server_versions, greaterthanprotocolv3 import re def setup_module(): + print("Setting up module") use_singledc() global CASS_SERVER_VERSION CASS_SERVER_VERSION = get_server_versions()[0] -class QueryTests(unittest.TestCase): +class QueryTests(BasicSharedKeyspaceUnitTestCase): def test_query(self): - cluster = Cluster(protocol_version=PROTOCOL_VERSION) - session = cluster.connect() - prepared = session.prepare( + prepared = self.session.prepare( """ INSERT INTO test3rf.test (k, v) VALUES (?, ?) - """) + """.format(self.keyspace_name)) self.assertIsInstance(prepared, PreparedStatement) bound = prepared.bind((1, None)) self.assertIsInstance(bound, BoundStatement) self.assertEqual(2, len(bound.values)) - session.execute(bound) + self.session.execute(bound) self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01') - cluster.shutdown() - def test_trace_prints_okay(self): """ Code coverage to ensure trace prints to string without error """ - cluster = Cluster(protocol_version=PROTOCOL_VERSION) - session = cluster.connect() - query = "SELECT * FROM system.local" statement = SimpleStatement(query) - rs = session.execute(statement, trace=True) + rs = self.session.execute(statement, trace=True) # Ensure this does not throw an exception trace = rs.get_query_trace() @@ -76,13 +70,9 @@ class QueryTests(unittest.TestCase): for event in trace.events: str(event) - cluster.shutdown() - def test_trace_id_to_resultset(self): - cluster = Cluster(protocol_version=PROTOCOL_VERSION) - session = cluster.connect() - future = session.execute_async("SELECT * FROM system.local", trace=True) + future = self.session.execute_async("SELECT * FROM system.local", trace=True) # future should have the current trace rs = future.result() @@ -96,16 +86,12 @@ class QueryTests(unittest.TestCase): self.assertListEqual([rs_trace], rs.get_all_query_traces()) - cluster.shutdown() - def test_trace_ignores_row_factory(self): - cluster = Cluster(protocol_version=PROTOCOL_VERSION) - session = cluster.connect() - session.row_factory = dict_factory + self.session.row_factory = dict_factory query = "SELECT * FROM system.local" statement = SimpleStatement(query) - rs = session.execute(statement, trace=True) + rs = self.session.execute(statement, trace=True) # Ensure this does not throw an exception trace = rs.get_query_trace() @@ -114,8 +100,7 @@ class QueryTests(unittest.TestCase): for event in trace.events: str(event) - cluster.shutdown() - + @greaterthanprotocolv3 def test_client_ip_in_trace(self): """ Test to validate that client trace contains client ip information. @@ -136,18 +121,10 @@ class QueryTests(unittest.TestCase): # raise unittest.SkipTest("Client IP was not present in trace until C* 2.2") """ - if PROTOCOL_VERSION < 4: - raise unittest.SkipTest( - "Protocol 4+ is required for client ip tracing, currently testing against %r" - % (PROTOCOL_VERSION,)) - - cluster = Cluster(protocol_version=PROTOCOL_VERSION) - session = cluster.connect() - # Make simple query with trace enabled query = "SELECT * FROM system.local" statement = SimpleStatement(query) - response_future = session.execute_async(statement, trace=True) + response_future = self.session.execute_async(statement, trace=True) response_future.result() # Fetch the client_ip from the trace. @@ -161,7 +138,29 @@ class QueryTests(unittest.TestCase): self.assertIsNotNone(client_ip, "Client IP was not set in trace with C* >= 2.2") self.assertTrue(pat.match(client_ip), "Client IP from trace did not match the expected value") - cluster.shutdown() + def test_column_names(self): + """ + Test to validate the columns are present on the result set. + Preforms a simple query against a table then checks to ensure column names are correct and present and correct. + + @since 3.0.0 + @jira_ticket PYTHON-439 + @expected_result column_names should be preset. + + @test_category queries basic + """ + create_table = """CREATE TABLE {0}.{1}( + user TEXT, + game TEXT, + year INT, + month INT, + day INT, + score INT, + PRIMARY KEY (user, game, year, month, day) + )""".format(self.keyspace_name, self.function_table_name) + self.session.execute(create_table) + result_set = self.session.execute("SELECT * FROM {0}.{1}".format(self.keyspace_name, self.function_table_name)) + self.assertEqual(result_set.column_names, [u'user', u'game', u'year', u'month', u'day', u'score']) class PreparedStatementTests(unittest.TestCase): diff --git a/tests/integration/standard/test_types.py b/tests/integration/standard/test_types.py index 73fbc895..5d14f4be 100644 --- a/tests/integration/standard/test_types.py +++ b/tests/integration/standard/test_types.py @@ -27,8 +27,10 @@ from cassandra.concurrent import execute_concurrent_with_args from cassandra.cqltypes import Int32Type, EMPTY from cassandra.query import dict_factory, ordered_dict_factory from cassandra.util import sortedset +from tests.unit.cython.utils import cythontest -from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION, execute_until_pass, notprotocolv1 +from tests.integration import use_singledc, PROTOCOL_VERSION, execute_until_pass, notprotocolv1, \ + BasicSharedKeyspaceUnitTestCase, greaterthancass20, lessthancass30 from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, COLLECTION_TYPES, \ get_sample, get_collection_sample @@ -38,21 +40,14 @@ def setup_module(): update_datatypes() -class TypeTests(unittest.TestCase): +class TypeTests(BasicSharedKeyspaceUnitTestCase): @classmethod def setUpClass(cls): - cls._cass_version, cls._cql_version = get_server_versions() - cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) - cls.session = cls.cluster.connect() - cls.session.execute("CREATE KEYSPACE typetests WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}") - cls.session.set_keyspace("typetests") + # cls._cass_version, cls. = get_server_versions() + super(TypeTests, cls).setUpClass() + cls.session.set_keyspace(cls.ks_name) - @classmethod - def tearDownClass(cls): - execute_until_pass(cls.session, "DROP KEYSPACE typetests") - cls.cluster.shutdown() - def test_can_insert_blob_type_as_string(self): """ Tests that byte strings in Python maps to blob type in Cassandra @@ -66,10 +61,10 @@ class TypeTests(unittest.TestCase): # In python2, with Cassandra > 2.0, we don't treat the 'byte str' type as a blob, so we'll encode it # as a string literal and have the following failure. - if six.PY2 and self._cql_version >= (3, 1, 0): + if six.PY2 and self.cql_version >= (3, 1, 0): # Blob values can't be specified using string notation in CQL 3.1.0 and # above which is used by default in Cassandra 2.0. - if self._cass_version >= (2, 1, 0): + if self.cass_version >= (2, 1, 0): msg = r'.*Invalid STRING constant \(.*?\) for "b" of type blob.*' else: msg = r'.*Invalid STRING constant \(.*?\) for b of type blob.*' @@ -108,7 +103,7 @@ class TypeTests(unittest.TestCase): Test insertion of all datatype primitives """ c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect("typetests") + s = c.connect(self.keyspace_name) # create table alpha_type_list = ["zz int PRIMARY KEY"] @@ -167,7 +162,7 @@ class TypeTests(unittest.TestCase): """ c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect("typetests") + s = c.connect(self.keyspace_name) # use tuple encoding, to convert native python tuple into raw CQL s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple @@ -394,11 +389,11 @@ class TypeTests(unittest.TestCase): Basic test of tuple functionality """ - if self._cass_version < (2, 1, 0): + if self.cass_version < (2, 1, 0): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect("typetests") + s = c.connect(self.keyspace_name) # use this encoder in order to insert tuples s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple @@ -446,11 +441,11 @@ class TypeTests(unittest.TestCase): as expected. """ - if self._cass_version < (2, 1, 0): + if self.cass_version < (2, 1, 0): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect("typetests") + s = c.connect(self.keyspace_name) # set the row_factory to dict_factory for programmatic access # set the encoder for tuples for the ability to write tuples @@ -485,11 +480,11 @@ class TypeTests(unittest.TestCase): Ensure tuple subtypes are appropriately handled. """ - if self._cass_version < (2, 1, 0): + if self.cass_version < (2, 1, 0): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect("typetests") + s = c.connect(self.keyspace_name) s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple s.execute("CREATE TABLE tuple_primitive (" @@ -513,11 +508,11 @@ class TypeTests(unittest.TestCase): Ensure tuple subtypes are appropriately handled for maps, sets, and lists. """ - if self._cass_version < (2, 1, 0): + if self.cass_version < (2, 1, 0): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect("typetests") + s = c.connect(self.keyspace_name) # set the row_factory to dict_factory for programmatic access # set the encoder for tuples for the ability to write tuples @@ -612,11 +607,11 @@ class TypeTests(unittest.TestCase): Ensure nested are appropriately handled. """ - if self._cass_version < (2, 1, 0): + if self.cass_version < (2, 1, 0): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect("typetests") + s = c.connect(self.keyspace_name) # set the row_factory to dict_factory for programmatic access # set the encoder for tuples for the ability to write tuples @@ -652,7 +647,7 @@ class TypeTests(unittest.TestCase): Test tuples with null and empty string fields. """ - if self._cass_version < (2, 1, 0): + if self.cass_version < (2, 1, 0): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") s = self.session @@ -746,3 +741,116 @@ class TypeTests(unittest.TestCase): # prepared binding verify_insert_select(s.prepare('INSERT INTO float_cql_encoding (f, d) VALUES (?, ?)'), s.prepare('SELECT * FROM float_cql_encoding WHERE f=?')) + + @cythontest + def test_cython_decimal(self): + """ + Test to validate that decimal deserialization works correctly in with our cython extensions + + @since 3.0.0 + @jira_ticket PYTHON-212 + @expected_result no exceptions are thrown, decimal is decoded correctly + + @test_category data_types serialization + """ + + self.session.execute("CREATE TABLE {0} (dc decimal PRIMARY KEY)".format(self.function_table_name)) + try: + self.session.execute("INSERT INTO {0} (dc) VALUES (-1.08430792318105707)".format(self.function_table_name)) + results = self.session.execute("SELECT * FROM {0}".format(self.function_table_name)) + self.assertTrue(str(results[0].dc) == '-1.08430792318105707') + finally: + self.session.execute("DROP TABLE {0}".format(self.function_table_name)) + + +class TypeTestsProtocol(BasicSharedKeyspaceUnitTestCase): + + @greaterthancass20 + @lessthancass30 + def test_nested_types_with_protocol_version(self): + """ + Test to validate that nested type serialization works on various protocol versions. Provided + the version of cassandra is greater the 2.1.3 we would expect to nested to types to work at all protocol versions. + + @since 3.0.0 + @jira_ticket PYTHON-215 + @expected_result no exceptions are thrown + + @test_category data_types serialization + """ + ddl = '''CREATE TABLE {0}.t ( + k int PRIMARY KEY, + v list>>)'''.format(self.keyspace_name) + + self.session.execute(ddl) + ddl = '''CREATE TABLE {0}.u ( + k int PRIMARY KEY, + v set>>)'''.format(self.keyspace_name) + self.session.execute(ddl) + ddl = '''CREATE TABLE {0}.v ( + k int PRIMARY KEY, + v map>, frozen>>, + v1 frozen>)'''.format(self.keyspace_name) + self.session.execute(ddl) + + self.session.execute("CREATE TYPE {0}.typ (v0 frozen>>>, v1 frozen>)".format(self.keyspace_name)) + + ddl = '''CREATE TABLE {0}.w ( + k int PRIMARY KEY, + v frozen)'''.format(self.keyspace_name) + + self.session.execute(ddl) + + for pvi in range(1, 5): + self.run_inserts_at_version(pvi) + for pvr in range(1, 5): + self.read_inserts_at_level(pvr) + + def print_results(self, results): + print("printing results") + print(str(results.v)) + + def read_inserts_at_level(self, proto_ver): + session = Cluster(protocol_version=proto_ver).connect(self.keyspace_name) + try: + print("reading at version {0}".format(proto_ver)) + results = session.execute('select * from t')[0] + self.print_results(results) + self.assertEqual("[SortedSet([1, 2]), SortedSet([3, 5])]", str(results.v)) + + results = session.execute('select * from u')[0] + self.print_results(results) + self.assertEqual("SortedSet([[1, 2], [3, 5]])", str(results.v)) + + results = session.execute('select * from v')[0] + self.print_results(results) + self.assertEqual("{SortedSet([1, 2]): [1, 2, 3], SortedSet([3, 5]): [4, 5, 6]}", str(results.v)) + + results = session.execute('select * from w')[0] + self.print_results(results) + self.assertEqual("typ(v0=OrderedMapSerializedKey([(1, [1, 2, 3]), (2, [4, 5, 6])]), v1=[7, 8, 9])", str(results.v)) + + finally: + session.cluster.shutdown() + + def run_inserts_at_version(self, proto_ver): + session = Cluster(protocol_version=proto_ver).connect(self.keyspace_name) + try: + print("running inserts with protocol version {0}".format(str(proto_ver))) + p = session.prepare('insert into t (k, v) values (?, ?)') + session.execute(p, (0, [{1, 2}, {3, 5}])) + + p = session.prepare('insert into u (k, v) values (?, ?)') + session.execute(p, (0, {(1, 2), (3, 5)})) + + p = session.prepare('insert into v (k, v, v1) values (?, ?, ?)') + session.execute(p, (0, {(1, 2): [1, 2, 3], (3, 5): [4, 5, 6]}, (123, 'four'))) + + p = session.prepare('insert into w (k, v) values (?, ?)') + session.execute(p, (0, ({1: [1, 2, 3], 2: [4, 5, 6]}, [7, 8, 9]))) + + finally: + session.cluster.shutdown() + + + diff --git a/tests/integration/standard/test_udts.py b/tests/integration/standard/test_udts.py index 6cebf97b..5c6b0d66 100644 --- a/tests/integration/standard/test_udts.py +++ b/tests/integration/standard/test_udts.py @@ -26,7 +26,7 @@ from cassandra.cluster import Cluster, UserTypeDoesNotExist from cassandra.query import dict_factory from cassandra.util import OrderedMap -from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION, execute_until_pass +from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION, execute_until_pass, BasicSegregatedKeyspaceUnitTestCase, greaterthancass20 from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, COLLECTION_TYPES, \ get_sample, get_collection_sample @@ -39,27 +39,16 @@ def setup_module(): update_datatypes() -class UDTTests(unittest.TestCase): +@greaterthancass20 +class UDTTests(BasicSegregatedKeyspaceUnitTestCase): @property def table_name(self): return self._testMethodName.lower() def setUp(self): - self._cass_version, self._cql_version = get_server_versions() - - if self._cass_version < (2, 1, 0): - raise unittest.SkipTest("User Defined Types were introduced in Cassandra 2.1") - - self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) - self.session = self.cluster.connect() - execute_until_pass(self.session, - "CREATE KEYSPACE udttests WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}") - self.session.set_keyspace("udttests") - - def tearDown(self): - execute_until_pass(self.session, "DROP KEYSPACE udttests") - self.cluster.shutdown() + super(UDTTests, self).setUp() + self.session.set_keyspace(self.keyspace_name) def test_can_insert_unprepared_registered_udts(self): """ @@ -67,13 +56,13 @@ class UDTTests(unittest.TestCase): """ c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect("udttests") + s = c.connect(self.keyspace_name) s.execute("CREATE TYPE user (age int, name text)") s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") User = namedtuple('user', ('age', 'name')) - c.register_user_type("udttests", "user", User) + c.register_user_type(self.keyspace_name, "user", User) s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User(42, 'bob'))) result = s.execute("SELECT b FROM mytable WHERE a=0") @@ -169,7 +158,7 @@ class UDTTests(unittest.TestCase): """ c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect("udttests") + s = c.connect(self.keyspace_name) s.execute("CREATE TYPE user (age int, name text)") s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") @@ -213,11 +202,11 @@ class UDTTests(unittest.TestCase): """ c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect("udttests") + s = c.connect(self.keyspace_name) s.execute("CREATE TYPE user (age int, name text)") User = namedtuple('user', ('age', 'name')) - c.register_user_type("udttests", "user", User) + c.register_user_type(self.keyspace_name, "user", User) s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") @@ -263,11 +252,11 @@ class UDTTests(unittest.TestCase): """ c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect("udttests") + s = c.connect(self.keyspace_name) s.execute("CREATE TYPE user (a text, b int, c uuid, d blob)") User = namedtuple('user', ('a', 'b', 'c', 'd')) - c.register_user_type("udttests", "user", User) + c.register_user_type(self.keyspace_name, "user", User) s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") @@ -293,7 +282,7 @@ class UDTTests(unittest.TestCase): """ c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect("udttests") + s = c.connect(self.keyspace_name) MAX_TEST_LENGTH = 254 @@ -310,7 +299,7 @@ class UDTTests(unittest.TestCase): # create and register the seed udt type udt = namedtuple('lengthy_udt', tuple(['v_{0}'.format(i) for i in range(MAX_TEST_LENGTH)])) - c.register_user_type("udttests", "lengthy_udt", udt) + c.register_user_type(self.keyspace_name, "lengthy_udt", udt) # verify inserts and reads for i in (0, 1, 2, 3, MAX_TEST_LENGTH): @@ -377,7 +366,7 @@ class UDTTests(unittest.TestCase): """ c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect("udttests") + s = c.connect(self.keyspace_name) s.row_factory = dict_factory MAX_NESTING_DEPTH = 16 @@ -389,13 +378,13 @@ class UDTTests(unittest.TestCase): udts = [] udt = namedtuple('depth_0', ('age', 'name')) udts.append(udt) - c.register_user_type("udttests", "depth_0", udts[0]) + c.register_user_type(self.keyspace_name, "depth_0", udts[0]) # create and register the nested udt types for i in range(MAX_NESTING_DEPTH): udt = namedtuple('depth_{0}'.format(i + 1), ('value')) udts.append(udt) - c.register_user_type("udttests", "depth_{0}".format(i + 1), udts[i + 1]) + c.register_user_type(self.keyspace_name, "depth_{0}".format(i + 1), udts[i + 1]) # insert udts and verify inserts with reads self.nested_udt_verification_helper(s, MAX_NESTING_DEPTH, udts) @@ -408,7 +397,7 @@ class UDTTests(unittest.TestCase): """ c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect("udttests") + s = c.connect(self.keyspace_name) s.row_factory = dict_factory MAX_NESTING_DEPTH = 16 @@ -448,7 +437,7 @@ class UDTTests(unittest.TestCase): """ c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect("udttests") + s = c.connect(self.keyspace_name) s.row_factory = dict_factory MAX_NESTING_DEPTH = 16 @@ -460,13 +449,13 @@ class UDTTests(unittest.TestCase): udts = [] udt = namedtuple('level_0', ('age', 'name')) udts.append(udt) - c.register_user_type("udttests", "depth_0", udts[0]) + c.register_user_type(self.keyspace_name, "depth_0", udts[0]) # create and register the nested udt types for i in range(MAX_NESTING_DEPTH): udt = namedtuple('level_{0}'.format(i + 1), ('value')) udts.append(udt) - c.register_user_type("udttests", "depth_{0}".format(i + 1), udts[i + 1]) + c.register_user_type(self.keyspace_name, "depth_{0}".format(i + 1), udts[i + 1]) # insert udts and verify inserts with reads self.nested_udt_verification_helper(s, MAX_NESTING_DEPTH, udts) @@ -479,7 +468,7 @@ class UDTTests(unittest.TestCase): """ c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect("udttests") + s = c.connect(self.keyspace_name) User = namedtuple('user', ('age', 'name')) with self.assertRaises(UserTypeDoesNotExist): @@ -499,7 +488,7 @@ class UDTTests(unittest.TestCase): """ c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect("udttests") + s = c.connect(self.keyspace_name) # create UDT alpha_type_list = [] @@ -510,7 +499,7 @@ class UDTTests(unittest.TestCase): s.execute(""" CREATE TYPE alldatatypes ({0}) """.format(', '.join(alpha_type_list)) - ) + ) s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") @@ -519,7 +508,7 @@ class UDTTests(unittest.TestCase): for i in range(ord('a'), ord('a') + len(PRIMITIVE_DATATYPES)): alphabet_list.append('{0}'.format(chr(i))) Alldatatypes = namedtuple("alldatatypes", alphabet_list) - c.register_user_type("udttests", "alldatatypes", Alldatatypes) + c.register_user_type(self.keyspace_name, "alldatatypes", Alldatatypes) # insert UDT data params = [] @@ -544,7 +533,7 @@ class UDTTests(unittest.TestCase): """ c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect("udttests") + s = c.connect(self.keyspace_name) # create UDT alpha_type_list = [] @@ -576,7 +565,7 @@ class UDTTests(unittest.TestCase): alphabet_list.append('{0}_{1}'.format(chr(i), chr(j))) Alldatatypes = namedtuple("alldatatypes", alphabet_list) - c.register_user_type("udttests", "alldatatypes", Alldatatypes) + c.register_user_type(self.keyspace_name, "alldatatypes", Alldatatypes) # insert UDT data params = [] @@ -607,11 +596,11 @@ class UDTTests(unittest.TestCase): Test for inserting various types of nested COLLECTION_TYPES into tables and UDTs """ - if self._cass_version < (2, 1, 3): + if self.cass_version < (2, 1, 3): raise unittest.SkipTest("Support for nested collections was introduced in Cassandra 2.1.3") c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect("udttests") + s = c.connect(self.keyspace_name) s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple name = self._testMethodName @@ -710,3 +699,36 @@ class UDTTests(unittest.TestCase): val = s.execute('SELECT v FROM %s' % self.table_name)[0][0] self.assertEqual(val['v0'], 3) self.assertEqual(val['v1'], six.b('\xde\xad\xbe\xef')) + + def test_alter_udt(self): + """ + Test to ensure that altered UDT's are properly surfaced without needing to restart the underlying session. + + @since 3.0.0 + @jira_ticket PYTHON-226 + @expected_result UDT's will reflect added columns without a session restart. + + @test_category data_types, udt + """ + + # Create udt ensure it has the proper column names. + self.session.set_keyspace(self.keyspace_name) + self.session.execute("CREATE TYPE typetoalter (a int)") + typetoalter = namedtuple('typetoalter', ('a')) + self.session.execute("CREATE TABLE {0} (pk int primary key, typetoalter frozen)".format(self.function_table_name)) + insert_statement = self.session.prepare("INSERT INTO {0} (pk, typetoalter) VALUES (?, ?)".format(self.function_table_name)) + self.session.execute(insert_statement, [1, typetoalter(1)]) + results = self.session.execute("SELECT * from {0}".format(self.function_table_name)) + for result in results: + self.assertTrue(hasattr(result.typetoalter, 'a')) + self.assertFalse(hasattr(result.typetoalter, 'b')) + + # Alter UDT and ensure the alter is honored in results + self.session.execute("ALTER TYPE typetoalter add b int") + typetoalter = namedtuple('typetoalter', ('a', 'b')) + self.session.execute(insert_statement, [2, typetoalter(2, 2)]) + results = self.session.execute("SELECT * from {0}".format(self.function_table_name)) + for result in results: + self.assertTrue(hasattr(result.typetoalter, 'a')) + self.assertTrue(hasattr(result.typetoalter, 'b')) +