Files
deb-python-cassandra-driver/tests/integration/standard/test_query.py
Tyler Hobbs 8bee26d6d6 Handle custom encoders in nested data types
The root of the problem was that nested data types would use the default
encoders for subitems.  When the encoders were customized, they would
not be used for those nested items.

This fix moves the encoder functions into a class so that collections,
tuples, and UDTs will use the customized mapping when encoding subitems.

Fixes PYTHON-100.
2014-07-22 14:21:04 -05:00

355 lines
12 KiB
Python

# Copyright 2013-2014 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
from cassandra import ConsistencyLevel
from cassandra.query import (PreparedStatement, BoundStatement, SimpleStatement,
BatchStatement, BatchType, dict_factory)
from cassandra.cluster import Cluster
from cassandra.policies import HostDistance
from tests.integration import PROTOCOL_VERSION
class QueryTest(unittest.TestCase):
def test_query(self):
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
prepared = session.prepare(
"""
INSERT INTO test3rf.test (k, v) VALUES (?, ?)
""")
self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind((1, None))
self.assertIsInstance(bound, BoundStatement)
self.assertEqual(2, len(bound.values))
session.execute(bound)
self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')
def test_trace_prints_okay(self):
"""
Code coverage to ensure trace prints to string without error
"""
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
query = "SELECT * FROM system.local"
statement = SimpleStatement(query)
session.execute(statement, trace=True)
# Ensure this does not throw an exception
str(statement.trace)
for event in statement.trace.events:
str(event)
def test_trace_ignores_row_factory(self):
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
session.row_factory = dict_factory
query = "SELECT * FROM system.local"
statement = SimpleStatement(query)
session.execute(statement, trace=True)
# Ensure this does not throw an exception
str(statement.trace)
for event in statement.trace.events:
str(event)
class PreparedStatementTests(unittest.TestCase):
def test_routing_key(self):
"""
Simple code coverage to ensure routing_keys can be accessed
"""
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
prepared = session.prepare(
"""
INSERT INTO test3rf.test (k, v) VALUES (?, ?)
""")
self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind((1, None))
self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')
def test_empty_routing_key_indexes(self):
"""
Ensure when routing_key_indexes are blank,
the routing key should be None
"""
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
prepared = session.prepare(
"""
INSERT INTO test3rf.test (k, v) VALUES (?, ?)
""")
prepared.routing_key_indexes = None
self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind((1, None))
self.assertEqual(bound.routing_key, None)
def test_predefined_routing_key(self):
"""
Basic test that ensures _set_routing_key()
overrides the current routing key
"""
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
prepared = session.prepare(
"""
INSERT INTO test3rf.test (k, v) VALUES (?, ?)
""")
self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind((1, None))
bound._set_routing_key('fake_key')
self.assertEqual(bound.routing_key, 'fake_key')
def test_multiple_routing_key_indexes(self):
"""
Basic test that uses a fake routing_key_index
"""
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
prepared = session.prepare(
"""
INSERT INTO test3rf.test (k, v) VALUES (?, ?)
""")
prepared.routing_key_indexes = {0: {0: 0}, 1: {1: 1}}
self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind((1, 2))
self.assertEqual(bound.routing_key, b'\x04\x00\x00\x00\x04\x00\x00\x00')
def test_bound_keyspace(self):
"""
Ensure that bound.keyspace works as expected
"""
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
prepared = session.prepare(
"""
INSERT INTO test3rf.test (k, v) VALUES (?, ?)
""")
self.assertIsInstance(prepared, PreparedStatement)
bound = prepared.bind((1, 2))
self.assertEqual(bound.keyspace, 'test3rf')
bound.prepared_statement.column_metadata = None
self.assertEqual(bound.keyspace, None)
class PrintStatementTests(unittest.TestCase):
"""
Test that shows the format used when printing Statements
"""
def test_simple_statement(self):
"""
Highlight the format of printing SimpleStatements
"""
ss = SimpleStatement('SELECT * FROM test3rf.test', consistency_level=ConsistencyLevel.ONE)
self.assertEqual(str(ss),
'<SimpleStatement query="SELECT * FROM test3rf.test", consistency=ONE>')
def test_prepared_statement(self):
"""
Highlight the difference between Prepared and Bound statements
"""
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect()
prepared = session.prepare('INSERT INTO test3rf.test (k, v) VALUES (?, ?)')
prepared.consistency_level = ConsistencyLevel.ONE
self.assertEqual(str(prepared),
'<PreparedStatement query="INSERT INTO test3rf.test (k, v) VALUES (?, ?)", consistency=ONE>')
bound = prepared.bind((1, 2))
self.assertEqual(str(bound),
'<BoundStatement query="INSERT INTO test3rf.test (k, v) VALUES (?, ?)", values=(1, 2), consistency=ONE>')
class BatchStatementTests(unittest.TestCase):
def setUp(self):
if PROTOCOL_VERSION < 2:
raise unittest.SkipTest(
"Protocol 2.0+ is required for BATCH operations, currently testing against %r"
% (PROTOCOL_VERSION,))
self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
if PROTOCOL_VERSION < 3:
self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
self.session = self.cluster.connect()
self.session.execute("TRUNCATE test3rf.test")
def tearDown(self):
self.cluster.shutdown()
def confirm_results(self):
keys = set()
values = set()
results = self.session.execute("SELECT * FROM test3rf.test")
for result in results:
keys.add(result.k)
values.add(result.v)
self.assertEqual(set(range(10)), keys)
self.assertEqual(set(range(10)), values)
def test_string_statements(self):
batch = BatchStatement(BatchType.LOGGED)
for i in range(10):
batch.add("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", (i, i))
self.session.execute(batch)
self.session.execute_async(batch).result()
self.confirm_results()
def test_simple_statements(self):
batch = BatchStatement(BatchType.LOGGED)
for i in range(10):
batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)"), (i, i))
self.session.execute(batch)
self.session.execute_async(batch).result()
self.confirm_results()
def test_prepared_statements(self):
prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (?, ?)")
batch = BatchStatement(BatchType.LOGGED)
for i in range(10):
batch.add(prepared, (i, i))
self.session.execute(batch)
self.session.execute_async(batch).result()
self.confirm_results()
def test_bound_statements(self):
prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (?, ?)")
batch = BatchStatement(BatchType.LOGGED)
for i in range(10):
batch.add(prepared.bind((i, i)))
self.session.execute(batch)
self.session.execute_async(batch).result()
self.confirm_results()
def test_no_parameters(self):
batch = BatchStatement(BatchType.LOGGED)
batch.add("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
batch.add("INSERT INTO test3rf.test (k, v) VALUES (1, 1)", ())
batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (2, 2)"))
batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (3, 3)"), ())
prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (4, 4)")
batch.add(prepared)
batch.add(prepared, ())
batch.add(prepared.bind([]))
batch.add(prepared.bind([]), ())
batch.add("INSERT INTO test3rf.test (k, v) VALUES (5, 5)", ())
batch.add("INSERT INTO test3rf.test (k, v) VALUES (6, 6)", ())
batch.add("INSERT INTO test3rf.test (k, v) VALUES (7, 7)", ())
batch.add("INSERT INTO test3rf.test (k, v) VALUES (8, 8)", ())
batch.add("INSERT INTO test3rf.test (k, v) VALUES (9, 9)", ())
self.assertRaises(ValueError, batch.add, prepared.bind([]), (1))
self.assertRaises(ValueError, batch.add, prepared.bind([]), (1, 2))
self.assertRaises(ValueError, batch.add, prepared.bind([]), (1, 2, 3))
self.session.execute(batch)
self.confirm_results()
class SerialConsistencyTests(unittest.TestCase):
def setUp(self):
if PROTOCOL_VERSION < 2:
raise unittest.SkipTest(
"Protocol 2.0+ is required for BATCH operations, currently testing against %r"
% (PROTOCOL_VERSION,))
self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
if PROTOCOL_VERSION < 3:
self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
self.session = self.cluster.connect()
def test_conditional_update(self):
self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
statement = SimpleStatement(
"UPDATE test3rf.test SET v=1 WHERE k=0 IF v=1",
serial_consistency_level=ConsistencyLevel.SERIAL)
result = self.session.execute(statement)
self.assertEqual(1, len(result))
self.assertFalse(result[0].applied)
statement = SimpleStatement(
"UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0",
serial_consistency_level=ConsistencyLevel.SERIAL)
result = self.session.execute(statement)
self.assertEqual(1, len(result))
self.assertTrue(result[0].applied)
def test_conditional_update_with_prepared_statements(self):
self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
statement = self.session.prepare(
"UPDATE test3rf.test SET v=1 WHERE k=0 IF v=2")
statement.serial_consistency_level = ConsistencyLevel.SERIAL
result = self.session.execute(statement)
self.assertEqual(1, len(result))
self.assertFalse(result[0].applied)
statement = self.session.prepare(
"UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0")
bound = statement.bind(())
bound.serial_consistency_level = ConsistencyLevel.SERIAL
result = self.session.execute(statement)
self.assertEqual(1, len(result))
self.assertTrue(result[0].applied)
def test_bad_consistency_level(self):
statement = SimpleStatement("foo")
self.assertRaises(ValueError, setattr, statement, 'serial_consistency_level', ConsistencyLevel.ONE)
self.assertRaises(ValueError, SimpleStatement, 'foo', serial_consistency_level=ConsistencyLevel.ONE)