Conflicts: cassandra/cluster.py cassandra/encoder.py cassandra/marshal.py cassandra/pool.py setup.py tests/integration/long/test_large_data.py tests/integration/long/utils.py tests/integration/standard/test_metadata.py tests/integration/standard/test_prepared_statements.py tests/unit/io/test_asyncorereactor.py tests/unit/test_connection.py tests/unit/test_types.py
		
			
				
	
	
		
			361 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			361 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright 2013-2014 DataStax, Inc.
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
# http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
 | 
						|
try:
 | 
						|
    import unittest2 as unittest
 | 
						|
except ImportError:
 | 
						|
    import unittest # noqa
 | 
						|
 | 
						|
from cassandra import ConsistencyLevel
 | 
						|
from cassandra.query import (PreparedStatement, BoundStatement, ValueSequence,
 | 
						|
                             SimpleStatement, BatchStatement, BatchType,
 | 
						|
                             dict_factory)
 | 
						|
from cassandra.cluster import Cluster
 | 
						|
from cassandra.policies import HostDistance
 | 
						|
 | 
						|
from tests.integration import PROTOCOL_VERSION
 | 
						|
 | 
						|
 | 
						|
class QueryTest(unittest.TestCase):
 | 
						|
 | 
						|
    def test_query(self):
 | 
						|
        cluster = Cluster(protocol_version=PROTOCOL_VERSION)
 | 
						|
        session = cluster.connect()
 | 
						|
 | 
						|
        prepared = session.prepare(
 | 
						|
            """
 | 
						|
            INSERT INTO test3rf.test (k, v) VALUES  (?, ?)
 | 
						|
            """)
 | 
						|
 | 
						|
        self.assertIsInstance(prepared, PreparedStatement)
 | 
						|
        bound = prepared.bind((1, None))
 | 
						|
        self.assertIsInstance(bound, BoundStatement)
 | 
						|
        self.assertEqual(2, len(bound.values))
 | 
						|
        session.execute(bound)
 | 
						|
        self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')
 | 
						|
 | 
						|
    def test_value_sequence(self):
 | 
						|
        """
 | 
						|
        Test the output of ValueSequences()
 | 
						|
        """
 | 
						|
 | 
						|
        my_user_ids = ('alice', 'bob', 'charles')
 | 
						|
        self.assertEqual(str(ValueSequence(my_user_ids)), "( 'alice' , 'bob' , 'charles' )")
 | 
						|
 | 
						|
    def test_trace_prints_okay(self):
 | 
						|
        """
 | 
						|
        Code coverage to ensure trace prints to string without error
 | 
						|
        """
 | 
						|
 | 
						|
        cluster = Cluster(protocol_version=PROTOCOL_VERSION)
 | 
						|
        session = cluster.connect()
 | 
						|
 | 
						|
        query = "SELECT * FROM system.local"
 | 
						|
        statement = SimpleStatement(query)
 | 
						|
        session.execute(statement, trace=True)
 | 
						|
 | 
						|
        # Ensure this does not throw an exception
 | 
						|
        str(statement.trace)
 | 
						|
        for event in statement.trace.events:
 | 
						|
            str(event)
 | 
						|
 | 
						|
    def test_trace_ignores_row_factory(self):
 | 
						|
        cluster = Cluster(protocol_version=PROTOCOL_VERSION)
 | 
						|
        session = cluster.connect()
 | 
						|
        session.row_factory = dict_factory
 | 
						|
 | 
						|
        query = "SELECT * FROM system.local"
 | 
						|
        statement = SimpleStatement(query)
 | 
						|
        session.execute(statement, trace=True)
 | 
						|
 | 
						|
        # Ensure this does not throw an exception
 | 
						|
        str(statement.trace)
 | 
						|
        for event in statement.trace.events:
 | 
						|
            str(event)
 | 
						|
 | 
						|
 | 
						|
class PreparedStatementTests(unittest.TestCase):
 | 
						|
 | 
						|
    def test_routing_key(self):
 | 
						|
        """
 | 
						|
        Simple code coverage to ensure routing_keys can be accessed
 | 
						|
        """
 | 
						|
 | 
						|
        cluster = Cluster(protocol_version=PROTOCOL_VERSION)
 | 
						|
        session = cluster.connect()
 | 
						|
 | 
						|
        prepared = session.prepare(
 | 
						|
            """
 | 
						|
            INSERT INTO test3rf.test (k, v) VALUES  (?, ?)
 | 
						|
            """)
 | 
						|
 | 
						|
        self.assertIsInstance(prepared, PreparedStatement)
 | 
						|
        bound = prepared.bind((1, None))
 | 
						|
        self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01')
 | 
						|
 | 
						|
    def test_empty_routing_key_indexes(self):
 | 
						|
        """
 | 
						|
        Ensure when routing_key_indexes are blank,
 | 
						|
        the routing key should be None
 | 
						|
        """
 | 
						|
 | 
						|
        cluster = Cluster(protocol_version=PROTOCOL_VERSION)
 | 
						|
        session = cluster.connect()
 | 
						|
 | 
						|
        prepared = session.prepare(
 | 
						|
            """
 | 
						|
            INSERT INTO test3rf.test (k, v) VALUES  (?, ?)
 | 
						|
            """)
 | 
						|
        prepared.routing_key_indexes = None
 | 
						|
 | 
						|
        self.assertIsInstance(prepared, PreparedStatement)
 | 
						|
        bound = prepared.bind((1, None))
 | 
						|
        self.assertEqual(bound.routing_key, None)
 | 
						|
 | 
						|
    def test_predefined_routing_key(self):
 | 
						|
        """
 | 
						|
        Basic test that ensures _set_routing_key()
 | 
						|
        overrides the current routing key
 | 
						|
        """
 | 
						|
 | 
						|
        cluster = Cluster(protocol_version=PROTOCOL_VERSION)
 | 
						|
        session = cluster.connect()
 | 
						|
 | 
						|
        prepared = session.prepare(
 | 
						|
            """
 | 
						|
            INSERT INTO test3rf.test (k, v) VALUES  (?, ?)
 | 
						|
            """)
 | 
						|
 | 
						|
        self.assertIsInstance(prepared, PreparedStatement)
 | 
						|
        bound = prepared.bind((1, None))
 | 
						|
        bound._set_routing_key('fake_key')
 | 
						|
        self.assertEqual(bound.routing_key, 'fake_key')
 | 
						|
 | 
						|
    def test_multiple_routing_key_indexes(self):
 | 
						|
        """
 | 
						|
        Basic test that uses a fake routing_key_index
 | 
						|
        """
 | 
						|
 | 
						|
        cluster = Cluster(protocol_version=PROTOCOL_VERSION)
 | 
						|
        session = cluster.connect()
 | 
						|
 | 
						|
        prepared = session.prepare(
 | 
						|
            """
 | 
						|
            INSERT INTO test3rf.test (k, v) VALUES  (?, ?)
 | 
						|
            """)
 | 
						|
        prepared.routing_key_indexes = {0: {0: 0}, 1: {1: 1}}
 | 
						|
 | 
						|
        self.assertIsInstance(prepared, PreparedStatement)
 | 
						|
        bound = prepared.bind((1, 2))
 | 
						|
        self.assertEqual(bound.routing_key, b'\x04\x00\x00\x00\x04\x00\x00\x00')
 | 
						|
 | 
						|
    def test_bound_keyspace(self):
 | 
						|
        """
 | 
						|
        Ensure that bound.keyspace works as expected
 | 
						|
        """
 | 
						|
 | 
						|
        cluster = Cluster(protocol_version=PROTOCOL_VERSION)
 | 
						|
        session = cluster.connect()
 | 
						|
 | 
						|
        prepared = session.prepare(
 | 
						|
            """
 | 
						|
            INSERT INTO test3rf.test (k, v) VALUES  (?, ?)
 | 
						|
            """)
 | 
						|
 | 
						|
        self.assertIsInstance(prepared, PreparedStatement)
 | 
						|
        bound = prepared.bind((1, 2))
 | 
						|
        self.assertEqual(bound.keyspace, 'test3rf')
 | 
						|
 | 
						|
        bound.prepared_statement.column_metadata = None
 | 
						|
        self.assertEqual(bound.keyspace, None)
 | 
						|
 | 
						|
 | 
						|
class PrintStatementTests(unittest.TestCase):
 | 
						|
    """
 | 
						|
    Test that shows the format used when printing Statements
 | 
						|
    """
 | 
						|
 | 
						|
    def test_simple_statement(self):
 | 
						|
        """
 | 
						|
        Highlight the format of printing SimpleStatements
 | 
						|
        """
 | 
						|
 | 
						|
        ss = SimpleStatement('SELECT * FROM test3rf.test')
 | 
						|
        self.assertEqual(str(ss),
 | 
						|
                         '<SimpleStatement query="SELECT * FROM test3rf.test", consistency=ONE>')
 | 
						|
 | 
						|
    def test_prepared_statement(self):
 | 
						|
        """
 | 
						|
        Highlight the difference between Prepared and Bound statements
 | 
						|
        """
 | 
						|
 | 
						|
        cluster = Cluster(protocol_version=PROTOCOL_VERSION)
 | 
						|
        session = cluster.connect()
 | 
						|
 | 
						|
        prepared = session.prepare('INSERT INTO test3rf.test (k, v) VALUES (?, ?)')
 | 
						|
 | 
						|
        self.assertEqual(str(prepared),
 | 
						|
                         '<PreparedStatement query="INSERT INTO test3rf.test (k, v) VALUES (?, ?)", consistency=ONE>')
 | 
						|
 | 
						|
        bound = prepared.bind((1, 2))
 | 
						|
        self.assertEqual(str(bound),
 | 
						|
                         '<BoundStatement query="INSERT INTO test3rf.test (k, v) VALUES (?, ?)", values=(1, 2), consistency=ONE>')
 | 
						|
 | 
						|
 | 
						|
class BatchStatementTests(unittest.TestCase):
 | 
						|
 | 
						|
    def setUp(self):
 | 
						|
        if PROTOCOL_VERSION < 2:
 | 
						|
            raise unittest.SkipTest(
 | 
						|
                "Protocol 2.0+ is required for BATCH operations, currently testing against %r"
 | 
						|
                % (PROTOCOL_VERSION,))
 | 
						|
 | 
						|
        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
 | 
						|
        self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
 | 
						|
        self.session = self.cluster.connect()
 | 
						|
 | 
						|
        self.session.execute("TRUNCATE test3rf.test")
 | 
						|
 | 
						|
    def tearDown(self):
 | 
						|
        self.cluster.shutdown()
 | 
						|
 | 
						|
    def confirm_results(self):
 | 
						|
        keys = set()
 | 
						|
        values = set()
 | 
						|
        results = self.session.execute("SELECT * FROM test3rf.test")
 | 
						|
        for result in results:
 | 
						|
            keys.add(result.k)
 | 
						|
            values.add(result.v)
 | 
						|
 | 
						|
        self.assertEqual(set(range(10)), keys)
 | 
						|
        self.assertEqual(set(range(10)), values)
 | 
						|
 | 
						|
    def test_string_statements(self):
 | 
						|
        batch = BatchStatement(BatchType.LOGGED)
 | 
						|
        for i in range(10):
 | 
						|
            batch.add("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", (i, i))
 | 
						|
 | 
						|
        self.session.execute(batch)
 | 
						|
        self.session.execute_async(batch).result()
 | 
						|
        self.confirm_results()
 | 
						|
 | 
						|
    def test_simple_statements(self):
 | 
						|
        batch = BatchStatement(BatchType.LOGGED)
 | 
						|
        for i in range(10):
 | 
						|
            batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)"), (i, i))
 | 
						|
 | 
						|
        self.session.execute(batch)
 | 
						|
        self.session.execute_async(batch).result()
 | 
						|
        self.confirm_results()
 | 
						|
 | 
						|
    def test_prepared_statements(self):
 | 
						|
        prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (?, ?)")
 | 
						|
 | 
						|
        batch = BatchStatement(BatchType.LOGGED)
 | 
						|
        for i in range(10):
 | 
						|
            batch.add(prepared, (i, i))
 | 
						|
 | 
						|
        self.session.execute(batch)
 | 
						|
        self.session.execute_async(batch).result()
 | 
						|
        self.confirm_results()
 | 
						|
 | 
						|
    def test_bound_statements(self):
 | 
						|
        prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (?, ?)")
 | 
						|
 | 
						|
        batch = BatchStatement(BatchType.LOGGED)
 | 
						|
        for i in range(10):
 | 
						|
            batch.add(prepared.bind((i, i)))
 | 
						|
 | 
						|
        self.session.execute(batch)
 | 
						|
        self.session.execute_async(batch).result()
 | 
						|
        self.confirm_results()
 | 
						|
 | 
						|
    def test_no_parameters(self):
 | 
						|
        batch = BatchStatement(BatchType.LOGGED)
 | 
						|
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
 | 
						|
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (1, 1)", ())
 | 
						|
        batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (2, 2)"))
 | 
						|
        batch.add(SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (3, 3)"), ())
 | 
						|
 | 
						|
        prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (4, 4)")
 | 
						|
        batch.add(prepared)
 | 
						|
        batch.add(prepared, ())
 | 
						|
        batch.add(prepared.bind([]))
 | 
						|
        batch.add(prepared.bind([]), ())
 | 
						|
 | 
						|
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (5, 5)", ())
 | 
						|
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (6, 6)", ())
 | 
						|
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (7, 7)", ())
 | 
						|
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (8, 8)", ())
 | 
						|
        batch.add("INSERT INTO test3rf.test (k, v) VALUES (9, 9)", ())
 | 
						|
 | 
						|
        self.assertRaises(ValueError, batch.add, prepared.bind([]), (1))
 | 
						|
        self.assertRaises(ValueError, batch.add, prepared.bind([]), (1, 2))
 | 
						|
        self.assertRaises(ValueError, batch.add, prepared.bind([]), (1, 2, 3))
 | 
						|
 | 
						|
        self.session.execute(batch)
 | 
						|
        self.confirm_results()
 | 
						|
 | 
						|
 | 
						|
class SerialConsistencyTests(unittest.TestCase):
 | 
						|
 | 
						|
    def setUp(self):
 | 
						|
        if PROTOCOL_VERSION < 2:
 | 
						|
            raise unittest.SkipTest(
 | 
						|
                "Protocol 2.0+ is required for BATCH operations, currently testing against %r"
 | 
						|
                % (PROTOCOL_VERSION,))
 | 
						|
 | 
						|
        self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
 | 
						|
        self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
 | 
						|
        self.session = self.cluster.connect()
 | 
						|
 | 
						|
    def test_conditional_update(self):
 | 
						|
        self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
 | 
						|
        statement = SimpleStatement(
 | 
						|
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=1",
 | 
						|
            serial_consistency_level=ConsistencyLevel.SERIAL)
 | 
						|
        result = self.session.execute(statement)
 | 
						|
        self.assertEqual(1, len(result))
 | 
						|
        self.assertFalse(result[0].applied)
 | 
						|
 | 
						|
        statement = SimpleStatement(
 | 
						|
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0",
 | 
						|
            serial_consistency_level=ConsistencyLevel.SERIAL)
 | 
						|
        result = self.session.execute(statement)
 | 
						|
        self.assertEqual(1, len(result))
 | 
						|
        self.assertTrue(result[0].applied)
 | 
						|
 | 
						|
    def test_conditional_update_with_prepared_statements(self):
 | 
						|
        self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
 | 
						|
        statement = self.session.prepare(
 | 
						|
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=2")
 | 
						|
 | 
						|
        statement.serial_consistency_level = ConsistencyLevel.SERIAL
 | 
						|
        result = self.session.execute(statement)
 | 
						|
        self.assertEqual(1, len(result))
 | 
						|
        self.assertFalse(result[0].applied)
 | 
						|
 | 
						|
        statement = self.session.prepare(
 | 
						|
            "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0")
 | 
						|
        bound = statement.bind(())
 | 
						|
        bound.serial_consistency_level = ConsistencyLevel.SERIAL
 | 
						|
        result = self.session.execute(statement)
 | 
						|
        self.assertEqual(1, len(result))
 | 
						|
        self.assertTrue(result[0].applied)
 | 
						|
 | 
						|
    def test_bad_consistency_level(self):
 | 
						|
        statement = SimpleStatement("foo")
 | 
						|
        self.assertRaises(ValueError, setattr, statement, 'serial_consistency_level', ConsistencyLevel.ONE)
 | 
						|
        self.assertRaises(ValueError, SimpleStatement, 'foo', serial_consistency_level=ConsistencyLevel.ONE)
 |