@@ -127,6 +127,12 @@ else:
|
||||
PROTOCOL_VERSION = int(os.getenv('PROTOCOL_VERSION', default_protocol_version))
|
||||
|
||||
notprotocolv1 = unittest.skipUnless(PROTOCOL_VERSION > 1, 'Protocol v1 not supported')
|
||||
lessthenprotocolv4 = unittest.skipUnless(PROTOCOL_VERSION < 4, 'Protocol versions 4 or greater no supported')
|
||||
greaterthanprotocolv3 = unittest.skipUnless(PROTOCOL_VERSION >= 4, 'Protocol versions less than 4 are not supported')
|
||||
|
||||
greaterthancass20 = unittest.skipUnless(CASSANDRA_VERSION >= '2.1', 'Cassandra version 2.1 or greater required')
|
||||
greaterthanorequalcass30 = unittest.skipUnless(CASSANDRA_VERSION >= '3.0', 'Cassandra version 3.0 or greater required')
|
||||
lessthancass30 = unittest.skipUnless(CASSANDRA_VERSION < '3.0', 'Cassandra version less then 3.0 required')
|
||||
|
||||
|
||||
def get_cluster():
|
||||
@@ -171,6 +177,7 @@ def remove_cluster():
|
||||
|
||||
raise RuntimeError("Failed to remove cluster after 100 attempts")
|
||||
|
||||
|
||||
def is_current_cluster(cluster_name, node_counts):
|
||||
global CCM_CLUSTER
|
||||
if CCM_CLUSTER and CCM_CLUSTER.name == cluster_name:
|
||||
@@ -395,12 +402,13 @@ class BasicKeyspaceUnitTestCase(unittest.TestCase):
|
||||
execute_with_long_wait_retry(cls.session, ddl)
|
||||
|
||||
@classmethod
|
||||
def common_setup(cls, rf, create_class_table=False, skip_if_cass_version_less_than=None):
|
||||
def common_setup(cls, rf, keyspace_creation=True, create_class_table=False):
|
||||
cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
cls.session = cls.cluster.connect()
|
||||
cls.ks_name = cls.__name__.lower()
|
||||
cls.create_keyspace(rf)
|
||||
cls.cass_version = get_server_versions()
|
||||
if keyspace_creation:
|
||||
cls.create_keyspace(rf)
|
||||
cls.cass_version, cls.cql_version = get_server_versions()
|
||||
|
||||
if create_class_table:
|
||||
|
||||
@@ -422,6 +430,19 @@ class BasicKeyspaceUnitTestCase(unittest.TestCase):
|
||||
execute_until_pass(self.session, ddl)
|
||||
|
||||
|
||||
class BasicExistingKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase):
|
||||
"""
|
||||
This is basic unit test defines class level teardown and setup methods. It assumes that keyspace is already defined, or created as part of the test.
|
||||
"""
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.common_setup(1, keyspace_creation=False)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.cluster.shutdown()
|
||||
|
||||
|
||||
class BasicSharedKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase):
|
||||
"""
|
||||
This is basic unit test case that can be leveraged to scope a keyspace to a specific test class.
|
||||
@@ -433,7 +454,7 @@ class BasicSharedKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
drop_keyspace_shutdown_cluster(cls.keyspace_name, cls.session, cls.cluster)
|
||||
drop_keyspace_shutdown_cluster(cls.ks_name, cls.session, cls.cluster)
|
||||
|
||||
|
||||
class BasicSharedKeyspaceUnitTestCaseWTable(BasicSharedKeyspaceUnitTestCase):
|
||||
@@ -510,4 +531,17 @@ class BasicSegregatedKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase):
|
||||
self.common_setup(1)
|
||||
|
||||
def tearDown(self):
|
||||
drop_keyspace_shutdown_cluster(self.keyspace_name, self.session, self.cluster)
|
||||
drop_keyspace_shutdown_cluster(self.ks_name, self.session, self.cluster)
|
||||
|
||||
|
||||
class BasicExistingSegregatedKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase):
|
||||
"""
|
||||
This unit test will create and teardown or each individual unit tests. It assumes that keyspace is existing
|
||||
or created as part of a test.
|
||||
It has some overhead and should only be used when sharing cluster/session is not feasible.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.common_setup(1, keyspace_creation=False)
|
||||
|
||||
def tearDown(self):
|
||||
self.cluster.shutdown()
|
||||
|
||||
@@ -29,7 +29,7 @@ from tests.integration.cqlengine.base import BaseCassEngTestCase
|
||||
from tests.integration.cqlengine.query.test_queryset import BaseQuerySetUsage
|
||||
|
||||
|
||||
from tests.integration import BasicSharedKeyspaceUnitTestCase, get_server_versions
|
||||
from tests.integration import BasicSharedKeyspaceUnitTestCase, greaterthanorequalcass30
|
||||
|
||||
|
||||
class TestQuerySetOperation(BaseCassEngTestCase):
|
||||
@@ -270,18 +270,21 @@ class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage):
|
||||
self.table.objects.get(test_id=1)
|
||||
|
||||
|
||||
class TestNamedWithMV(BaseCassEngTestCase):
|
||||
class TestNamedWithMV(BasicSharedKeyspaceUnitTestCase):
|
||||
|
||||
def setUp(self):
|
||||
cass_version = get_server_versions()[0]
|
||||
if cass_version < (3, 0):
|
||||
raise unittest.SkipTest("Materialized views require Cassandra 3.0+")
|
||||
super(TestNamedWithMV, self).setUp()
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(TestNamedWithMV, cls).setUpClass()
|
||||
cls.default_keyspace = models.DEFAULT_KEYSPACE
|
||||
models.DEFAULT_KEYSPACE = cls.ks_name
|
||||
|
||||
def tearDown(self):
|
||||
models.DEFAULT_KEYSPACE = self.default_keyspace
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
models.DEFAULT_KEYSPACE = cls.default_keyspace
|
||||
setup_connection(models.DEFAULT_KEYSPACE)
|
||||
super(TestNamedWithMV, cls).tearDownClass()
|
||||
|
||||
@greaterthanorequalcass30
|
||||
def test_named_table_with_mv(self):
|
||||
"""
|
||||
Test NamedTable access to materialized views
|
||||
|
||||
@@ -18,6 +18,7 @@ from tests.integration.standard.utils import (
|
||||
|
||||
from tests.unit.cython.utils import cythontest, numpytest
|
||||
|
||||
|
||||
def setup_module():
|
||||
use_singledc()
|
||||
update_datatypes()
|
||||
|
||||
@@ -26,48 +26,42 @@ from cassandra.query import (PreparedStatement, BoundStatement, SimpleStatement,
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.policies import HostDistance
|
||||
|
||||
from tests.integration import use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCase, get_server_versions
|
||||
from tests.integration import use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCase, get_server_versions, greaterthanprotocolv3
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def setup_module():
|
||||
print("Setting up module")
|
||||
use_singledc()
|
||||
global CASS_SERVER_VERSION
|
||||
CASS_SERVER_VERSION = get_server_versions()[0]
|
||||
|
||||
|
||||
class QueryTests(unittest.TestCase):
|
||||
class QueryTests(BasicSharedKeyspaceUnitTestCase):
|
||||
|
||||
def test_query(self):
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
prepared = session.prepare(
|
||||
prepared = self.session.prepare(
|
||||
"""
|
||||
INSERT INTO test3rf.test (k, v) VALUES (?, ?)
|
||||
""")
|
||||
""".format(self.keyspace_name))
|
||||
|
||||
self.assertIsInstance(prepared, PreparedStatement)
|
||||
bound = prepared.bind((1, None))
|
||||
self.assertIsInstance(bound, BoundStatement)
|
||||
self.assertEqual(2, len(bound.values))
|
||||
session.execute(bound)
|
||||
self.session.execute(bound)
|
||||
self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')
|
||||
|
||||
cluster.shutdown()
|
||||
|
||||
def test_trace_prints_okay(self):
|
||||
"""
|
||||
Code coverage to ensure trace prints to string without error
|
||||
"""
|
||||
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
query = "SELECT * FROM system.local"
|
||||
statement = SimpleStatement(query)
|
||||
rs = session.execute(statement, trace=True)
|
||||
rs = self.session.execute(statement, trace=True)
|
||||
|
||||
# Ensure this does not throw an exception
|
||||
trace = rs.get_query_trace()
|
||||
@@ -76,13 +70,9 @@ class QueryTests(unittest.TestCase):
|
||||
for event in trace.events:
|
||||
str(event)
|
||||
|
||||
cluster.shutdown()
|
||||
|
||||
def test_trace_id_to_resultset(self):
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
future = session.execute_async("SELECT * FROM system.local", trace=True)
|
||||
future = self.session.execute_async("SELECT * FROM system.local", trace=True)
|
||||
|
||||
# future should have the current trace
|
||||
rs = future.result()
|
||||
@@ -96,16 +86,12 @@ class QueryTests(unittest.TestCase):
|
||||
|
||||
self.assertListEqual([rs_trace], rs.get_all_query_traces())
|
||||
|
||||
cluster.shutdown()
|
||||
|
||||
def test_trace_ignores_row_factory(self):
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
session.row_factory = dict_factory
|
||||
self.session.row_factory = dict_factory
|
||||
|
||||
query = "SELECT * FROM system.local"
|
||||
statement = SimpleStatement(query)
|
||||
rs = session.execute(statement, trace=True)
|
||||
rs = self.session.execute(statement, trace=True)
|
||||
|
||||
# Ensure this does not throw an exception
|
||||
trace = rs.get_query_trace()
|
||||
@@ -114,8 +100,7 @@ class QueryTests(unittest.TestCase):
|
||||
for event in trace.events:
|
||||
str(event)
|
||||
|
||||
cluster.shutdown()
|
||||
|
||||
@greaterthanprotocolv3
|
||||
def test_client_ip_in_trace(self):
|
||||
"""
|
||||
Test to validate that client trace contains client ip information.
|
||||
@@ -136,18 +121,10 @@ class QueryTests(unittest.TestCase):
|
||||
# raise unittest.SkipTest("Client IP was not present in trace until C* 2.2")
|
||||
"""
|
||||
|
||||
if PROTOCOL_VERSION < 4:
|
||||
raise unittest.SkipTest(
|
||||
"Protocol 4+ is required for client ip tracing, currently testing against %r"
|
||||
% (PROTOCOL_VERSION,))
|
||||
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
# Make simple query with trace enabled
|
||||
query = "SELECT * FROM system.local"
|
||||
statement = SimpleStatement(query)
|
||||
response_future = session.execute_async(statement, trace=True)
|
||||
response_future = self.session.execute_async(statement, trace=True)
|
||||
response_future.result()
|
||||
|
||||
# Fetch the client_ip from the trace.
|
||||
@@ -161,7 +138,29 @@ class QueryTests(unittest.TestCase):
|
||||
self.assertIsNotNone(client_ip, "Client IP was not set in trace with C* >= 2.2")
|
||||
self.assertTrue(pat.match(client_ip), "Client IP from trace did not match the expected value")
|
||||
|
||||
cluster.shutdown()
|
||||
def test_column_names(self):
|
||||
"""
|
||||
Test to validate the columns are present on the result set.
|
||||
Preforms a simple query against a table then checks to ensure column names are correct and present and correct.
|
||||
|
||||
@since 3.0.0
|
||||
@jira_ticket PYTHON-439
|
||||
@expected_result column_names should be preset.
|
||||
|
||||
@test_category queries basic
|
||||
"""
|
||||
create_table = """CREATE TABLE {0}.{1}(
|
||||
user TEXT,
|
||||
game TEXT,
|
||||
year INT,
|
||||
month INT,
|
||||
day INT,
|
||||
score INT,
|
||||
PRIMARY KEY (user, game, year, month, day)
|
||||
)""".format(self.keyspace_name, self.function_table_name)
|
||||
self.session.execute(create_table)
|
||||
result_set = self.session.execute("SELECT * FROM {0}.{1}".format(self.keyspace_name, self.function_table_name))
|
||||
self.assertEqual(result_set.column_names, [u'user', u'game', u'year', u'month', u'day', u'score'])
|
||||
|
||||
|
||||
class PreparedStatementTests(unittest.TestCase):
|
||||
|
||||
@@ -27,8 +27,10 @@ from cassandra.concurrent import execute_concurrent_with_args
|
||||
from cassandra.cqltypes import Int32Type, EMPTY
|
||||
from cassandra.query import dict_factory, ordered_dict_factory
|
||||
from cassandra.util import sortedset
|
||||
from tests.unit.cython.utils import cythontest
|
||||
|
||||
from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION, execute_until_pass, notprotocolv1
|
||||
from tests.integration import use_singledc, PROTOCOL_VERSION, execute_until_pass, notprotocolv1, \
|
||||
BasicSharedKeyspaceUnitTestCase, greaterthancass20, lessthancass30
|
||||
from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, COLLECTION_TYPES, \
|
||||
get_sample, get_collection_sample
|
||||
|
||||
@@ -38,21 +40,14 @@ def setup_module():
|
||||
update_datatypes()
|
||||
|
||||
|
||||
class TypeTests(unittest.TestCase):
|
||||
class TypeTests(BasicSharedKeyspaceUnitTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._cass_version, cls._cql_version = get_server_versions()
|
||||
cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
cls.session = cls.cluster.connect()
|
||||
cls.session.execute("CREATE KEYSPACE typetests WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}")
|
||||
cls.session.set_keyspace("typetests")
|
||||
# cls._cass_version, cls. = get_server_versions()
|
||||
super(TypeTests, cls).setUpClass()
|
||||
cls.session.set_keyspace(cls.ks_name)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
execute_until_pass(cls.session, "DROP KEYSPACE typetests")
|
||||
cls.cluster.shutdown()
|
||||
|
||||
def test_can_insert_blob_type_as_string(self):
|
||||
"""
|
||||
Tests that byte strings in Python maps to blob type in Cassandra
|
||||
@@ -66,10 +61,10 @@ class TypeTests(unittest.TestCase):
|
||||
|
||||
# In python2, with Cassandra > 2.0, we don't treat the 'byte str' type as a blob, so we'll encode it
|
||||
# as a string literal and have the following failure.
|
||||
if six.PY2 and self._cql_version >= (3, 1, 0):
|
||||
if six.PY2 and self.cql_version >= (3, 1, 0):
|
||||
# Blob values can't be specified using string notation in CQL 3.1.0 and
|
||||
# above which is used by default in Cassandra 2.0.
|
||||
if self._cass_version >= (2, 1, 0):
|
||||
if self.cass_version >= (2, 1, 0):
|
||||
msg = r'.*Invalid STRING constant \(.*?\) for "b" of type blob.*'
|
||||
else:
|
||||
msg = r'.*Invalid STRING constant \(.*?\) for b of type blob.*'
|
||||
@@ -108,7 +103,7 @@ class TypeTests(unittest.TestCase):
|
||||
Test insertion of all datatype primitives
|
||||
"""
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect("typetests")
|
||||
s = c.connect(self.keyspace_name)
|
||||
|
||||
# create table
|
||||
alpha_type_list = ["zz int PRIMARY KEY"]
|
||||
@@ -167,7 +162,7 @@ class TypeTests(unittest.TestCase):
|
||||
"""
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect("typetests")
|
||||
s = c.connect(self.keyspace_name)
|
||||
# use tuple encoding, to convert native python tuple into raw CQL
|
||||
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
|
||||
|
||||
@@ -394,11 +389,11 @@ class TypeTests(unittest.TestCase):
|
||||
Basic test of tuple functionality
|
||||
"""
|
||||
|
||||
if self._cass_version < (2, 1, 0):
|
||||
if self.cass_version < (2, 1, 0):
|
||||
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect("typetests")
|
||||
s = c.connect(self.keyspace_name)
|
||||
|
||||
# use this encoder in order to insert tuples
|
||||
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
|
||||
@@ -446,11 +441,11 @@ class TypeTests(unittest.TestCase):
|
||||
as expected.
|
||||
"""
|
||||
|
||||
if self._cass_version < (2, 1, 0):
|
||||
if self.cass_version < (2, 1, 0):
|
||||
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect("typetests")
|
||||
s = c.connect(self.keyspace_name)
|
||||
|
||||
# set the row_factory to dict_factory for programmatic access
|
||||
# set the encoder for tuples for the ability to write tuples
|
||||
@@ -485,11 +480,11 @@ class TypeTests(unittest.TestCase):
|
||||
Ensure tuple subtypes are appropriately handled.
|
||||
"""
|
||||
|
||||
if self._cass_version < (2, 1, 0):
|
||||
if self.cass_version < (2, 1, 0):
|
||||
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect("typetests")
|
||||
s = c.connect(self.keyspace_name)
|
||||
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
|
||||
|
||||
s.execute("CREATE TABLE tuple_primitive ("
|
||||
@@ -513,11 +508,11 @@ class TypeTests(unittest.TestCase):
|
||||
Ensure tuple subtypes are appropriately handled for maps, sets, and lists.
|
||||
"""
|
||||
|
||||
if self._cass_version < (2, 1, 0):
|
||||
if self.cass_version < (2, 1, 0):
|
||||
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect("typetests")
|
||||
s = c.connect(self.keyspace_name)
|
||||
|
||||
# set the row_factory to dict_factory for programmatic access
|
||||
# set the encoder for tuples for the ability to write tuples
|
||||
@@ -612,11 +607,11 @@ class TypeTests(unittest.TestCase):
|
||||
Ensure nested are appropriately handled.
|
||||
"""
|
||||
|
||||
if self._cass_version < (2, 1, 0):
|
||||
if self.cass_version < (2, 1, 0):
|
||||
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect("typetests")
|
||||
s = c.connect(self.keyspace_name)
|
||||
|
||||
# set the row_factory to dict_factory for programmatic access
|
||||
# set the encoder for tuples for the ability to write tuples
|
||||
@@ -652,7 +647,7 @@ class TypeTests(unittest.TestCase):
|
||||
Test tuples with null and empty string fields.
|
||||
"""
|
||||
|
||||
if self._cass_version < (2, 1, 0):
|
||||
if self.cass_version < (2, 1, 0):
|
||||
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
||||
|
||||
s = self.session
|
||||
@@ -746,3 +741,116 @@ class TypeTests(unittest.TestCase):
|
||||
# prepared binding
|
||||
verify_insert_select(s.prepare('INSERT INTO float_cql_encoding (f, d) VALUES (?, ?)'),
|
||||
s.prepare('SELECT * FROM float_cql_encoding WHERE f=?'))
|
||||
|
||||
@cythontest
|
||||
def test_cython_decimal(self):
|
||||
"""
|
||||
Test to validate that decimal deserialization works correctly in with our cython extensions
|
||||
|
||||
@since 3.0.0
|
||||
@jira_ticket PYTHON-212
|
||||
@expected_result no exceptions are thrown, decimal is decoded correctly
|
||||
|
||||
@test_category data_types serialization
|
||||
"""
|
||||
|
||||
self.session.execute("CREATE TABLE {0} (dc decimal PRIMARY KEY)".format(self.function_table_name))
|
||||
try:
|
||||
self.session.execute("INSERT INTO {0} (dc) VALUES (-1.08430792318105707)".format(self.function_table_name))
|
||||
results = self.session.execute("SELECT * FROM {0}".format(self.function_table_name))
|
||||
self.assertTrue(str(results[0].dc) == '-1.08430792318105707')
|
||||
finally:
|
||||
self.session.execute("DROP TABLE {0}".format(self.function_table_name))
|
||||
|
||||
|
||||
class TypeTestsProtocol(BasicSharedKeyspaceUnitTestCase):
|
||||
|
||||
@greaterthancass20
|
||||
@lessthancass30
|
||||
def test_nested_types_with_protocol_version(self):
|
||||
"""
|
||||
Test to validate that nested type serialization works on various protocol versions. Provided
|
||||
the version of cassandra is greater the 2.1.3 we would expect to nested to types to work at all protocol versions.
|
||||
|
||||
@since 3.0.0
|
||||
@jira_ticket PYTHON-215
|
||||
@expected_result no exceptions are thrown
|
||||
|
||||
@test_category data_types serialization
|
||||
"""
|
||||
ddl = '''CREATE TABLE {0}.t (
|
||||
k int PRIMARY KEY,
|
||||
v list<frozen<set<int>>>)'''.format(self.keyspace_name)
|
||||
|
||||
self.session.execute(ddl)
|
||||
ddl = '''CREATE TABLE {0}.u (
|
||||
k int PRIMARY KEY,
|
||||
v set<frozen<list<int>>>)'''.format(self.keyspace_name)
|
||||
self.session.execute(ddl)
|
||||
ddl = '''CREATE TABLE {0}.v (
|
||||
k int PRIMARY KEY,
|
||||
v map<frozen<set<int>>, frozen<list<int>>>,
|
||||
v1 frozen<tuple<int, text>>)'''.format(self.keyspace_name)
|
||||
self.session.execute(ddl)
|
||||
|
||||
self.session.execute("CREATE TYPE {0}.typ (v0 frozen<map<int, frozen<list<int>>>>, v1 frozen<list<int>>)".format(self.keyspace_name))
|
||||
|
||||
ddl = '''CREATE TABLE {0}.w (
|
||||
k int PRIMARY KEY,
|
||||
v frozen<typ>)'''.format(self.keyspace_name)
|
||||
|
||||
self.session.execute(ddl)
|
||||
|
||||
for pvi in range(1, 5):
|
||||
self.run_inserts_at_version(pvi)
|
||||
for pvr in range(1, 5):
|
||||
self.read_inserts_at_level(pvr)
|
||||
|
||||
def print_results(self, results):
|
||||
print("printing results")
|
||||
print(str(results.v))
|
||||
|
||||
def read_inserts_at_level(self, proto_ver):
|
||||
session = Cluster(protocol_version=proto_ver).connect(self.keyspace_name)
|
||||
try:
|
||||
print("reading at version {0}".format(proto_ver))
|
||||
results = session.execute('select * from t')[0]
|
||||
self.print_results(results)
|
||||
self.assertEqual("[SortedSet([1, 2]), SortedSet([3, 5])]", str(results.v))
|
||||
|
||||
results = session.execute('select * from u')[0]
|
||||
self.print_results(results)
|
||||
self.assertEqual("SortedSet([[1, 2], [3, 5]])", str(results.v))
|
||||
|
||||
results = session.execute('select * from v')[0]
|
||||
self.print_results(results)
|
||||
self.assertEqual("{SortedSet([1, 2]): [1, 2, 3], SortedSet([3, 5]): [4, 5, 6]}", str(results.v))
|
||||
|
||||
results = session.execute('select * from w')[0]
|
||||
self.print_results(results)
|
||||
self.assertEqual("typ(v0=OrderedMapSerializedKey([(1, [1, 2, 3]), (2, [4, 5, 6])]), v1=[7, 8, 9])", str(results.v))
|
||||
|
||||
finally:
|
||||
session.cluster.shutdown()
|
||||
|
||||
def run_inserts_at_version(self, proto_ver):
|
||||
session = Cluster(protocol_version=proto_ver).connect(self.keyspace_name)
|
||||
try:
|
||||
print("running inserts with protocol version {0}".format(str(proto_ver)))
|
||||
p = session.prepare('insert into t (k, v) values (?, ?)')
|
||||
session.execute(p, (0, [{1, 2}, {3, 5}]))
|
||||
|
||||
p = session.prepare('insert into u (k, v) values (?, ?)')
|
||||
session.execute(p, (0, {(1, 2), (3, 5)}))
|
||||
|
||||
p = session.prepare('insert into v (k, v, v1) values (?, ?, ?)')
|
||||
session.execute(p, (0, {(1, 2): [1, 2, 3], (3, 5): [4, 5, 6]}, (123, 'four')))
|
||||
|
||||
p = session.prepare('insert into w (k, v) values (?, ?)')
|
||||
session.execute(p, (0, ({1: [1, 2, 3], 2: [4, 5, 6]}, [7, 8, 9])))
|
||||
|
||||
finally:
|
||||
session.cluster.shutdown()
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from cassandra.cluster import Cluster, UserTypeDoesNotExist
|
||||
from cassandra.query import dict_factory
|
||||
from cassandra.util import OrderedMap
|
||||
|
||||
from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION, execute_until_pass
|
||||
from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION, execute_until_pass, BasicSegregatedKeyspaceUnitTestCase, greaterthancass20
|
||||
from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, COLLECTION_TYPES, \
|
||||
get_sample, get_collection_sample
|
||||
|
||||
@@ -39,27 +39,16 @@ def setup_module():
|
||||
update_datatypes()
|
||||
|
||||
|
||||
class UDTTests(unittest.TestCase):
|
||||
@greaterthancass20
|
||||
class UDTTests(BasicSegregatedKeyspaceUnitTestCase):
|
||||
|
||||
@property
|
||||
def table_name(self):
|
||||
return self._testMethodName.lower()
|
||||
|
||||
def setUp(self):
|
||||
self._cass_version, self._cql_version = get_server_versions()
|
||||
|
||||
if self._cass_version < (2, 1, 0):
|
||||
raise unittest.SkipTest("User Defined Types were introduced in Cassandra 2.1")
|
||||
|
||||
self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
self.session = self.cluster.connect()
|
||||
execute_until_pass(self.session,
|
||||
"CREATE KEYSPACE udttests WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}")
|
||||
self.session.set_keyspace("udttests")
|
||||
|
||||
def tearDown(self):
|
||||
execute_until_pass(self.session, "DROP KEYSPACE udttests")
|
||||
self.cluster.shutdown()
|
||||
super(UDTTests, self).setUp()
|
||||
self.session.set_keyspace(self.keyspace_name)
|
||||
|
||||
def test_can_insert_unprepared_registered_udts(self):
|
||||
"""
|
||||
@@ -67,13 +56,13 @@ class UDTTests(unittest.TestCase):
|
||||
"""
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect("udttests")
|
||||
s = c.connect(self.keyspace_name)
|
||||
|
||||
s.execute("CREATE TYPE user (age int, name text)")
|
||||
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen<user>)")
|
||||
|
||||
User = namedtuple('user', ('age', 'name'))
|
||||
c.register_user_type("udttests", "user", User)
|
||||
c.register_user_type(self.keyspace_name, "user", User)
|
||||
|
||||
s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User(42, 'bob')))
|
||||
result = s.execute("SELECT b FROM mytable WHERE a=0")
|
||||
@@ -169,7 +158,7 @@ class UDTTests(unittest.TestCase):
|
||||
"""
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect("udttests")
|
||||
s = c.connect(self.keyspace_name)
|
||||
|
||||
s.execute("CREATE TYPE user (age int, name text)")
|
||||
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen<user>)")
|
||||
@@ -213,11 +202,11 @@ class UDTTests(unittest.TestCase):
|
||||
"""
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect("udttests")
|
||||
s = c.connect(self.keyspace_name)
|
||||
|
||||
s.execute("CREATE TYPE user (age int, name text)")
|
||||
User = namedtuple('user', ('age', 'name'))
|
||||
c.register_user_type("udttests", "user", User)
|
||||
c.register_user_type(self.keyspace_name, "user", User)
|
||||
|
||||
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen<user>)")
|
||||
|
||||
@@ -263,11 +252,11 @@ class UDTTests(unittest.TestCase):
|
||||
"""
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect("udttests")
|
||||
s = c.connect(self.keyspace_name)
|
||||
|
||||
s.execute("CREATE TYPE user (a text, b int, c uuid, d blob)")
|
||||
User = namedtuple('user', ('a', 'b', 'c', 'd'))
|
||||
c.register_user_type("udttests", "user", User)
|
||||
c.register_user_type(self.keyspace_name, "user", User)
|
||||
|
||||
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen<user>)")
|
||||
|
||||
@@ -293,7 +282,7 @@ class UDTTests(unittest.TestCase):
|
||||
"""
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect("udttests")
|
||||
s = c.connect(self.keyspace_name)
|
||||
|
||||
MAX_TEST_LENGTH = 254
|
||||
|
||||
@@ -310,7 +299,7 @@ class UDTTests(unittest.TestCase):
|
||||
|
||||
# create and register the seed udt type
|
||||
udt = namedtuple('lengthy_udt', tuple(['v_{0}'.format(i) for i in range(MAX_TEST_LENGTH)]))
|
||||
c.register_user_type("udttests", "lengthy_udt", udt)
|
||||
c.register_user_type(self.keyspace_name, "lengthy_udt", udt)
|
||||
|
||||
# verify inserts and reads
|
||||
for i in (0, 1, 2, 3, MAX_TEST_LENGTH):
|
||||
@@ -377,7 +366,7 @@ class UDTTests(unittest.TestCase):
|
||||
"""
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect("udttests")
|
||||
s = c.connect(self.keyspace_name)
|
||||
s.row_factory = dict_factory
|
||||
|
||||
MAX_NESTING_DEPTH = 16
|
||||
@@ -389,13 +378,13 @@ class UDTTests(unittest.TestCase):
|
||||
udts = []
|
||||
udt = namedtuple('depth_0', ('age', 'name'))
|
||||
udts.append(udt)
|
||||
c.register_user_type("udttests", "depth_0", udts[0])
|
||||
c.register_user_type(self.keyspace_name, "depth_0", udts[0])
|
||||
|
||||
# create and register the nested udt types
|
||||
for i in range(MAX_NESTING_DEPTH):
|
||||
udt = namedtuple('depth_{0}'.format(i + 1), ('value'))
|
||||
udts.append(udt)
|
||||
c.register_user_type("udttests", "depth_{0}".format(i + 1), udts[i + 1])
|
||||
c.register_user_type(self.keyspace_name, "depth_{0}".format(i + 1), udts[i + 1])
|
||||
|
||||
# insert udts and verify inserts with reads
|
||||
self.nested_udt_verification_helper(s, MAX_NESTING_DEPTH, udts)
|
||||
@@ -408,7 +397,7 @@ class UDTTests(unittest.TestCase):
|
||||
"""
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect("udttests")
|
||||
s = c.connect(self.keyspace_name)
|
||||
s.row_factory = dict_factory
|
||||
|
||||
MAX_NESTING_DEPTH = 16
|
||||
@@ -448,7 +437,7 @@ class UDTTests(unittest.TestCase):
|
||||
"""
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect("udttests")
|
||||
s = c.connect(self.keyspace_name)
|
||||
s.row_factory = dict_factory
|
||||
|
||||
MAX_NESTING_DEPTH = 16
|
||||
@@ -460,13 +449,13 @@ class UDTTests(unittest.TestCase):
|
||||
udts = []
|
||||
udt = namedtuple('level_0', ('age', 'name'))
|
||||
udts.append(udt)
|
||||
c.register_user_type("udttests", "depth_0", udts[0])
|
||||
c.register_user_type(self.keyspace_name, "depth_0", udts[0])
|
||||
|
||||
# create and register the nested udt types
|
||||
for i in range(MAX_NESTING_DEPTH):
|
||||
udt = namedtuple('level_{0}'.format(i + 1), ('value'))
|
||||
udts.append(udt)
|
||||
c.register_user_type("udttests", "depth_{0}".format(i + 1), udts[i + 1])
|
||||
c.register_user_type(self.keyspace_name, "depth_{0}".format(i + 1), udts[i + 1])
|
||||
|
||||
# insert udts and verify inserts with reads
|
||||
self.nested_udt_verification_helper(s, MAX_NESTING_DEPTH, udts)
|
||||
@@ -479,7 +468,7 @@ class UDTTests(unittest.TestCase):
|
||||
"""
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect("udttests")
|
||||
s = c.connect(self.keyspace_name)
|
||||
User = namedtuple('user', ('age', 'name'))
|
||||
|
||||
with self.assertRaises(UserTypeDoesNotExist):
|
||||
@@ -499,7 +488,7 @@ class UDTTests(unittest.TestCase):
|
||||
"""
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect("udttests")
|
||||
s = c.connect(self.keyspace_name)
|
||||
|
||||
# create UDT
|
||||
alpha_type_list = []
|
||||
@@ -510,7 +499,7 @@ class UDTTests(unittest.TestCase):
|
||||
s.execute("""
|
||||
CREATE TYPE alldatatypes ({0})
|
||||
""".format(', '.join(alpha_type_list))
|
||||
)
|
||||
)
|
||||
|
||||
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen<alldatatypes>)")
|
||||
|
||||
@@ -519,7 +508,7 @@ class UDTTests(unittest.TestCase):
|
||||
for i in range(ord('a'), ord('a') + len(PRIMITIVE_DATATYPES)):
|
||||
alphabet_list.append('{0}'.format(chr(i)))
|
||||
Alldatatypes = namedtuple("alldatatypes", alphabet_list)
|
||||
c.register_user_type("udttests", "alldatatypes", Alldatatypes)
|
||||
c.register_user_type(self.keyspace_name, "alldatatypes", Alldatatypes)
|
||||
|
||||
# insert UDT data
|
||||
params = []
|
||||
@@ -544,7 +533,7 @@ class UDTTests(unittest.TestCase):
|
||||
"""
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect("udttests")
|
||||
s = c.connect(self.keyspace_name)
|
||||
|
||||
# create UDT
|
||||
alpha_type_list = []
|
||||
@@ -576,7 +565,7 @@ class UDTTests(unittest.TestCase):
|
||||
alphabet_list.append('{0}_{1}'.format(chr(i), chr(j)))
|
||||
|
||||
Alldatatypes = namedtuple("alldatatypes", alphabet_list)
|
||||
c.register_user_type("udttests", "alldatatypes", Alldatatypes)
|
||||
c.register_user_type(self.keyspace_name, "alldatatypes", Alldatatypes)
|
||||
|
||||
# insert UDT data
|
||||
params = []
|
||||
@@ -607,11 +596,11 @@ class UDTTests(unittest.TestCase):
|
||||
Test for inserting various types of nested COLLECTION_TYPES into tables and UDTs
|
||||
"""
|
||||
|
||||
if self._cass_version < (2, 1, 3):
|
||||
if self.cass_version < (2, 1, 3):
|
||||
raise unittest.SkipTest("Support for nested collections was introduced in Cassandra 2.1.3")
|
||||
|
||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
s = c.connect("udttests")
|
||||
s = c.connect(self.keyspace_name)
|
||||
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
|
||||
|
||||
name = self._testMethodName
|
||||
@@ -710,3 +699,36 @@ class UDTTests(unittest.TestCase):
|
||||
val = s.execute('SELECT v FROM %s' % self.table_name)[0][0]
|
||||
self.assertEqual(val['v0'], 3)
|
||||
self.assertEqual(val['v1'], six.b('\xde\xad\xbe\xef'))
|
||||
|
||||
def test_alter_udt(self):
|
||||
"""
|
||||
Test to ensure that altered UDT's are properly surfaced without needing to restart the underlying session.
|
||||
|
||||
@since 3.0.0
|
||||
@jira_ticket PYTHON-226
|
||||
@expected_result UDT's will reflect added columns without a session restart.
|
||||
|
||||
@test_category data_types, udt
|
||||
"""
|
||||
|
||||
# Create udt ensure it has the proper column names.
|
||||
self.session.set_keyspace(self.keyspace_name)
|
||||
self.session.execute("CREATE TYPE typetoalter (a int)")
|
||||
typetoalter = namedtuple('typetoalter', ('a'))
|
||||
self.session.execute("CREATE TABLE {0} (pk int primary key, typetoalter frozen<typetoalter>)".format(self.function_table_name))
|
||||
insert_statement = self.session.prepare("INSERT INTO {0} (pk, typetoalter) VALUES (?, ?)".format(self.function_table_name))
|
||||
self.session.execute(insert_statement, [1, typetoalter(1)])
|
||||
results = self.session.execute("SELECT * from {0}".format(self.function_table_name))
|
||||
for result in results:
|
||||
self.assertTrue(hasattr(result.typetoalter, 'a'))
|
||||
self.assertFalse(hasattr(result.typetoalter, 'b'))
|
||||
|
||||
# Alter UDT and ensure the alter is honored in results
|
||||
self.session.execute("ALTER TYPE typetoalter add b int")
|
||||
typetoalter = namedtuple('typetoalter', ('a', 'b'))
|
||||
self.session.execute(insert_statement, [2, typetoalter(2, 2)])
|
||||
results = self.session.execute("SELECT * from {0}".format(self.function_table_name))
|
||||
for result in results:
|
||||
self.assertTrue(hasattr(result.typetoalter, 'a'))
|
||||
self.assertTrue(hasattr(result.typetoalter, 'b'))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user