From 6055df24d010c243e222a515039f5777aff26b1c Mon Sep 17 00:00:00 2001 From: Tyler Hobbs Date: Wed, 25 Jun 2014 13:57:13 -0500 Subject: [PATCH] Fix cassandra.concurrent behavior with automatic paging Fixes PYTHON-81 --- CHANGELOG.rst | 2 ++ cassandra/cluster.py | 9 ++++++- cassandra/concurrent.py | 26 ++++++++++++++----- .../integration/standard/test_query_paging.py | 17 +++++++++++- 4 files changed, 45 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 1a9aa1de..0280c043 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -32,6 +32,8 @@ Bug Fixes * Don't share prepared statement lock across Cluster instances * Format CompositeType and DynamicCompositeType columns correctly in CREATE TABLE statements. +* Fix cassandra.concurrent behavior when dealing with automatic paging + (PYTHON-81) 2.0.2 ===== diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 839aeded..847b58ca 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2704,6 +2704,11 @@ class ResponseFuture(object): self.add_callback(callback, *callback_args, **(callback_kwargs or {})) self.add_errback(errback, *errback_args, **(errback_kwargs or {})) + def clear_callbacks(self): + with self._callback_lock: + self._callback = None + self._errback = None + def __str__(self): result = "(no result yet)" if self._final_result is _NOT_SET else self._final_result return "" \ @@ -2744,6 +2749,8 @@ class PagedResult(object): .. versionadded: 2.0.0 """ + response_future = None + def __init__(self, response_future, initial_response): self.response_future = response_future self.current_response = iter(initial_response) @@ -2755,7 +2762,7 @@ class PagedResult(object): try: return next(self.current_response) except StopIteration: - if self.response_future._paging_state is None: + if not self.response_future.has_more_pages: raise self.response_future.start_fetching_next_page() diff --git a/cassandra/concurrent.py b/cassandra/concurrent.py index 3c3a5208..0f646da8 100644 --- a/cassandra/concurrent.py +++ b/cassandra/concurrent.py @@ -16,9 +16,14 @@ import six import sys from itertools import count, cycle +import logging from six.moves import xrange from threading import Event +from cassandra.cluster import PagedResult + +log = logging.getLogger(__name__) + def execute_concurrent(session, statements_and_parameters, concurrency=100, raise_on_first_error=True): """ @@ -81,7 +86,7 @@ def execute_concurrent(session, statements_and_parameters, concurrency=100, rais num_finished = count(start=1) statements = enumerate(iter(statements_and_parameters)) for i in xrange(min(concurrency, len(statements_and_parameters))): - _execute_next(_sentinel, i, event, session, statements, results, num_finished, to_execute, first_error) + _execute_next(_sentinel, i, event, session, statements, results, None, num_finished, to_execute, first_error) event.wait() if first_error: @@ -113,7 +118,8 @@ def execute_concurrent_with_args(session, statement, parameters, *args, **kwargs _sentinel = object() -def _handle_error(error, result_index, event, session, statements, results, num_finished, to_execute, first_error): +def _handle_error(error, result_index, event, session, statements, results, + future, num_finished, to_execute, first_error): if first_error is not None: first_error.append(error) event.set() @@ -129,9 +135,10 @@ def _handle_error(error, result_index, event, session, statements, results, num_ except StopIteration: return - args = (next_index, event, session, statements, results, num_finished, to_execute, first_error) try: - session.execute_async(statement, params).add_callbacks( + future = session.execute_async(statement, params) + args = (next_index, event, session, statements, results, future, num_finished, to_execute, first_error) + future.add_callbacks( callback=_execute_next, callback_args=args, errback=_handle_error, errback_args=args) except Exception as exc: @@ -149,8 +156,12 @@ def _handle_error(error, result_index, event, session, statements, results, num_ return -def _execute_next(result, result_index, event, session, statements, results, num_finished, to_execute, first_error): +def _execute_next(result, result_index, event, session, statements, results, + future, num_finished, to_execute, first_error): if result is not _sentinel: + if future.has_more_pages: + result = PagedResult(future, result) + future.clear_callbacks() results[result_index] = (True, result) finished = next(num_finished) if finished >= to_execute: @@ -162,9 +173,10 @@ def _execute_next(result, result_index, event, session, statements, results, num except StopIteration: return - args = (next_index, event, session, statements, results, num_finished, to_execute, first_error) try: - session.execute_async(statement, params).add_callbacks( + future = session.execute_async(statement, params) + args = (next_index, event, session, statements, results, future, num_finished, to_execute, first_error) + future.add_callbacks( callback=_execute_next, callback_args=args, errback=_handle_error, errback_args=args) except Exception as exc: diff --git a/tests/integration/standard/test_query_paging.py b/tests/integration/standard/test_query_paging.py index 3dce069c..09438abc 100644 --- a/tests/integration/standard/test_query_paging.py +++ b/tests/integration/standard/test_query_paging.py @@ -27,7 +27,7 @@ from six.moves import range from threading import Event from cassandra.cluster import Cluster -from cassandra.concurrent import execute_concurrent +from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args from cassandra.policies import HostDistance from cassandra.query import SimpleStatement @@ -266,3 +266,18 @@ class QueryPagingTests(unittest.TestCase): future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error) event.wait() self.assertEquals(next(counter), 100) + + def test_concurrent_with_paging(self): + statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), + [(i, ) for i in range(100)]) + execute_concurrent(self.session, list(statements_and_params)) + + prepared = self.session.prepare("SELECT * FROM test3rf.test") + + for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000): + self.session.default_fetch_size = fetch_size + results = execute_concurrent_with_args(self.session, prepared, [None] * 10) + self.assertEquals(10, len(results)) + for (success, result) in results: + self.assertTrue(success) + self.assertEquals(100, len(list(result)))