Fix cassandra.concurrent behavior with automatic paging
Fixes PYTHON-81
This commit is contained in:
@@ -32,6 +32,8 @@ Bug Fixes
|
||||
* Don't share prepared statement lock across Cluster instances
|
||||
* Format CompositeType and DynamicCompositeType columns correctly in
|
||||
CREATE TABLE statements.
|
||||
* Fix cassandra.concurrent behavior when dealing with automatic paging
|
||||
(PYTHON-81)
|
||||
|
||||
2.0.2
|
||||
=====
|
||||
|
@@ -2704,6 +2704,11 @@ class ResponseFuture(object):
|
||||
self.add_callback(callback, *callback_args, **(callback_kwargs or {}))
|
||||
self.add_errback(errback, *errback_args, **(errback_kwargs or {}))
|
||||
|
||||
def clear_callbacks(self):
|
||||
with self._callback_lock:
|
||||
self._callback = None
|
||||
self._errback = None
|
||||
|
||||
def __str__(self):
|
||||
result = "(no result yet)" if self._final_result is _NOT_SET else self._final_result
|
||||
return "<ResponseFuture: query='%s' request_id=%s result=%s exception=%s host=%s>" \
|
||||
@@ -2744,6 +2749,8 @@ class PagedResult(object):
|
||||
.. versionadded: 2.0.0
|
||||
"""
|
||||
|
||||
response_future = None
|
||||
|
||||
def __init__(self, response_future, initial_response):
|
||||
self.response_future = response_future
|
||||
self.current_response = iter(initial_response)
|
||||
@@ -2755,7 +2762,7 @@ class PagedResult(object):
|
||||
try:
|
||||
return next(self.current_response)
|
||||
except StopIteration:
|
||||
if self.response_future._paging_state is None:
|
||||
if not self.response_future.has_more_pages:
|
||||
raise
|
||||
|
||||
self.response_future.start_fetching_next_page()
|
||||
|
@@ -16,9 +16,14 @@ import six
|
||||
import sys
|
||||
|
||||
from itertools import count, cycle
|
||||
import logging
|
||||
from six.moves import xrange
|
||||
from threading import Event
|
||||
|
||||
from cassandra.cluster import PagedResult
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def execute_concurrent(session, statements_and_parameters, concurrency=100, raise_on_first_error=True):
|
||||
"""
|
||||
@@ -81,7 +86,7 @@ def execute_concurrent(session, statements_and_parameters, concurrency=100, rais
|
||||
num_finished = count(start=1)
|
||||
statements = enumerate(iter(statements_and_parameters))
|
||||
for i in xrange(min(concurrency, len(statements_and_parameters))):
|
||||
_execute_next(_sentinel, i, event, session, statements, results, num_finished, to_execute, first_error)
|
||||
_execute_next(_sentinel, i, event, session, statements, results, None, num_finished, to_execute, first_error)
|
||||
|
||||
event.wait()
|
||||
if first_error:
|
||||
@@ -113,7 +118,8 @@ def execute_concurrent_with_args(session, statement, parameters, *args, **kwargs
|
||||
_sentinel = object()
|
||||
|
||||
|
||||
def _handle_error(error, result_index, event, session, statements, results, num_finished, to_execute, first_error):
|
||||
def _handle_error(error, result_index, event, session, statements, results,
|
||||
future, num_finished, to_execute, first_error):
|
||||
if first_error is not None:
|
||||
first_error.append(error)
|
||||
event.set()
|
||||
@@ -129,9 +135,10 @@ def _handle_error(error, result_index, event, session, statements, results, num_
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
args = (next_index, event, session, statements, results, num_finished, to_execute, first_error)
|
||||
try:
|
||||
session.execute_async(statement, params).add_callbacks(
|
||||
future = session.execute_async(statement, params)
|
||||
args = (next_index, event, session, statements, results, future, num_finished, to_execute, first_error)
|
||||
future.add_callbacks(
|
||||
callback=_execute_next, callback_args=args,
|
||||
errback=_handle_error, errback_args=args)
|
||||
except Exception as exc:
|
||||
@@ -149,8 +156,12 @@ def _handle_error(error, result_index, event, session, statements, results, num_
|
||||
return
|
||||
|
||||
|
||||
def _execute_next(result, result_index, event, session, statements, results, num_finished, to_execute, first_error):
|
||||
def _execute_next(result, result_index, event, session, statements, results,
|
||||
future, num_finished, to_execute, first_error):
|
||||
if result is not _sentinel:
|
||||
if future.has_more_pages:
|
||||
result = PagedResult(future, result)
|
||||
future.clear_callbacks()
|
||||
results[result_index] = (True, result)
|
||||
finished = next(num_finished)
|
||||
if finished >= to_execute:
|
||||
@@ -162,9 +173,10 @@ def _execute_next(result, result_index, event, session, statements, results, num
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
args = (next_index, event, session, statements, results, num_finished, to_execute, first_error)
|
||||
try:
|
||||
session.execute_async(statement, params).add_callbacks(
|
||||
future = session.execute_async(statement, params)
|
||||
args = (next_index, event, session, statements, results, future, num_finished, to_execute, first_error)
|
||||
future.add_callbacks(
|
||||
callback=_execute_next, callback_args=args,
|
||||
errback=_handle_error, errback_args=args)
|
||||
except Exception as exc:
|
||||
|
@@ -27,7 +27,7 @@ from six.moves import range
|
||||
from threading import Event
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.concurrent import execute_concurrent
|
||||
from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args
|
||||
from cassandra.policies import HostDistance
|
||||
from cassandra.query import SimpleStatement
|
||||
|
||||
@@ -266,3 +266,18 @@ class QueryPagingTests(unittest.TestCase):
|
||||
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
|
||||
event.wait()
|
||||
self.assertEquals(next(counter), 100)
|
||||
|
||||
def test_concurrent_with_paging(self):
|
||||
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
|
||||
[(i, ) for i in range(100)])
|
||||
execute_concurrent(self.session, list(statements_and_params))
|
||||
|
||||
prepared = self.session.prepare("SELECT * FROM test3rf.test")
|
||||
|
||||
for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
|
||||
self.session.default_fetch_size = fetch_size
|
||||
results = execute_concurrent_with_args(self.session, prepared, [None] * 10)
|
||||
self.assertEquals(10, len(results))
|
||||
for (success, result) in results:
|
||||
self.assertTrue(success)
|
||||
self.assertEquals(100, len(list(result)))
|
||||
|
Reference in New Issue
Block a user