diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index f5f189a7..b55bf17c 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -203,6 +203,7 @@ def setup_test_keyspace(): k int PRIMARY KEY, v int )''' session.execute(ddl) + except Exception: traceback.print_exc() raise diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index e493a700..cacda2a2 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -11,11 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os + +from cassandra.concurrent import execute_concurrent + try: import unittest2 as unittest except ImportError: - import unittest # noqa + import unittest # noqa from cassandra import ConsistencyLevel from cassandra.query import (PreparedStatement, BoundStatement, SimpleStatement, @@ -352,3 +356,70 @@ class SerialConsistencyTests(unittest.TestCase): statement = SimpleStatement("foo") self.assertRaises(ValueError, setattr, statement, 'serial_consistency_level', ConsistencyLevel.ONE) self.assertRaises(ValueError, SimpleStatement, 'foo', serial_consistency_level=ConsistencyLevel.ONE) + + +class LightweightTransactionsTests(unittest.TestCase): + + def setUp(self): + """ + Test is skipped if run with cql version < 2 + + """ + if PROTOCOL_VERSION < 2: + raise unittest.SkipTest( + "Protocol 2.0+ is required for Lightweight transactions, currently testing against %r" + % (PROTOCOL_VERSION,)) + + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + self.session = self.cluster.connect() + + ddl = ''' + CREATE TABLE test3rf.lwt ( + k int PRIMARY KEY, + v int )''' + self.session.execute(ddl) + + def tearDown(self): + """ + Shutdown cluster + """ + self.session.execute("DROP TABLE test3rf.lwt") + self.cluster.shutdown() + + def test_no_connection_refused_on_timeout(self): + """ + Test for PYTHON-91 "Connection closed after LWT timeout" + Verifies that connection to the cluster is not shut down when timeout occurs. + Number of iterations can be specified with LWT_ITERATIONS environment variable. + Default value is 1000 + """ + insert_statement = self.session.prepare("INSERT INTO test3rf.lwt (k, v) VALUES (0, 0) IF NOT EXISTS") + delete_statement = self.session.prepare("DELETE FROM test3rf.lwt WHERE k = 0 IF EXISTS") + + iterations = int(os.getenv("LWT_ITERATIONS", 1000)) + print("Started test for %d iterations" % iterations) + + # Prepare series of parallel statements + statements_and_params = [] + for i in range(iterations): + statements_and_params.append((insert_statement, ())) + statements_and_params.append((delete_statement, ())) + + received_timeout = False + results = execute_concurrent(self.session, statements_and_params, raise_on_first_error=False) + for (success, result) in results: + if success: + continue + # In this case result is an exception + if type(result).__name__ == "NoHostAvailable": + self.fail("PYTHON-91: Disconnected from Cassandra: %s" % result.message) + break + if type(result).__name__ == "WriteTimeout": + print("Timeout: %s" % result.message) + received_timeout = True + continue + self.fail("Unexpected exception %s: %s" % (type(result).__name__, result.message)) + break + + # Make sure test passed + self.assertTrue(received_timeout)