speculative execution policies

PYTHON-218
This commit is contained in:
Adam Holmberg
2016-08-29 16:25:13 -05:00
parent 3788378906
commit df93b8403a
5 changed files with 205 additions and 93 deletions

View File

@@ -62,7 +62,8 @@ from cassandra.protocol import (QueryMessage, ResultMessage,
from cassandra.metadata import Metadata, protect_name, murmur3
from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy,
ExponentialReconnectionPolicy, HostDistance,
RetryPolicy, IdentityTranslator)
RetryPolicy, IdentityTranslator, NoSpeculativeExecutionPlan,
NoSpeculativeExecutionPolicy)
from cassandra.pool import (Host, _ReconnectionHandler, _HostReconnectionHandler,
HostConnectionPool, HostConnection,
NoConnectionsAvailable)
@@ -240,15 +241,23 @@ class ExecutionProfile(object):
- :func:`cassandra.query.ordered_dict_factory` - return a result row as an OrderedDict
"""
speculative_execution_policy = None
"""
An instance of :class:`.policies.SpeculativeExecutionPolicy`
Defaults to :class:`.NoSpeculativeExecutionPolicy` if not specified
"""
def __init__(self, load_balancing_policy=None, retry_policy=None,
consistency_level=ConsistencyLevel.LOCAL_ONE, serial_consistency_level=None,
request_timeout=10.0, row_factory=named_tuple_factory):
request_timeout=10.0, row_factory=named_tuple_factory, speculative_execution_policy=None):
self.load_balancing_policy = load_balancing_policy or default_lbp_factory()
self.retry_policy = retry_policy or RetryPolicy()
self.consistency_level = consistency_level
self.serial_consistency_level = serial_consistency_level
self.request_timeout = request_timeout
self.row_factory = row_factory
self.speculative_execution_policy = speculative_execution_policy or NoSpeculativeExecutionPolicy()
class ProfileManager(object):
@@ -2058,6 +2067,7 @@ class Session(object):
retry_policy = query.retry_policy or self.cluster.default_retry_policy
row_factory = self.row_factory
load_balancing_policy = self.cluster.load_balancing_policy
spec_exec_policy = None
else:
execution_profile = self._get_execution_profile(execution_profile)
@@ -2070,6 +2080,8 @@ class Session(object):
retry_policy = query.retry_policy or execution_profile.retry_policy
row_factory = execution_profile.row_factory
load_balancing_policy = execution_profile.load_balancing_policy
spec_exec_policy = execution_profile.speculative_execution_policy
fetch_size = query.fetch_size
if fetch_size is FETCH_SIZE_UNSET and self._protocol_version >= 2:
@@ -2077,8 +2089,9 @@ class Session(object):
elif self._protocol_version == 1:
fetch_size = None
start_time = time.time()
if self._protocol_version >= 3 and self.use_client_timestamp:
timestamp = int(time.time() * 1e6)
timestamp = int(start_time * 1e6)
else:
timestamp = None
@@ -2112,9 +2125,11 @@ class Session(object):
message.allow_beta_protocol_version = self.cluster.allow_beta_protocol_version
message.paging_state = paging_state
spec_exec_plan = spec_exec_policy.new_plan(query.keyspace or self.keyspace, query) if query.is_idempotent and spec_exec_policy else None
return ResponseFuture(
self, message, query, timeout, metrics=self._metrics,
prepared_statement=prepared_statement, retry_policy=retry_policy, row_factory=row_factory, load_balancer=load_balancing_policy)
prepared_statement=prepared_statement, retry_policy=retry_policy, row_factory=row_factory,
load_balancer=load_balancing_policy, start_time=start_time, speculative_execution_plan=spec_exec_plan)
def _get_execution_profile(self, ep):
profiles = self.cluster.profile_manager.profiles
@@ -3172,11 +3187,11 @@ class _Scheduler(Thread):
exc_info=exc)
def refresh_schema_and_set_result(control_conn, response_future, **kwargs):
def refresh_schema_and_set_result(control_conn, response_future, connection, **kwargs):
try:
log.debug("Refreshing schema in response to schema change. "
"%s", kwargs)
response_future.is_schema_agreed = control_conn._refresh_schema(response_future._connection, **kwargs)
response_future.is_schema_agreed = control_conn._refresh_schema(connection, **kwargs)
except Exception:
log.exception("Exception refreshing schema in response to schema change:")
response_future.session.submit(control_conn.refresh_schema, **kwargs)
@@ -3214,6 +3229,14 @@ class ResponseFuture(object):
Size of the request message sent
"""
coordinator_host = None
"""
The host an actual result from (
"""
"""
"""
session = None
row_factory = None
message = None
@@ -3230,7 +3253,6 @@ class ResponseFuture(object):
_callbacks = None
_errbacks = None
_current_host = None
_current_pool = None
_connection = None
_query_retries = 0
_start_time = None
@@ -3240,11 +3262,12 @@ class ResponseFuture(object):
_warnings = None
_timer = None
_protocol_handler = ProtocolHandler
_spec_execution_plan = NoSpeculativeExecutionPlan()
_warned_timeout = False
def __init__(self, session, message, query, timeout, metrics=None, prepared_statement=None,
retry_policy=RetryPolicy(), row_factory=None, load_balancer=None):
retry_policy=RetryPolicy(), row_factory=None, load_balancer=None, start_time=None, speculative_execution_plan=None):
self.session = session
# TODO: normalize handling of retry policy and row factory
self.row_factory = row_factory or session.row_factory
@@ -3252,21 +3275,29 @@ class ResponseFuture(object):
self.message = message
self.query = query
self.timeout = timeout
self._time_remaining = timeout
self._retry_policy = retry_policy
self._metrics = metrics
self.prepared_statement = prepared_statement
self._callback_lock = Lock()
if metrics is not None:
self._start_time = time.time()
self._start_time = start_time or time.time()
self._make_query_plan()
self._event = Event()
self._errors = {}
self._callbacks = []
self._errbacks = []
self._spec_execution_plan = speculative_execution_plan or self._spec_execution_plan
self._queried_hosts = []
def _start_timer(self):
if self.timeout is not None:
self._timer = self.session.cluster.connection_class.create_timer(self.timeout, self._on_timeout)
if self._timer is None:
spec_delay = self._spec_execution_plan.next_execution(self._current_host)
if spec_delay >= 0:
if self._time_remaining is None or self._time_remaining > spec_delay:
self._timer = self.session.cluster.connection_class.create_timer(spec_delay, self._on_speculative_execute)
return
if self._time_remaining is not None:
self._timer = self.session.cluster.connection_class.create_timer(self._time_remaining, self._on_timeout)
def _cancel_timer(self):
if self._timer:
@@ -3284,17 +3315,29 @@ class ResponseFuture(object):
self._set_final_exception(OperationTimedOut(errors, self._current_host))
def _on_speculative_execute(self):
self._timer = None
if not self._event.is_set():
if self._time_remaining is not None:
elapsed = time.time() - self._start_time
self._time_remaining -= elapsed
if self._time_remaining <= 0:
self._on_timeout()
return
if not self.send_request(error_no_hosts=False):
self._start_timer()
def _make_query_plan(self):
# convert the list/generator/etc to an iterator so that subsequent
# calls to send_request (which retries may do) will resume where
# they last left off
self.query_plan = iter(self._load_balancer.make_query_plan(self.session.keyspace, self.query))
def send_request(self):
def send_request(self, error_no_hosts=True):
""" Internal """
# query_plan is an iterator, so this will resume where we last left
# off if send_request() is called multiple times
start = time.time()
for host in self.query_plan:
req_id = self._query(host)
if req_id is not None:
@@ -3303,23 +3346,21 @@ class ResponseFuture(object):
# timer is only started here, after we have at least one message queued
# this is done to avoid overrun of timers with unfettered client requests
# in the case of full disconnect, where no hosts will be available
if self._timer is None:
self._start_timer()
return
if self.timeout is not None and time.time() - start > self.timeout:
self._start_timer()
return True
if self.timeout is not None and time.time() - self._start_time > self.timeout:
self._on_timeout()
return
return True
self._set_final_exception(NoHostAvailable(
"Unable to complete the operation against any hosts", self._errors))
if error_no_hosts:
self._set_final_exception(NoHostAvailable(
"Unable to complete the operation against any hosts", self._errors))
return False
def _query(self, host, message=None, cb=None):
if message is None:
message = self.message
if cb is None:
cb = self._set_result
pool = self.session._pools.get(host)
if not pool:
self._errors[host] = ConnectionException("Host has been marked down or removed")
@@ -3329,7 +3370,6 @@ class ResponseFuture(object):
return None
self._current_host = host
self._current_pool = pool
connection = None
try:
@@ -3337,6 +3377,10 @@ class ResponseFuture(object):
connection, request_id = pool.borrow_connection(timeout=2.0)
self._connection = connection
result_meta = self.prepared_statement.result_metadata if self.prepared_statement else []
if cb is None:
cb = partial(self._set_result, host, connection, pool)
self.request_encoded_size = connection.send_msg(message, request_id, cb=cb,
encoder=self._protocol_handler.encode_message,
decoder=self._protocol_handler.decode_message,
@@ -3423,17 +3467,18 @@ class ResponseFuture(object):
self._timer = None # clear cancelled timer; new one will be set when request is queued
self.send_request()
def _reprepare(self, prepare_message):
cb = partial(self.session.submit, self._execute_after_prepare)
request_id = self._query(self._current_host, prepare_message, cb=cb)
def _reprepare(self, prepare_message, host, connection, pool):
cb = partial(self.session.submit, self._execute_after_prepare, host, connection, pool)
request_id = self._query(host, prepare_message, cb=cb)
if request_id is None:
# try to submit the original prepared statement on some other host
self.send_request()
def _set_result(self, response):
def _set_result(self, host, connection, pool, response):
try:
if self._current_pool and self._connection:
self._current_pool.return_connection(self._connection)
self.coordinator_host = host
if pool:
pool.return_connection(connection)
trace_id = getattr(response, 'trace_id', None)
if trace_id:
@@ -3464,7 +3509,7 @@ class ResponseFuture(object):
self.session.submit(
refresh_schema_and_set_result,
self.session.cluster.control_connection,
self, **response.results)
self, connection, **response.results)
else:
results = getattr(response, 'results', None)
if results is not None and response.kind == RESULT_KIND_ROWS:
@@ -3495,14 +3540,14 @@ class ResponseFuture(object):
self._metrics.on_other_error()
# need to retry against a different host here
log.warning("Host %s is overloaded, retrying against a different "
"host", self._current_host)
self._retry(reuse_connection=False, consistency_level=None)
"host", host)
self._retry(reuse_connection=False, consistency_level=None, host=host)
return
elif isinstance(response, IsBootstrappingErrorMessage):
if self._metrics is not None:
self._metrics.on_other_error()
# need to retry against a different host here
self._retry(reuse_connection=False, consistency_level=None)
self._retry(reuse_connection=False, consistency_level=None, host=host)
return
elif isinstance(response, PreparedQueryNotFound):
if self.prepared_statement:
@@ -3536,11 +3581,11 @@ class ResponseFuture(object):
return
log.debug("Re-preparing unrecognized prepared statement against host %s: %s",
self._current_host, prepared_statement.query_string)
host, prepared_statement.query_string)
prepare_message = PrepareMessage(query=prepared_statement.query_string)
# since this might block, run on the executor to avoid hanging
# the event loop thread
self.session.submit(self._reprepare, prepare_message)
self.session.submit(self._reprepare, prepare_message, host, connection, pool)
return
else:
if hasattr(response, 'to_exception'):
@@ -3553,20 +3598,20 @@ class ResponseFuture(object):
if retry_type in (RetryPolicy.RETRY, RetryPolicy.RETRY_NEXT_HOST):
self._query_retries += 1
reuse = retry_type == RetryPolicy.RETRY
self._retry(reuse_connection=reuse, consistency_level=consistency)
self._retry(reuse, consistency, host)
elif retry_type is RetryPolicy.RETHROW:
self._set_final_exception(response.to_exception())
else: # IGNORE
if self._metrics is not None:
self._metrics.on_ignore()
self._set_final_result(None)
self._errors[self._current_host] = response.to_exception()
self._errors[host] = response.to_exception()
elif isinstance(response, ConnectionException):
if self._metrics is not None:
self._metrics.on_connection_error()
if not isinstance(response, ConnectionShutdown):
self._connection.defunct(response)
self._retry(reuse_connection=False, consistency_level=None)
self._retry(reuse_connection=False, consistency_level=None, host=host)
elif isinstance(response, Exception):
if hasattr(response, 'to_exception'):
self._set_final_exception(response.to_exception())
@@ -3575,7 +3620,7 @@ class ResponseFuture(object):
else:
# we got some other kind of response message
msg = "Got unexpected message: %r" % (response,)
exc = ConnectionException(msg, self._current_host)
exc = ConnectionException(msg, host)
self._connection.defunct(exc)
self._set_final_exception(exc)
except Exception as exc:
@@ -3590,13 +3635,14 @@ class ResponseFuture(object):
self._set_final_exception(ConnectionException(
"Failed to set keyspace on all hosts: %s" % (errors,)))
def _execute_after_prepare(self, response):
def _execute_after_prepare(self, host, connection, pool, response):
"""
Handle the response to our attempt to prepare a statement.
If it succeeded, run the original query again against the same host.
"""
if self._current_pool and self._connection:
self._current_pool.return_connection(self._connection)
"AFTER PREPARE"
if pool:
pool.return_connection(connection)
if self._final_exception:
return
@@ -3609,14 +3655,14 @@ class ResponseFuture(object):
# use self._query to re-use the same host and
# at the same time properly borrow the connection
request_id = self._query(self._current_host)
request_id = self._query(host)
if request_id is None:
# this host errored out, move on to the next
self.send_request()
else:
self._set_final_exception(ConnectionException(
"Got unexpected response when preparing statement "
"on host %s: %s" % (self._current_host, response)))
"on host %s: %s" % (host, response)))
elif isinstance(response, ErrorMessage):
if hasattr(response, 'to_exception'):
self._set_final_exception(response.to_exception())
@@ -3624,14 +3670,14 @@ class ResponseFuture(object):
self._set_final_exception(response)
elif isinstance(response, ConnectionException):
log.debug("Connection error when preparing statement on host %s: %s",
self._current_host, response)
host, response)
# try again on a different host, preparing again if necessary
self._errors[self._current_host] = response
self._errors[host] = response
self.send_request()
else:
self._set_final_exception(ConnectionException(
"Got unexpected response type when preparing "
"statement on host %s: %s" % (self._current_host, response)))
"statement on host %s: %s" % (host, response)))
def _set_final_result(self, response):
self._cancel_timer()
@@ -3661,7 +3707,7 @@ class ResponseFuture(object):
fn, args, kwargs = errback
fn(response, *args, **kwargs)
def _retry(self, reuse_connection, consistency_level):
def _retry(self, reuse_connection, consistency_level, host):
if self._final_exception:
# the connection probably broke while we were waiting
# to retry the operation
@@ -3673,15 +3719,15 @@ class ResponseFuture(object):
self.message.consistency_level = consistency_level
# don't retry on the event loop thread
self.session.submit(self._retry_task, reuse_connection)
self.session.submit(self._retry_task, reuse_connection, host)
def _retry_task(self, reuse_connection):
def _retry_task(self, reuse_connection, host):
if self._final_exception:
# the connection probably broke while we were waiting
# to retry the operation
return
if reuse_connection and self._query(self._current_host) is not None:
if reuse_connection and self._query(host) is not None:
return
# otherwise, move onto another host
@@ -3852,8 +3898,8 @@ class ResponseFuture(object):
def __str__(self):
result = "(no result yet)" if self._final_result is _NOT_SET else self._final_result
return "<ResponseFuture: query='%s' request_id=%s result=%s exception=%s host=%s>" \
% (self.query, self._req_id, result, self._final_exception, self._current_host)
return "<ResponseFuture: query='%s' request_id=%s result=%s exception=%s coordinator_host=%s>" \
% (self.query, self._req_id, result, self._final_exception, self.coordinator_host)
__repr__ = __str__

View File

@@ -875,7 +875,7 @@ class AddressTranslator(object):
"""
Accepts the node ip address, and returns a translated address to be used connecting to this node.
"""
raise NotImplementedError
raise NotImplementedError()
class IdentityTranslator(AddressTranslator):
@@ -904,3 +904,57 @@ class EC2MultiRegionTranslator(AddressTranslator):
except Exception:
pass
return addr
class SpeculativeExecutionPolicy(object):
"""
Interface for specifying speculative execution plans
"""
def new_plan(self, keyspace, statement):
"""
Returns
:param keyspace:
:param statement:
:return:
"""
raise NotImplementedError()
class SpeculativeExecutionPlan(object):
def next_execution(self, host):
raise NotImplementedError()
class NoSpeculativeExecutionPlan(SpeculativeExecutionPlan):
def next_execution(self, host):
return -1
class NoSpeculativeExecutionPolicy(SpeculativeExecutionPolicy):
def new_plan(self, keyspace, statement):
return self.NoSpeculativeExecutionPlan()
class ConstantSpeculativeExecutionPolicy(SpeculativeExecutionPolicy):
def __init__(self, delay, max_attempts):
self.delay = delay
self.max_attempts = max_attempts
class ConstantSpeculativeExecutionPlan(SpeculativeExecutionPlan):
def __init__(self, delay, max_attempts):
self.delay = delay
self.remaining = max_attempts
def next_execution(self, host):
if self.remaining > 0:
self.remaining -= 1
return self.delay
else:
return -1
def new_plan(self, keyspace, statement):
return self.ConstantSpeculativeExecutionPlan(self.delay, self.max_attempts)

View File

@@ -215,11 +215,17 @@ class Statement(object):
.. versionadded:: 2.6.0
"""
is_idempotent = False
"""
Flag indicating whether this statement is safe to run multiple times in speculative execution.
"""
_serial_consistency_level = None
_routing_key = None
def __init__(self, retry_policy=None, consistency_level=None, routing_key=None,
serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None):
serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None,
is_idempotent=False):
if retry_policy and not hasattr(retry_policy, 'on_read_timeout'): # just checking one method to detect positional parameter errors
raise ValueError('retry_policy should implement cassandra.policies.RetryPolicy')
self.retry_policy = retry_policy
@@ -234,6 +240,7 @@ class Statement(object):
self.keyspace = keyspace
if custom_payload is not None:
self.custom_payload = custom_payload
self.is_idempotent = is_idempotent
def _key_parts_packed(self, parts):
for p in parts:
@@ -328,7 +335,7 @@ class SimpleStatement(Statement):
def __init__(self, query_string, retry_policy=None, consistency_level=None, routing_key=None,
serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None,
custom_payload=None):
custom_payload=None, is_idempotent=False):
"""
`query_string` should be a literal CQL statement with the exception
of parameter placeholders that will be filled through the
@@ -337,7 +344,7 @@ class SimpleStatement(Statement):
See :class:`Statement` attributes for a description of the other parameters.
"""
Statement.__init__(self, retry_policy, consistency_level, routing_key,
serial_consistency_level, fetch_size, keyspace, custom_payload)
serial_consistency_level, fetch_size, keyspace, custom_payload, is_idempotent)
self._query_string = query_string
@property
@@ -383,6 +390,7 @@ class PreparedStatement(object):
self.keyspace = keyspace
self.protocol_version = protocol_version
self.result_metadata = result_metadata
self.is_idempotent = False
@classmethod
def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata,
@@ -465,6 +473,7 @@ class BoundStatement(Statement):
self.serial_consistency_level = prepared_statement.serial_consistency_level
self.fetch_size = prepared_statement.fetch_size
self.custom_payload = prepared_statement.custom_payload
self.is_idempotent = prepared_statement.is_idempotent
self.values = []
meta = prepared_statement.column_metadata

View File

@@ -69,7 +69,7 @@ class ResponseFutureTests(unittest.TestCase):
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[])
rf._set_result(self.make_mock_response([{'col': 'val'}]))
rf._set_result(None, None, None, self.make_mock_response([{'col': 'val'}]))
result = rf.result()
self.assertEqual(result, [{'col': 'val'}])
@@ -81,7 +81,7 @@ class ResponseFutureTests(unittest.TestCase):
rf = self.make_response_future(session)
rf.send_request()
rf._set_result(object())
rf._set_result(None, None, None, object())
self.assertRaises(ConnectionException, rf.result)
def test_set_keyspace_result(self):
@@ -92,7 +92,7 @@ class ResponseFutureTests(unittest.TestCase):
result = Mock(spec=ResultMessage,
kind=RESULT_KIND_SET_KEYSPACE,
results="keyspace1")
rf._set_result(result)
rf._set_result(None, None, None, result)
rf._set_keyspace_completed({})
self.assertFalse(rf.result())
@@ -106,15 +106,16 @@ class ResponseFutureTests(unittest.TestCase):
result = Mock(spec=ResultMessage,
kind=RESULT_KIND_SCHEMA_CHANGE,
results=event_results)
rf._set_result(result)
session.submit.assert_called_once_with(ANY, ANY, rf, **event_results)
connection = Mock()
rf._set_result(None, connection, None, result)
session.submit.assert_called_once_with(ANY, ANY, rf, connection, **event_results)
def test_other_result_message_kind(self):
session = self.make_session()
rf = self.make_response_future(session)
rf.send_request()
result = [1, 2, 3]
rf._set_result(Mock(spec=ResultMessage, kind=999, results=result))
rf._set_result(None, None, None, Mock(spec=ResultMessage, kind=999, results=result))
self.assertListEqual(list(rf.result()), result)
def test_read_timeout_error_message(self):
@@ -128,7 +129,7 @@ class ResponseFutureTests(unittest.TestCase):
rf.send_request()
result = Mock(spec=ReadTimeoutErrorMessage, info={})
rf._set_result(result)
rf._set_result(None, None, None, result)
self.assertRaises(Exception, rf.result)
@@ -143,7 +144,7 @@ class ResponseFutureTests(unittest.TestCase):
rf.send_request()
result = Mock(spec=WriteTimeoutErrorMessage, info={})
rf._set_result(result)
rf._set_result(None, None, None, result)
self.assertRaises(Exception, rf.result)
def test_unavailable_error_message(self):
@@ -157,7 +158,7 @@ class ResponseFutureTests(unittest.TestCase):
rf.send_request()
result = Mock(spec=UnavailableErrorMessage, info={})
rf._set_result(result)
rf._set_result(None, None, None, result)
self.assertRaises(Exception, rf.result)
def test_retry_policy_says_ignore(self):
@@ -171,7 +172,7 @@ class ResponseFutureTests(unittest.TestCase):
rf.send_request()
result = Mock(spec=UnavailableErrorMessage, info={})
rf._set_result(result)
rf._set_result(None, None, None, result)
self.assertFalse(rf.result())
def test_retry_policy_says_retry(self):
@@ -195,20 +196,21 @@ class ResponseFutureTests(unittest.TestCase):
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[])
result = Mock(spec=UnavailableErrorMessage, info={})
rf._set_result(result)
host = Mock()
rf._set_result(host, None, None, result)
session.submit.assert_called_once_with(rf._retry_task, True)
session.submit.assert_called_once_with(rf._retry_task, True, host)
self.assertEqual(1, rf._query_retries)
connection = Mock(spec=Connection)
pool.borrow_connection.return_value = (connection, 2)
# simulate the executor running this
rf._retry_task(True)
rf._retry_task(True, host)
# it should try again with the same host since this was
# an UnavailableException
rf.session._pools.get.assert_called_with('ip1')
rf.session._pools.get.assert_called_with(host)
pool.borrow_connection.assert_called_with(timeout=ANY)
connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[])
@@ -229,16 +231,17 @@ class ResponseFutureTests(unittest.TestCase):
self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level)
result = Mock(spec=OverloadedErrorMessage, info={})
rf._set_result(result)
host = Mock()
rf._set_result(host, None, None, result)
session.submit.assert_called_once_with(rf._retry_task, False)
session.submit.assert_called_once_with(rf._retry_task, False, host)
# query_retries does not get incremented for Overloaded/Bootstrapping errors
self.assertEqual(0, rf._query_retries)
connection = Mock(spec=Connection)
pool.borrow_connection.return_value = (connection, 2)
# simulate the executor running this
rf._retry_task(False)
rf._retry_task(False, host)
# it should try with a different host
rf.session._pools.get.assert_called_with('ip2')
@@ -259,21 +262,22 @@ class ResponseFutureTests(unittest.TestCase):
rf.session._pools.get.assert_called_once_with('ip1')
result = Mock(spec=IsBootstrappingErrorMessage, info={})
rf._set_result(result)
host = Mock()
rf._set_result(host, None, None, result)
# simulate the executor running this
session.submit.assert_called_once_with(rf._retry_task, False)
rf._retry_task(False)
session.submit.assert_called_once_with(rf._retry_task, False, host)
rf._retry_task(False, host)
# it should try with a different host
rf.session._pools.get.assert_called_with('ip2')
result = Mock(spec=IsBootstrappingErrorMessage, info={})
rf._set_result(result)
rf._set_result(host, None, None, result)
# simulate the executor running this
session.submit.assert_called_with(rf._retry_task, False)
rf._retry_task(False)
session.submit.assert_called_with(rf._retry_task, False, host)
rf._retry_task(False, host)
self.assertRaises(NoHostAvailable, rf.result)
@@ -295,7 +299,7 @@ class ResponseFutureTests(unittest.TestCase):
rf = self.make_response_future(session)
rf.send_request()
rf._set_result(self.make_mock_response([{'col': 'val'}]))
rf._set_result(None, None, None, self.make_mock_response([{'col': 'val'}]))
result = rf.result()
self.assertEqual(result, [{'col': 'val'}])
@@ -319,7 +323,7 @@ class ResponseFutureTests(unittest.TestCase):
rf = self.make_response_future(session)
rf.send_request()
rf._set_result(self.make_mock_response([{'col': 'val'}]))
rf._set_result(None, None, None, self.make_mock_response([{'col': 'val'}]))
self.assertEqual(rf.result(), [{'col': 'val'}])
# make sure the exception is recorded correctly
@@ -336,7 +340,7 @@ class ResponseFutureTests(unittest.TestCase):
kwargs = {'one': 1, 'two': 2}
rf.add_callback(callback, arg, **kwargs)
rf._set_result(self.make_mock_response(expected_result))
rf._set_result(None, None, None, self.make_mock_response(expected_result))
result = rf.result()
self.assertEqual(result, expected_result)
@@ -363,7 +367,7 @@ class ResponseFutureTests(unittest.TestCase):
rf.add_errback(self.assertIsInstance, Exception)
result = Mock(spec=UnavailableErrorMessage, info={})
rf._set_result(result)
rf._set_result(None, None, None, result)
self.assertRaises(Exception, rf.result)
# this should get called immediately now that the error is set
@@ -385,7 +389,7 @@ class ResponseFutureTests(unittest.TestCase):
kwargs2 = {'three': 3, 'four': 4}
rf.add_callback(callback2, arg2, **kwargs2)
rf._set_result(self.make_mock_response(expected_result))
rf._set_result(None, None, None, self.make_mock_response(expected_result))
result = rf.result()
self.assertEqual(result, expected_result)
@@ -420,7 +424,7 @@ class ResponseFutureTests(unittest.TestCase):
expected_exception = Unavailable("message", 1, 2, 3)
result = Mock(spec=UnavailableErrorMessage, info={'something': 'here'})
result.to_exception.return_value = expected_exception
rf._set_result(result)
rf._set_result(None, None, None, result)
self.assertRaises(Exception, rf.result)
callback.assert_called_once_with(expected_exception, arg, **kwargs)
@@ -442,7 +446,7 @@ class ResponseFutureTests(unittest.TestCase):
errback=self.assertIsInstance, errback_args=(Exception,))
result = Mock(spec=UnavailableErrorMessage, info={})
rf._set_result(result)
rf._set_result(None, None, None, result)
self.assertRaises(Exception, rf.result)
# test callback
@@ -457,7 +461,7 @@ class ResponseFutureTests(unittest.TestCase):
callback=callback, callback_args=(arg,), callback_kwargs=kwargs,
errback=self.assertIsInstance, errback_args=(Exception,))
rf._set_result(self.make_mock_response(expected_result))
rf._set_result(None, None, None, self.make_mock_response(expected_result))
self.assertEqual(rf.result(), expected_result)
callback.assert_called_once_with(expected_result, arg, **kwargs)
@@ -478,7 +482,7 @@ class ResponseFutureTests(unittest.TestCase):
rf._connection.keyspace = "FooKeyspace"
result = Mock(spec=PreparedQueryNotFound, info='a' * 16)
rf._set_result(result)
rf._set_result(None, None, None, result)
self.assertTrue(session.submit.call_args)
args, kwargs = session.submit.call_args
@@ -502,5 +506,5 @@ class ResponseFutureTests(unittest.TestCase):
rf._connection.keyspace = "BarKeyspace"
result = Mock(spec=PreparedQueryNotFound, info='a' * 16)
rf._set_result(result)
rf._set_result(None, None, None, result)
self.assertRaises(ValueError, rf.result)

View File

@@ -19,7 +19,6 @@ except ImportError:
import unittest # noqa
from mock import Mock, PropertyMock
import warnings
from cassandra.cluster import ResultSet