@@ -127,6 +127,12 @@ else:
|
|||||||
PROTOCOL_VERSION = int(os.getenv('PROTOCOL_VERSION', default_protocol_version))
|
PROTOCOL_VERSION = int(os.getenv('PROTOCOL_VERSION', default_protocol_version))
|
||||||
|
|
||||||
notprotocolv1 = unittest.skipUnless(PROTOCOL_VERSION > 1, 'Protocol v1 not supported')
|
notprotocolv1 = unittest.skipUnless(PROTOCOL_VERSION > 1, 'Protocol v1 not supported')
|
||||||
|
lessthenprotocolv4 = unittest.skipUnless(PROTOCOL_VERSION < 4, 'Protocol versions 4 or greater no supported')
|
||||||
|
greaterthanprotocolv3 = unittest.skipUnless(PROTOCOL_VERSION >= 4, 'Protocol versions less than 4 are not supported')
|
||||||
|
|
||||||
|
greaterthancass20 = unittest.skipUnless(CASSANDRA_VERSION >= '2.1', 'Cassandra version 2.1 or greater required')
|
||||||
|
greaterthanorequalcass30 = unittest.skipUnless(CASSANDRA_VERSION >= '3.0', 'Cassandra version 3.0 or greater required')
|
||||||
|
lessthancass30 = unittest.skipUnless(CASSANDRA_VERSION < '3.0', 'Cassandra version less then 3.0 required')
|
||||||
|
|
||||||
|
|
||||||
def get_cluster():
|
def get_cluster():
|
||||||
@@ -171,6 +177,7 @@ def remove_cluster():
|
|||||||
|
|
||||||
raise RuntimeError("Failed to remove cluster after 100 attempts")
|
raise RuntimeError("Failed to remove cluster after 100 attempts")
|
||||||
|
|
||||||
|
|
||||||
def is_current_cluster(cluster_name, node_counts):
|
def is_current_cluster(cluster_name, node_counts):
|
||||||
global CCM_CLUSTER
|
global CCM_CLUSTER
|
||||||
if CCM_CLUSTER and CCM_CLUSTER.name == cluster_name:
|
if CCM_CLUSTER and CCM_CLUSTER.name == cluster_name:
|
||||||
@@ -395,12 +402,13 @@ class BasicKeyspaceUnitTestCase(unittest.TestCase):
|
|||||||
execute_with_long_wait_retry(cls.session, ddl)
|
execute_with_long_wait_retry(cls.session, ddl)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def common_setup(cls, rf, create_class_table=False, skip_if_cass_version_less_than=None):
|
def common_setup(cls, rf, keyspace_creation=True, create_class_table=False):
|
||||||
cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
cls.session = cls.cluster.connect()
|
cls.session = cls.cluster.connect()
|
||||||
cls.ks_name = cls.__name__.lower()
|
cls.ks_name = cls.__name__.lower()
|
||||||
cls.create_keyspace(rf)
|
if keyspace_creation:
|
||||||
cls.cass_version = get_server_versions()
|
cls.create_keyspace(rf)
|
||||||
|
cls.cass_version, cls.cql_version = get_server_versions()
|
||||||
|
|
||||||
if create_class_table:
|
if create_class_table:
|
||||||
|
|
||||||
@@ -422,6 +430,19 @@ class BasicKeyspaceUnitTestCase(unittest.TestCase):
|
|||||||
execute_until_pass(self.session, ddl)
|
execute_until_pass(self.session, ddl)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicExistingKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase):
|
||||||
|
"""
|
||||||
|
This is basic unit test defines class level teardown and setup methods. It assumes that keyspace is already defined, or created as part of the test.
|
||||||
|
"""
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.common_setup(1, keyspace_creation=False)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
cls.cluster.shutdown()
|
||||||
|
|
||||||
|
|
||||||
class BasicSharedKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase):
|
class BasicSharedKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase):
|
||||||
"""
|
"""
|
||||||
This is basic unit test case that can be leveraged to scope a keyspace to a specific test class.
|
This is basic unit test case that can be leveraged to scope a keyspace to a specific test class.
|
||||||
@@ -433,7 +454,7 @@ class BasicSharedKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
drop_keyspace_shutdown_cluster(cls.keyspace_name, cls.session, cls.cluster)
|
drop_keyspace_shutdown_cluster(cls.ks_name, cls.session, cls.cluster)
|
||||||
|
|
||||||
|
|
||||||
class BasicSharedKeyspaceUnitTestCaseWTable(BasicSharedKeyspaceUnitTestCase):
|
class BasicSharedKeyspaceUnitTestCaseWTable(BasicSharedKeyspaceUnitTestCase):
|
||||||
@@ -510,4 +531,17 @@ class BasicSegregatedKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase):
|
|||||||
self.common_setup(1)
|
self.common_setup(1)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
drop_keyspace_shutdown_cluster(self.keyspace_name, self.session, self.cluster)
|
drop_keyspace_shutdown_cluster(self.ks_name, self.session, self.cluster)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicExistingSegregatedKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase):
|
||||||
|
"""
|
||||||
|
This unit test will create and teardown or each individual unit tests. It assumes that keyspace is existing
|
||||||
|
or created as part of a test.
|
||||||
|
It has some overhead and should only be used when sharing cluster/session is not feasible.
|
||||||
|
"""
|
||||||
|
def setUp(self):
|
||||||
|
self.common_setup(1, keyspace_creation=False)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.cluster.shutdown()
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from tests.integration.cqlengine.base import BaseCassEngTestCase
|
|||||||
from tests.integration.cqlengine.query.test_queryset import BaseQuerySetUsage
|
from tests.integration.cqlengine.query.test_queryset import BaseQuerySetUsage
|
||||||
|
|
||||||
|
|
||||||
from tests.integration import BasicSharedKeyspaceUnitTestCase, get_server_versions
|
from tests.integration import BasicSharedKeyspaceUnitTestCase, greaterthanorequalcass30
|
||||||
|
|
||||||
|
|
||||||
class TestQuerySetOperation(BaseCassEngTestCase):
|
class TestQuerySetOperation(BaseCassEngTestCase):
|
||||||
@@ -270,18 +270,21 @@ class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage):
|
|||||||
self.table.objects.get(test_id=1)
|
self.table.objects.get(test_id=1)
|
||||||
|
|
||||||
|
|
||||||
class TestNamedWithMV(BaseCassEngTestCase):
|
class TestNamedWithMV(BasicSharedKeyspaceUnitTestCase):
|
||||||
|
|
||||||
def setUp(self):
|
@classmethod
|
||||||
cass_version = get_server_versions()[0]
|
def setUpClass(cls):
|
||||||
if cass_version < (3, 0):
|
super(TestNamedWithMV, cls).setUpClass()
|
||||||
raise unittest.SkipTest("Materialized views require Cassandra 3.0+")
|
cls.default_keyspace = models.DEFAULT_KEYSPACE
|
||||||
super(TestNamedWithMV, self).setUp()
|
models.DEFAULT_KEYSPACE = cls.ks_name
|
||||||
|
|
||||||
def tearDown(self):
|
@classmethod
|
||||||
models.DEFAULT_KEYSPACE = self.default_keyspace
|
def tearDownClass(cls):
|
||||||
|
models.DEFAULT_KEYSPACE = cls.default_keyspace
|
||||||
setup_connection(models.DEFAULT_KEYSPACE)
|
setup_connection(models.DEFAULT_KEYSPACE)
|
||||||
|
super(TestNamedWithMV, cls).tearDownClass()
|
||||||
|
|
||||||
|
@greaterthanorequalcass30
|
||||||
def test_named_table_with_mv(self):
|
def test_named_table_with_mv(self):
|
||||||
"""
|
"""
|
||||||
Test NamedTable access to materialized views
|
Test NamedTable access to materialized views
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from tests.integration.standard.utils import (
|
|||||||
|
|
||||||
from tests.unit.cython.utils import cythontest, numpytest
|
from tests.unit.cython.utils import cythontest, numpytest
|
||||||
|
|
||||||
|
|
||||||
def setup_module():
|
def setup_module():
|
||||||
use_singledc()
|
use_singledc()
|
||||||
update_datatypes()
|
update_datatypes()
|
||||||
|
|||||||
@@ -26,48 +26,42 @@ from cassandra.query import (PreparedStatement, BoundStatement, SimpleStatement,
|
|||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
from cassandra.policies import HostDistance
|
from cassandra.policies import HostDistance
|
||||||
|
|
||||||
from tests.integration import use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCase, get_server_versions
|
from tests.integration import use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCase, get_server_versions, greaterthanprotocolv3
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
def setup_module():
|
def setup_module():
|
||||||
|
print("Setting up module")
|
||||||
use_singledc()
|
use_singledc()
|
||||||
global CASS_SERVER_VERSION
|
global CASS_SERVER_VERSION
|
||||||
CASS_SERVER_VERSION = get_server_versions()[0]
|
CASS_SERVER_VERSION = get_server_versions()[0]
|
||||||
|
|
||||||
|
|
||||||
class QueryTests(unittest.TestCase):
|
class QueryTests(BasicSharedKeyspaceUnitTestCase):
|
||||||
|
|
||||||
def test_query(self):
|
def test_query(self):
|
||||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
|
||||||
session = cluster.connect()
|
|
||||||
|
|
||||||
prepared = session.prepare(
|
prepared = self.session.prepare(
|
||||||
"""
|
"""
|
||||||
INSERT INTO test3rf.test (k, v) VALUES (?, ?)
|
INSERT INTO test3rf.test (k, v) VALUES (?, ?)
|
||||||
""")
|
""".format(self.keyspace_name))
|
||||||
|
|
||||||
self.assertIsInstance(prepared, PreparedStatement)
|
self.assertIsInstance(prepared, PreparedStatement)
|
||||||
bound = prepared.bind((1, None))
|
bound = prepared.bind((1, None))
|
||||||
self.assertIsInstance(bound, BoundStatement)
|
self.assertIsInstance(bound, BoundStatement)
|
||||||
self.assertEqual(2, len(bound.values))
|
self.assertEqual(2, len(bound.values))
|
||||||
session.execute(bound)
|
self.session.execute(bound)
|
||||||
self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')
|
self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')
|
||||||
|
|
||||||
cluster.shutdown()
|
|
||||||
|
|
||||||
def test_trace_prints_okay(self):
|
def test_trace_prints_okay(self):
|
||||||
"""
|
"""
|
||||||
Code coverage to ensure trace prints to string without error
|
Code coverage to ensure trace prints to string without error
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
|
||||||
session = cluster.connect()
|
|
||||||
|
|
||||||
query = "SELECT * FROM system.local"
|
query = "SELECT * FROM system.local"
|
||||||
statement = SimpleStatement(query)
|
statement = SimpleStatement(query)
|
||||||
rs = session.execute(statement, trace=True)
|
rs = self.session.execute(statement, trace=True)
|
||||||
|
|
||||||
# Ensure this does not throw an exception
|
# Ensure this does not throw an exception
|
||||||
trace = rs.get_query_trace()
|
trace = rs.get_query_trace()
|
||||||
@@ -76,13 +70,9 @@ class QueryTests(unittest.TestCase):
|
|||||||
for event in trace.events:
|
for event in trace.events:
|
||||||
str(event)
|
str(event)
|
||||||
|
|
||||||
cluster.shutdown()
|
|
||||||
|
|
||||||
def test_trace_id_to_resultset(self):
|
def test_trace_id_to_resultset(self):
|
||||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
|
||||||
session = cluster.connect()
|
|
||||||
|
|
||||||
future = session.execute_async("SELECT * FROM system.local", trace=True)
|
future = self.session.execute_async("SELECT * FROM system.local", trace=True)
|
||||||
|
|
||||||
# future should have the current trace
|
# future should have the current trace
|
||||||
rs = future.result()
|
rs = future.result()
|
||||||
@@ -96,16 +86,12 @@ class QueryTests(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertListEqual([rs_trace], rs.get_all_query_traces())
|
self.assertListEqual([rs_trace], rs.get_all_query_traces())
|
||||||
|
|
||||||
cluster.shutdown()
|
|
||||||
|
|
||||||
def test_trace_ignores_row_factory(self):
|
def test_trace_ignores_row_factory(self):
|
||||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
self.session.row_factory = dict_factory
|
||||||
session = cluster.connect()
|
|
||||||
session.row_factory = dict_factory
|
|
||||||
|
|
||||||
query = "SELECT * FROM system.local"
|
query = "SELECT * FROM system.local"
|
||||||
statement = SimpleStatement(query)
|
statement = SimpleStatement(query)
|
||||||
rs = session.execute(statement, trace=True)
|
rs = self.session.execute(statement, trace=True)
|
||||||
|
|
||||||
# Ensure this does not throw an exception
|
# Ensure this does not throw an exception
|
||||||
trace = rs.get_query_trace()
|
trace = rs.get_query_trace()
|
||||||
@@ -114,8 +100,7 @@ class QueryTests(unittest.TestCase):
|
|||||||
for event in trace.events:
|
for event in trace.events:
|
||||||
str(event)
|
str(event)
|
||||||
|
|
||||||
cluster.shutdown()
|
@greaterthanprotocolv3
|
||||||
|
|
||||||
def test_client_ip_in_trace(self):
|
def test_client_ip_in_trace(self):
|
||||||
"""
|
"""
|
||||||
Test to validate that client trace contains client ip information.
|
Test to validate that client trace contains client ip information.
|
||||||
@@ -136,18 +121,10 @@ class QueryTests(unittest.TestCase):
|
|||||||
# raise unittest.SkipTest("Client IP was not present in trace until C* 2.2")
|
# raise unittest.SkipTest("Client IP was not present in trace until C* 2.2")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if PROTOCOL_VERSION < 4:
|
|
||||||
raise unittest.SkipTest(
|
|
||||||
"Protocol 4+ is required for client ip tracing, currently testing against %r"
|
|
||||||
% (PROTOCOL_VERSION,))
|
|
||||||
|
|
||||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
|
||||||
session = cluster.connect()
|
|
||||||
|
|
||||||
# Make simple query with trace enabled
|
# Make simple query with trace enabled
|
||||||
query = "SELECT * FROM system.local"
|
query = "SELECT * FROM system.local"
|
||||||
statement = SimpleStatement(query)
|
statement = SimpleStatement(query)
|
||||||
response_future = session.execute_async(statement, trace=True)
|
response_future = self.session.execute_async(statement, trace=True)
|
||||||
response_future.result()
|
response_future.result()
|
||||||
|
|
||||||
# Fetch the client_ip from the trace.
|
# Fetch the client_ip from the trace.
|
||||||
@@ -161,7 +138,29 @@ class QueryTests(unittest.TestCase):
|
|||||||
self.assertIsNotNone(client_ip, "Client IP was not set in trace with C* >= 2.2")
|
self.assertIsNotNone(client_ip, "Client IP was not set in trace with C* >= 2.2")
|
||||||
self.assertTrue(pat.match(client_ip), "Client IP from trace did not match the expected value")
|
self.assertTrue(pat.match(client_ip), "Client IP from trace did not match the expected value")
|
||||||
|
|
||||||
cluster.shutdown()
|
def test_column_names(self):
|
||||||
|
"""
|
||||||
|
Test to validate the columns are present on the result set.
|
||||||
|
Preforms a simple query against a table then checks to ensure column names are correct and present and correct.
|
||||||
|
|
||||||
|
@since 3.0.0
|
||||||
|
@jira_ticket PYTHON-439
|
||||||
|
@expected_result column_names should be preset.
|
||||||
|
|
||||||
|
@test_category queries basic
|
||||||
|
"""
|
||||||
|
create_table = """CREATE TABLE {0}.{1}(
|
||||||
|
user TEXT,
|
||||||
|
game TEXT,
|
||||||
|
year INT,
|
||||||
|
month INT,
|
||||||
|
day INT,
|
||||||
|
score INT,
|
||||||
|
PRIMARY KEY (user, game, year, month, day)
|
||||||
|
)""".format(self.keyspace_name, self.function_table_name)
|
||||||
|
self.session.execute(create_table)
|
||||||
|
result_set = self.session.execute("SELECT * FROM {0}.{1}".format(self.keyspace_name, self.function_table_name))
|
||||||
|
self.assertEqual(result_set.column_names, [u'user', u'game', u'year', u'month', u'day', u'score'])
|
||||||
|
|
||||||
|
|
||||||
class PreparedStatementTests(unittest.TestCase):
|
class PreparedStatementTests(unittest.TestCase):
|
||||||
|
|||||||
@@ -27,8 +27,10 @@ from cassandra.concurrent import execute_concurrent_with_args
|
|||||||
from cassandra.cqltypes import Int32Type, EMPTY
|
from cassandra.cqltypes import Int32Type, EMPTY
|
||||||
from cassandra.query import dict_factory, ordered_dict_factory
|
from cassandra.query import dict_factory, ordered_dict_factory
|
||||||
from cassandra.util import sortedset
|
from cassandra.util import sortedset
|
||||||
|
from tests.unit.cython.utils import cythontest
|
||||||
|
|
||||||
from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION, execute_until_pass, notprotocolv1
|
from tests.integration import use_singledc, PROTOCOL_VERSION, execute_until_pass, notprotocolv1, \
|
||||||
|
BasicSharedKeyspaceUnitTestCase, greaterthancass20, lessthancass30
|
||||||
from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, COLLECTION_TYPES, \
|
from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, COLLECTION_TYPES, \
|
||||||
get_sample, get_collection_sample
|
get_sample, get_collection_sample
|
||||||
|
|
||||||
@@ -38,21 +40,14 @@ def setup_module():
|
|||||||
update_datatypes()
|
update_datatypes()
|
||||||
|
|
||||||
|
|
||||||
class TypeTests(unittest.TestCase):
|
class TypeTests(BasicSharedKeyspaceUnitTestCase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls._cass_version, cls._cql_version = get_server_versions()
|
# cls._cass_version, cls. = get_server_versions()
|
||||||
cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
super(TypeTests, cls).setUpClass()
|
||||||
cls.session = cls.cluster.connect()
|
cls.session.set_keyspace(cls.ks_name)
|
||||||
cls.session.execute("CREATE KEYSPACE typetests WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}")
|
|
||||||
cls.session.set_keyspace("typetests")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
execute_until_pass(cls.session, "DROP KEYSPACE typetests")
|
|
||||||
cls.cluster.shutdown()
|
|
||||||
|
|
||||||
def test_can_insert_blob_type_as_string(self):
|
def test_can_insert_blob_type_as_string(self):
|
||||||
"""
|
"""
|
||||||
Tests that byte strings in Python maps to blob type in Cassandra
|
Tests that byte strings in Python maps to blob type in Cassandra
|
||||||
@@ -66,10 +61,10 @@ class TypeTests(unittest.TestCase):
|
|||||||
|
|
||||||
# In python2, with Cassandra > 2.0, we don't treat the 'byte str' type as a blob, so we'll encode it
|
# In python2, with Cassandra > 2.0, we don't treat the 'byte str' type as a blob, so we'll encode it
|
||||||
# as a string literal and have the following failure.
|
# as a string literal and have the following failure.
|
||||||
if six.PY2 and self._cql_version >= (3, 1, 0):
|
if six.PY2 and self.cql_version >= (3, 1, 0):
|
||||||
# Blob values can't be specified using string notation in CQL 3.1.0 and
|
# Blob values can't be specified using string notation in CQL 3.1.0 and
|
||||||
# above which is used by default in Cassandra 2.0.
|
# above which is used by default in Cassandra 2.0.
|
||||||
if self._cass_version >= (2, 1, 0):
|
if self.cass_version >= (2, 1, 0):
|
||||||
msg = r'.*Invalid STRING constant \(.*?\) for "b" of type blob.*'
|
msg = r'.*Invalid STRING constant \(.*?\) for "b" of type blob.*'
|
||||||
else:
|
else:
|
||||||
msg = r'.*Invalid STRING constant \(.*?\) for b of type blob.*'
|
msg = r'.*Invalid STRING constant \(.*?\) for b of type blob.*'
|
||||||
@@ -108,7 +103,7 @@ class TypeTests(unittest.TestCase):
|
|||||||
Test insertion of all datatype primitives
|
Test insertion of all datatype primitives
|
||||||
"""
|
"""
|
||||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
s = c.connect("typetests")
|
s = c.connect(self.keyspace_name)
|
||||||
|
|
||||||
# create table
|
# create table
|
||||||
alpha_type_list = ["zz int PRIMARY KEY"]
|
alpha_type_list = ["zz int PRIMARY KEY"]
|
||||||
@@ -167,7 +162,7 @@ class TypeTests(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
s = c.connect("typetests")
|
s = c.connect(self.keyspace_name)
|
||||||
# use tuple encoding, to convert native python tuple into raw CQL
|
# use tuple encoding, to convert native python tuple into raw CQL
|
||||||
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
|
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
|
||||||
|
|
||||||
@@ -394,11 +389,11 @@ class TypeTests(unittest.TestCase):
|
|||||||
Basic test of tuple functionality
|
Basic test of tuple functionality
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self._cass_version < (2, 1, 0):
|
if self.cass_version < (2, 1, 0):
|
||||||
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
||||||
|
|
||||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
s = c.connect("typetests")
|
s = c.connect(self.keyspace_name)
|
||||||
|
|
||||||
# use this encoder in order to insert tuples
|
# use this encoder in order to insert tuples
|
||||||
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
|
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
|
||||||
@@ -446,11 +441,11 @@ class TypeTests(unittest.TestCase):
|
|||||||
as expected.
|
as expected.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self._cass_version < (2, 1, 0):
|
if self.cass_version < (2, 1, 0):
|
||||||
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
||||||
|
|
||||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
s = c.connect("typetests")
|
s = c.connect(self.keyspace_name)
|
||||||
|
|
||||||
# set the row_factory to dict_factory for programmatic access
|
# set the row_factory to dict_factory for programmatic access
|
||||||
# set the encoder for tuples for the ability to write tuples
|
# set the encoder for tuples for the ability to write tuples
|
||||||
@@ -485,11 +480,11 @@ class TypeTests(unittest.TestCase):
|
|||||||
Ensure tuple subtypes are appropriately handled.
|
Ensure tuple subtypes are appropriately handled.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self._cass_version < (2, 1, 0):
|
if self.cass_version < (2, 1, 0):
|
||||||
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
||||||
|
|
||||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
s = c.connect("typetests")
|
s = c.connect(self.keyspace_name)
|
||||||
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
|
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
|
||||||
|
|
||||||
s.execute("CREATE TABLE tuple_primitive ("
|
s.execute("CREATE TABLE tuple_primitive ("
|
||||||
@@ -513,11 +508,11 @@ class TypeTests(unittest.TestCase):
|
|||||||
Ensure tuple subtypes are appropriately handled for maps, sets, and lists.
|
Ensure tuple subtypes are appropriately handled for maps, sets, and lists.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self._cass_version < (2, 1, 0):
|
if self.cass_version < (2, 1, 0):
|
||||||
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
||||||
|
|
||||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
s = c.connect("typetests")
|
s = c.connect(self.keyspace_name)
|
||||||
|
|
||||||
# set the row_factory to dict_factory for programmatic access
|
# set the row_factory to dict_factory for programmatic access
|
||||||
# set the encoder for tuples for the ability to write tuples
|
# set the encoder for tuples for the ability to write tuples
|
||||||
@@ -612,11 +607,11 @@ class TypeTests(unittest.TestCase):
|
|||||||
Ensure nested are appropriately handled.
|
Ensure nested are appropriately handled.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self._cass_version < (2, 1, 0):
|
if self.cass_version < (2, 1, 0):
|
||||||
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
||||||
|
|
||||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
s = c.connect("typetests")
|
s = c.connect(self.keyspace_name)
|
||||||
|
|
||||||
# set the row_factory to dict_factory for programmatic access
|
# set the row_factory to dict_factory for programmatic access
|
||||||
# set the encoder for tuples for the ability to write tuples
|
# set the encoder for tuples for the ability to write tuples
|
||||||
@@ -652,7 +647,7 @@ class TypeTests(unittest.TestCase):
|
|||||||
Test tuples with null and empty string fields.
|
Test tuples with null and empty string fields.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self._cass_version < (2, 1, 0):
|
if self.cass_version < (2, 1, 0):
|
||||||
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1")
|
||||||
|
|
||||||
s = self.session
|
s = self.session
|
||||||
@@ -746,3 +741,116 @@ class TypeTests(unittest.TestCase):
|
|||||||
# prepared binding
|
# prepared binding
|
||||||
verify_insert_select(s.prepare('INSERT INTO float_cql_encoding (f, d) VALUES (?, ?)'),
|
verify_insert_select(s.prepare('INSERT INTO float_cql_encoding (f, d) VALUES (?, ?)'),
|
||||||
s.prepare('SELECT * FROM float_cql_encoding WHERE f=?'))
|
s.prepare('SELECT * FROM float_cql_encoding WHERE f=?'))
|
||||||
|
|
||||||
|
@cythontest
|
||||||
|
def test_cython_decimal(self):
|
||||||
|
"""
|
||||||
|
Test to validate that decimal deserialization works correctly in with our cython extensions
|
||||||
|
|
||||||
|
@since 3.0.0
|
||||||
|
@jira_ticket PYTHON-212
|
||||||
|
@expected_result no exceptions are thrown, decimal is decoded correctly
|
||||||
|
|
||||||
|
@test_category data_types serialization
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.session.execute("CREATE TABLE {0} (dc decimal PRIMARY KEY)".format(self.function_table_name))
|
||||||
|
try:
|
||||||
|
self.session.execute("INSERT INTO {0} (dc) VALUES (-1.08430792318105707)".format(self.function_table_name))
|
||||||
|
results = self.session.execute("SELECT * FROM {0}".format(self.function_table_name))
|
||||||
|
self.assertTrue(str(results[0].dc) == '-1.08430792318105707')
|
||||||
|
finally:
|
||||||
|
self.session.execute("DROP TABLE {0}".format(self.function_table_name))
|
||||||
|
|
||||||
|
|
||||||
|
class TypeTestsProtocol(BasicSharedKeyspaceUnitTestCase):
|
||||||
|
|
||||||
|
@greaterthancass20
|
||||||
|
@lessthancass30
|
||||||
|
def test_nested_types_with_protocol_version(self):
|
||||||
|
"""
|
||||||
|
Test to validate that nested type serialization works on various protocol versions. Provided
|
||||||
|
the version of cassandra is greater the 2.1.3 we would expect to nested to types to work at all protocol versions.
|
||||||
|
|
||||||
|
@since 3.0.0
|
||||||
|
@jira_ticket PYTHON-215
|
||||||
|
@expected_result no exceptions are thrown
|
||||||
|
|
||||||
|
@test_category data_types serialization
|
||||||
|
"""
|
||||||
|
ddl = '''CREATE TABLE {0}.t (
|
||||||
|
k int PRIMARY KEY,
|
||||||
|
v list<frozen<set<int>>>)'''.format(self.keyspace_name)
|
||||||
|
|
||||||
|
self.session.execute(ddl)
|
||||||
|
ddl = '''CREATE TABLE {0}.u (
|
||||||
|
k int PRIMARY KEY,
|
||||||
|
v set<frozen<list<int>>>)'''.format(self.keyspace_name)
|
||||||
|
self.session.execute(ddl)
|
||||||
|
ddl = '''CREATE TABLE {0}.v (
|
||||||
|
k int PRIMARY KEY,
|
||||||
|
v map<frozen<set<int>>, frozen<list<int>>>,
|
||||||
|
v1 frozen<tuple<int, text>>)'''.format(self.keyspace_name)
|
||||||
|
self.session.execute(ddl)
|
||||||
|
|
||||||
|
self.session.execute("CREATE TYPE {0}.typ (v0 frozen<map<int, frozen<list<int>>>>, v1 frozen<list<int>>)".format(self.keyspace_name))
|
||||||
|
|
||||||
|
ddl = '''CREATE TABLE {0}.w (
|
||||||
|
k int PRIMARY KEY,
|
||||||
|
v frozen<typ>)'''.format(self.keyspace_name)
|
||||||
|
|
||||||
|
self.session.execute(ddl)
|
||||||
|
|
||||||
|
for pvi in range(1, 5):
|
||||||
|
self.run_inserts_at_version(pvi)
|
||||||
|
for pvr in range(1, 5):
|
||||||
|
self.read_inserts_at_level(pvr)
|
||||||
|
|
||||||
|
def print_results(self, results):
|
||||||
|
print("printing results")
|
||||||
|
print(str(results.v))
|
||||||
|
|
||||||
|
def read_inserts_at_level(self, proto_ver):
|
||||||
|
session = Cluster(protocol_version=proto_ver).connect(self.keyspace_name)
|
||||||
|
try:
|
||||||
|
print("reading at version {0}".format(proto_ver))
|
||||||
|
results = session.execute('select * from t')[0]
|
||||||
|
self.print_results(results)
|
||||||
|
self.assertEqual("[SortedSet([1, 2]), SortedSet([3, 5])]", str(results.v))
|
||||||
|
|
||||||
|
results = session.execute('select * from u')[0]
|
||||||
|
self.print_results(results)
|
||||||
|
self.assertEqual("SortedSet([[1, 2], [3, 5]])", str(results.v))
|
||||||
|
|
||||||
|
results = session.execute('select * from v')[0]
|
||||||
|
self.print_results(results)
|
||||||
|
self.assertEqual("{SortedSet([1, 2]): [1, 2, 3], SortedSet([3, 5]): [4, 5, 6]}", str(results.v))
|
||||||
|
|
||||||
|
results = session.execute('select * from w')[0]
|
||||||
|
self.print_results(results)
|
||||||
|
self.assertEqual("typ(v0=OrderedMapSerializedKey([(1, [1, 2, 3]), (2, [4, 5, 6])]), v1=[7, 8, 9])", str(results.v))
|
||||||
|
|
||||||
|
finally:
|
||||||
|
session.cluster.shutdown()
|
||||||
|
|
||||||
|
def run_inserts_at_version(self, proto_ver):
|
||||||
|
session = Cluster(protocol_version=proto_ver).connect(self.keyspace_name)
|
||||||
|
try:
|
||||||
|
print("running inserts with protocol version {0}".format(str(proto_ver)))
|
||||||
|
p = session.prepare('insert into t (k, v) values (?, ?)')
|
||||||
|
session.execute(p, (0, [{1, 2}, {3, 5}]))
|
||||||
|
|
||||||
|
p = session.prepare('insert into u (k, v) values (?, ?)')
|
||||||
|
session.execute(p, (0, {(1, 2), (3, 5)}))
|
||||||
|
|
||||||
|
p = session.prepare('insert into v (k, v, v1) values (?, ?, ?)')
|
||||||
|
session.execute(p, (0, {(1, 2): [1, 2, 3], (3, 5): [4, 5, 6]}, (123, 'four')))
|
||||||
|
|
||||||
|
p = session.prepare('insert into w (k, v) values (?, ?)')
|
||||||
|
session.execute(p, (0, ({1: [1, 2, 3], 2: [4, 5, 6]}, [7, 8, 9])))
|
||||||
|
|
||||||
|
finally:
|
||||||
|
session.cluster.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from cassandra.cluster import Cluster, UserTypeDoesNotExist
|
|||||||
from cassandra.query import dict_factory
|
from cassandra.query import dict_factory
|
||||||
from cassandra.util import OrderedMap
|
from cassandra.util import OrderedMap
|
||||||
|
|
||||||
from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION, execute_until_pass
|
from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION, execute_until_pass, BasicSegregatedKeyspaceUnitTestCase, greaterthancass20
|
||||||
from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, COLLECTION_TYPES, \
|
from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, COLLECTION_TYPES, \
|
||||||
get_sample, get_collection_sample
|
get_sample, get_collection_sample
|
||||||
|
|
||||||
@@ -39,27 +39,16 @@ def setup_module():
|
|||||||
update_datatypes()
|
update_datatypes()
|
||||||
|
|
||||||
|
|
||||||
class UDTTests(unittest.TestCase):
|
@greaterthancass20
|
||||||
|
class UDTTests(BasicSegregatedKeyspaceUnitTestCase):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def table_name(self):
|
def table_name(self):
|
||||||
return self._testMethodName.lower()
|
return self._testMethodName.lower()
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self._cass_version, self._cql_version = get_server_versions()
|
super(UDTTests, self).setUp()
|
||||||
|
self.session.set_keyspace(self.keyspace_name)
|
||||||
if self._cass_version < (2, 1, 0):
|
|
||||||
raise unittest.SkipTest("User Defined Types were introduced in Cassandra 2.1")
|
|
||||||
|
|
||||||
self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
|
||||||
self.session = self.cluster.connect()
|
|
||||||
execute_until_pass(self.session,
|
|
||||||
"CREATE KEYSPACE udttests WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}")
|
|
||||||
self.session.set_keyspace("udttests")
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
execute_until_pass(self.session, "DROP KEYSPACE udttests")
|
|
||||||
self.cluster.shutdown()
|
|
||||||
|
|
||||||
def test_can_insert_unprepared_registered_udts(self):
|
def test_can_insert_unprepared_registered_udts(self):
|
||||||
"""
|
"""
|
||||||
@@ -67,13 +56,13 @@ class UDTTests(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
s = c.connect("udttests")
|
s = c.connect(self.keyspace_name)
|
||||||
|
|
||||||
s.execute("CREATE TYPE user (age int, name text)")
|
s.execute("CREATE TYPE user (age int, name text)")
|
||||||
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen<user>)")
|
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen<user>)")
|
||||||
|
|
||||||
User = namedtuple('user', ('age', 'name'))
|
User = namedtuple('user', ('age', 'name'))
|
||||||
c.register_user_type("udttests", "user", User)
|
c.register_user_type(self.keyspace_name, "user", User)
|
||||||
|
|
||||||
s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User(42, 'bob')))
|
s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User(42, 'bob')))
|
||||||
result = s.execute("SELECT b FROM mytable WHERE a=0")
|
result = s.execute("SELECT b FROM mytable WHERE a=0")
|
||||||
@@ -169,7 +158,7 @@ class UDTTests(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
s = c.connect("udttests")
|
s = c.connect(self.keyspace_name)
|
||||||
|
|
||||||
s.execute("CREATE TYPE user (age int, name text)")
|
s.execute("CREATE TYPE user (age int, name text)")
|
||||||
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen<user>)")
|
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen<user>)")
|
||||||
@@ -213,11 +202,11 @@ class UDTTests(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
s = c.connect("udttests")
|
s = c.connect(self.keyspace_name)
|
||||||
|
|
||||||
s.execute("CREATE TYPE user (age int, name text)")
|
s.execute("CREATE TYPE user (age int, name text)")
|
||||||
User = namedtuple('user', ('age', 'name'))
|
User = namedtuple('user', ('age', 'name'))
|
||||||
c.register_user_type("udttests", "user", User)
|
c.register_user_type(self.keyspace_name, "user", User)
|
||||||
|
|
||||||
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen<user>)")
|
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen<user>)")
|
||||||
|
|
||||||
@@ -263,11 +252,11 @@ class UDTTests(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
s = c.connect("udttests")
|
s = c.connect(self.keyspace_name)
|
||||||
|
|
||||||
s.execute("CREATE TYPE user (a text, b int, c uuid, d blob)")
|
s.execute("CREATE TYPE user (a text, b int, c uuid, d blob)")
|
||||||
User = namedtuple('user', ('a', 'b', 'c', 'd'))
|
User = namedtuple('user', ('a', 'b', 'c', 'd'))
|
||||||
c.register_user_type("udttests", "user", User)
|
c.register_user_type(self.keyspace_name, "user", User)
|
||||||
|
|
||||||
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen<user>)")
|
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen<user>)")
|
||||||
|
|
||||||
@@ -293,7 +282,7 @@ class UDTTests(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
s = c.connect("udttests")
|
s = c.connect(self.keyspace_name)
|
||||||
|
|
||||||
MAX_TEST_LENGTH = 254
|
MAX_TEST_LENGTH = 254
|
||||||
|
|
||||||
@@ -310,7 +299,7 @@ class UDTTests(unittest.TestCase):
|
|||||||
|
|
||||||
# create and register the seed udt type
|
# create and register the seed udt type
|
||||||
udt = namedtuple('lengthy_udt', tuple(['v_{0}'.format(i) for i in range(MAX_TEST_LENGTH)]))
|
udt = namedtuple('lengthy_udt', tuple(['v_{0}'.format(i) for i in range(MAX_TEST_LENGTH)]))
|
||||||
c.register_user_type("udttests", "lengthy_udt", udt)
|
c.register_user_type(self.keyspace_name, "lengthy_udt", udt)
|
||||||
|
|
||||||
# verify inserts and reads
|
# verify inserts and reads
|
||||||
for i in (0, 1, 2, 3, MAX_TEST_LENGTH):
|
for i in (0, 1, 2, 3, MAX_TEST_LENGTH):
|
||||||
@@ -377,7 +366,7 @@ class UDTTests(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
s = c.connect("udttests")
|
s = c.connect(self.keyspace_name)
|
||||||
s.row_factory = dict_factory
|
s.row_factory = dict_factory
|
||||||
|
|
||||||
MAX_NESTING_DEPTH = 16
|
MAX_NESTING_DEPTH = 16
|
||||||
@@ -389,13 +378,13 @@ class UDTTests(unittest.TestCase):
|
|||||||
udts = []
|
udts = []
|
||||||
udt = namedtuple('depth_0', ('age', 'name'))
|
udt = namedtuple('depth_0', ('age', 'name'))
|
||||||
udts.append(udt)
|
udts.append(udt)
|
||||||
c.register_user_type("udttests", "depth_0", udts[0])
|
c.register_user_type(self.keyspace_name, "depth_0", udts[0])
|
||||||
|
|
||||||
# create and register the nested udt types
|
# create and register the nested udt types
|
||||||
for i in range(MAX_NESTING_DEPTH):
|
for i in range(MAX_NESTING_DEPTH):
|
||||||
udt = namedtuple('depth_{0}'.format(i + 1), ('value'))
|
udt = namedtuple('depth_{0}'.format(i + 1), ('value'))
|
||||||
udts.append(udt)
|
udts.append(udt)
|
||||||
c.register_user_type("udttests", "depth_{0}".format(i + 1), udts[i + 1])
|
c.register_user_type(self.keyspace_name, "depth_{0}".format(i + 1), udts[i + 1])
|
||||||
|
|
||||||
# insert udts and verify inserts with reads
|
# insert udts and verify inserts with reads
|
||||||
self.nested_udt_verification_helper(s, MAX_NESTING_DEPTH, udts)
|
self.nested_udt_verification_helper(s, MAX_NESTING_DEPTH, udts)
|
||||||
@@ -408,7 +397,7 @@ class UDTTests(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
s = c.connect("udttests")
|
s = c.connect(self.keyspace_name)
|
||||||
s.row_factory = dict_factory
|
s.row_factory = dict_factory
|
||||||
|
|
||||||
MAX_NESTING_DEPTH = 16
|
MAX_NESTING_DEPTH = 16
|
||||||
@@ -448,7 +437,7 @@ class UDTTests(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
s = c.connect("udttests")
|
s = c.connect(self.keyspace_name)
|
||||||
s.row_factory = dict_factory
|
s.row_factory = dict_factory
|
||||||
|
|
||||||
MAX_NESTING_DEPTH = 16
|
MAX_NESTING_DEPTH = 16
|
||||||
@@ -460,13 +449,13 @@ class UDTTests(unittest.TestCase):
|
|||||||
udts = []
|
udts = []
|
||||||
udt = namedtuple('level_0', ('age', 'name'))
|
udt = namedtuple('level_0', ('age', 'name'))
|
||||||
udts.append(udt)
|
udts.append(udt)
|
||||||
c.register_user_type("udttests", "depth_0", udts[0])
|
c.register_user_type(self.keyspace_name, "depth_0", udts[0])
|
||||||
|
|
||||||
# create and register the nested udt types
|
# create and register the nested udt types
|
||||||
for i in range(MAX_NESTING_DEPTH):
|
for i in range(MAX_NESTING_DEPTH):
|
||||||
udt = namedtuple('level_{0}'.format(i + 1), ('value'))
|
udt = namedtuple('level_{0}'.format(i + 1), ('value'))
|
||||||
udts.append(udt)
|
udts.append(udt)
|
||||||
c.register_user_type("udttests", "depth_{0}".format(i + 1), udts[i + 1])
|
c.register_user_type(self.keyspace_name, "depth_{0}".format(i + 1), udts[i + 1])
|
||||||
|
|
||||||
# insert udts and verify inserts with reads
|
# insert udts and verify inserts with reads
|
||||||
self.nested_udt_verification_helper(s, MAX_NESTING_DEPTH, udts)
|
self.nested_udt_verification_helper(s, MAX_NESTING_DEPTH, udts)
|
||||||
@@ -479,7 +468,7 @@ class UDTTests(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
s = c.connect("udttests")
|
s = c.connect(self.keyspace_name)
|
||||||
User = namedtuple('user', ('age', 'name'))
|
User = namedtuple('user', ('age', 'name'))
|
||||||
|
|
||||||
with self.assertRaises(UserTypeDoesNotExist):
|
with self.assertRaises(UserTypeDoesNotExist):
|
||||||
@@ -499,7 +488,7 @@ class UDTTests(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
s = c.connect("udttests")
|
s = c.connect(self.keyspace_name)
|
||||||
|
|
||||||
# create UDT
|
# create UDT
|
||||||
alpha_type_list = []
|
alpha_type_list = []
|
||||||
@@ -510,7 +499,7 @@ class UDTTests(unittest.TestCase):
|
|||||||
s.execute("""
|
s.execute("""
|
||||||
CREATE TYPE alldatatypes ({0})
|
CREATE TYPE alldatatypes ({0})
|
||||||
""".format(', '.join(alpha_type_list))
|
""".format(', '.join(alpha_type_list))
|
||||||
)
|
)
|
||||||
|
|
||||||
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen<alldatatypes>)")
|
s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen<alldatatypes>)")
|
||||||
|
|
||||||
@@ -519,7 +508,7 @@ class UDTTests(unittest.TestCase):
|
|||||||
for i in range(ord('a'), ord('a') + len(PRIMITIVE_DATATYPES)):
|
for i in range(ord('a'), ord('a') + len(PRIMITIVE_DATATYPES)):
|
||||||
alphabet_list.append('{0}'.format(chr(i)))
|
alphabet_list.append('{0}'.format(chr(i)))
|
||||||
Alldatatypes = namedtuple("alldatatypes", alphabet_list)
|
Alldatatypes = namedtuple("alldatatypes", alphabet_list)
|
||||||
c.register_user_type("udttests", "alldatatypes", Alldatatypes)
|
c.register_user_type(self.keyspace_name, "alldatatypes", Alldatatypes)
|
||||||
|
|
||||||
# insert UDT data
|
# insert UDT data
|
||||||
params = []
|
params = []
|
||||||
@@ -544,7 +533,7 @@ class UDTTests(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
s = c.connect("udttests")
|
s = c.connect(self.keyspace_name)
|
||||||
|
|
||||||
# create UDT
|
# create UDT
|
||||||
alpha_type_list = []
|
alpha_type_list = []
|
||||||
@@ -576,7 +565,7 @@ class UDTTests(unittest.TestCase):
|
|||||||
alphabet_list.append('{0}_{1}'.format(chr(i), chr(j)))
|
alphabet_list.append('{0}_{1}'.format(chr(i), chr(j)))
|
||||||
|
|
||||||
Alldatatypes = namedtuple("alldatatypes", alphabet_list)
|
Alldatatypes = namedtuple("alldatatypes", alphabet_list)
|
||||||
c.register_user_type("udttests", "alldatatypes", Alldatatypes)
|
c.register_user_type(self.keyspace_name, "alldatatypes", Alldatatypes)
|
||||||
|
|
||||||
# insert UDT data
|
# insert UDT data
|
||||||
params = []
|
params = []
|
||||||
@@ -607,11 +596,11 @@ class UDTTests(unittest.TestCase):
|
|||||||
Test for inserting various types of nested COLLECTION_TYPES into tables and UDTs
|
Test for inserting various types of nested COLLECTION_TYPES into tables and UDTs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self._cass_version < (2, 1, 3):
|
if self.cass_version < (2, 1, 3):
|
||||||
raise unittest.SkipTest("Support for nested collections was introduced in Cassandra 2.1.3")
|
raise unittest.SkipTest("Support for nested collections was introduced in Cassandra 2.1.3")
|
||||||
|
|
||||||
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
c = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
s = c.connect("udttests")
|
s = c.connect(self.keyspace_name)
|
||||||
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
|
s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple
|
||||||
|
|
||||||
name = self._testMethodName
|
name = self._testMethodName
|
||||||
@@ -710,3 +699,36 @@ class UDTTests(unittest.TestCase):
|
|||||||
val = s.execute('SELECT v FROM %s' % self.table_name)[0][0]
|
val = s.execute('SELECT v FROM %s' % self.table_name)[0][0]
|
||||||
self.assertEqual(val['v0'], 3)
|
self.assertEqual(val['v0'], 3)
|
||||||
self.assertEqual(val['v1'], six.b('\xde\xad\xbe\xef'))
|
self.assertEqual(val['v1'], six.b('\xde\xad\xbe\xef'))
|
||||||
|
|
||||||
|
def test_alter_udt(self):
|
||||||
|
"""
|
||||||
|
Test to ensure that altered UDT's are properly surfaced without needing to restart the underlying session.
|
||||||
|
|
||||||
|
@since 3.0.0
|
||||||
|
@jira_ticket PYTHON-226
|
||||||
|
@expected_result UDT's will reflect added columns without a session restart.
|
||||||
|
|
||||||
|
@test_category data_types, udt
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create udt ensure it has the proper column names.
|
||||||
|
self.session.set_keyspace(self.keyspace_name)
|
||||||
|
self.session.execute("CREATE TYPE typetoalter (a int)")
|
||||||
|
typetoalter = namedtuple('typetoalter', ('a'))
|
||||||
|
self.session.execute("CREATE TABLE {0} (pk int primary key, typetoalter frozen<typetoalter>)".format(self.function_table_name))
|
||||||
|
insert_statement = self.session.prepare("INSERT INTO {0} (pk, typetoalter) VALUES (?, ?)".format(self.function_table_name))
|
||||||
|
self.session.execute(insert_statement, [1, typetoalter(1)])
|
||||||
|
results = self.session.execute("SELECT * from {0}".format(self.function_table_name))
|
||||||
|
for result in results:
|
||||||
|
self.assertTrue(hasattr(result.typetoalter, 'a'))
|
||||||
|
self.assertFalse(hasattr(result.typetoalter, 'b'))
|
||||||
|
|
||||||
|
# Alter UDT and ensure the alter is honored in results
|
||||||
|
self.session.execute("ALTER TYPE typetoalter add b int")
|
||||||
|
typetoalter = namedtuple('typetoalter', ('a', 'b'))
|
||||||
|
self.session.execute(insert_statement, [2, typetoalter(2, 2)])
|
||||||
|
results = self.session.execute("SELECT * from {0}".format(self.function_table_name))
|
||||||
|
for result in results:
|
||||||
|
self.assertTrue(hasattr(result.typetoalter, 'a'))
|
||||||
|
self.assertTrue(hasattr(result.typetoalter, 'b'))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user