diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 95aff072..75304a9b 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -62,7 +62,8 @@ from cassandra.protocol import (QueryMessage, ResultMessage, from cassandra.metadata import Metadata, protect_name, murmur3 from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy, ExponentialReconnectionPolicy, HostDistance, - RetryPolicy, IdentityTranslator) + RetryPolicy, IdentityTranslator, NoSpeculativeExecutionPlan, + NoSpeculativeExecutionPolicy) from cassandra.pool import (Host, _ReconnectionHandler, _HostReconnectionHandler, HostConnectionPool, HostConnection, NoConnectionsAvailable) @@ -240,15 +241,23 @@ class ExecutionProfile(object): - :func:`cassandra.query.ordered_dict_factory` - return a result row as an OrderedDict """ + speculative_execution_policy = None + """ + An instance of :class:`.policies.SpeculativeExecutionPolicy` + + Defaults to :class:`.NoSpeculativeExecutionPolicy` if not specified + """ + def __init__(self, load_balancing_policy=None, retry_policy=None, consistency_level=ConsistencyLevel.LOCAL_ONE, serial_consistency_level=None, - request_timeout=10.0, row_factory=named_tuple_factory): + request_timeout=10.0, row_factory=named_tuple_factory, speculative_execution_policy=None): self.load_balancing_policy = load_balancing_policy or default_lbp_factory() self.retry_policy = retry_policy or RetryPolicy() self.consistency_level = consistency_level self.serial_consistency_level = serial_consistency_level self.request_timeout = request_timeout self.row_factory = row_factory + self.speculative_execution_policy = speculative_execution_policy or NoSpeculativeExecutionPolicy() class ProfileManager(object): @@ -2058,6 +2067,7 @@ class Session(object): retry_policy = query.retry_policy or self.cluster.default_retry_policy row_factory = self.row_factory load_balancing_policy = self.cluster.load_balancing_policy + spec_exec_policy = None else: execution_profile = self._get_execution_profile(execution_profile) @@ -2070,6 +2080,8 @@ class Session(object): retry_policy = query.retry_policy or execution_profile.retry_policy row_factory = execution_profile.row_factory load_balancing_policy = execution_profile.load_balancing_policy + spec_exec_policy = execution_profile.speculative_execution_policy + fetch_size = query.fetch_size if fetch_size is FETCH_SIZE_UNSET and self._protocol_version >= 2: @@ -2077,8 +2089,9 @@ class Session(object): elif self._protocol_version == 1: fetch_size = None + start_time = time.time() if self._protocol_version >= 3 and self.use_client_timestamp: - timestamp = int(time.time() * 1e6) + timestamp = int(start_time * 1e6) else: timestamp = None @@ -2112,9 +2125,11 @@ class Session(object): message.allow_beta_protocol_version = self.cluster.allow_beta_protocol_version message.paging_state = paging_state + spec_exec_plan = spec_exec_policy.new_plan(query.keyspace or self.keyspace, query) if query.is_idempotent and spec_exec_policy else None return ResponseFuture( self, message, query, timeout, metrics=self._metrics, - prepared_statement=prepared_statement, retry_policy=retry_policy, row_factory=row_factory, load_balancer=load_balancing_policy) + prepared_statement=prepared_statement, retry_policy=retry_policy, row_factory=row_factory, + load_balancer=load_balancing_policy, start_time=start_time, speculative_execution_plan=spec_exec_plan) def _get_execution_profile(self, ep): profiles = self.cluster.profile_manager.profiles @@ -3172,11 +3187,11 @@ class _Scheduler(Thread): exc_info=exc) -def refresh_schema_and_set_result(control_conn, response_future, **kwargs): +def refresh_schema_and_set_result(control_conn, response_future, connection, **kwargs): try: log.debug("Refreshing schema in response to schema change. " "%s", kwargs) - response_future.is_schema_agreed = control_conn._refresh_schema(response_future._connection, **kwargs) + response_future.is_schema_agreed = control_conn._refresh_schema(connection, **kwargs) except Exception: log.exception("Exception refreshing schema in response to schema change:") response_future.session.submit(control_conn.refresh_schema, **kwargs) @@ -3214,6 +3229,14 @@ class ResponseFuture(object): Size of the request message sent """ + coordinator_host = None + """ + The host an actual result from ( + """ + + """ + """ + session = None row_factory = None message = None @@ -3230,7 +3253,6 @@ class ResponseFuture(object): _callbacks = None _errbacks = None _current_host = None - _current_pool = None _connection = None _query_retries = 0 _start_time = None @@ -3240,11 +3262,12 @@ class ResponseFuture(object): _warnings = None _timer = None _protocol_handler = ProtocolHandler + _spec_execution_plan = NoSpeculativeExecutionPlan() _warned_timeout = False def __init__(self, session, message, query, timeout, metrics=None, prepared_statement=None, - retry_policy=RetryPolicy(), row_factory=None, load_balancer=None): + retry_policy=RetryPolicy(), row_factory=None, load_balancer=None, start_time=None, speculative_execution_plan=None): self.session = session # TODO: normalize handling of retry policy and row factory self.row_factory = row_factory or session.row_factory @@ -3252,21 +3275,29 @@ class ResponseFuture(object): self.message = message self.query = query self.timeout = timeout + self._time_remaining = timeout self._retry_policy = retry_policy self._metrics = metrics self.prepared_statement = prepared_statement self._callback_lock = Lock() - if metrics is not None: - self._start_time = time.time() + self._start_time = start_time or time.time() self._make_query_plan() self._event = Event() self._errors = {} self._callbacks = [] self._errbacks = [] + self._spec_execution_plan = speculative_execution_plan or self._spec_execution_plan + self._queried_hosts = [] def _start_timer(self): - if self.timeout is not None: - self._timer = self.session.cluster.connection_class.create_timer(self.timeout, self._on_timeout) + if self._timer is None: + spec_delay = self._spec_execution_plan.next_execution(self._current_host) + if spec_delay >= 0: + if self._time_remaining is None or self._time_remaining > spec_delay: + self._timer = self.session.cluster.connection_class.create_timer(spec_delay, self._on_speculative_execute) + return + if self._time_remaining is not None: + self._timer = self.session.cluster.connection_class.create_timer(self._time_remaining, self._on_timeout) def _cancel_timer(self): if self._timer: @@ -3284,17 +3315,29 @@ class ResponseFuture(object): self._set_final_exception(OperationTimedOut(errors, self._current_host)) + def _on_speculative_execute(self): + self._timer = None + if not self._event.is_set(): + if self._time_remaining is not None: + elapsed = time.time() - self._start_time + self._time_remaining -= elapsed + if self._time_remaining <= 0: + self._on_timeout() + return + if not self.send_request(error_no_hosts=False): + self._start_timer() + + def _make_query_plan(self): # convert the list/generator/etc to an iterator so that subsequent # calls to send_request (which retries may do) will resume where # they last left off self.query_plan = iter(self._load_balancer.make_query_plan(self.session.keyspace, self.query)) - def send_request(self): + def send_request(self, error_no_hosts=True): """ Internal """ # query_plan is an iterator, so this will resume where we last left # off if send_request() is called multiple times - start = time.time() for host in self.query_plan: req_id = self._query(host) if req_id is not None: @@ -3303,23 +3346,21 @@ class ResponseFuture(object): # timer is only started here, after we have at least one message queued # this is done to avoid overrun of timers with unfettered client requests # in the case of full disconnect, where no hosts will be available - if self._timer is None: - self._start_timer() - return - if self.timeout is not None and time.time() - start > self.timeout: + self._start_timer() + return True + if self.timeout is not None and time.time() - self._start_time > self.timeout: self._on_timeout() - return + return True - self._set_final_exception(NoHostAvailable( - "Unable to complete the operation against any hosts", self._errors)) + if error_no_hosts: + self._set_final_exception(NoHostAvailable( + "Unable to complete the operation against any hosts", self._errors)) + return False def _query(self, host, message=None, cb=None): if message is None: message = self.message - if cb is None: - cb = self._set_result - pool = self.session._pools.get(host) if not pool: self._errors[host] = ConnectionException("Host has been marked down or removed") @@ -3329,7 +3370,6 @@ class ResponseFuture(object): return None self._current_host = host - self._current_pool = pool connection = None try: @@ -3337,6 +3377,10 @@ class ResponseFuture(object): connection, request_id = pool.borrow_connection(timeout=2.0) self._connection = connection result_meta = self.prepared_statement.result_metadata if self.prepared_statement else [] + + if cb is None: + cb = partial(self._set_result, host, connection, pool) + self.request_encoded_size = connection.send_msg(message, request_id, cb=cb, encoder=self._protocol_handler.encode_message, decoder=self._protocol_handler.decode_message, @@ -3423,17 +3467,18 @@ class ResponseFuture(object): self._timer = None # clear cancelled timer; new one will be set when request is queued self.send_request() - def _reprepare(self, prepare_message): - cb = partial(self.session.submit, self._execute_after_prepare) - request_id = self._query(self._current_host, prepare_message, cb=cb) + def _reprepare(self, prepare_message, host, connection, pool): + cb = partial(self.session.submit, self._execute_after_prepare, host, connection, pool) + request_id = self._query(host, prepare_message, cb=cb) if request_id is None: # try to submit the original prepared statement on some other host self.send_request() - def _set_result(self, response): + def _set_result(self, host, connection, pool, response): try: - if self._current_pool and self._connection: - self._current_pool.return_connection(self._connection) + self.coordinator_host = host + if pool: + pool.return_connection(connection) trace_id = getattr(response, 'trace_id', None) if trace_id: @@ -3464,7 +3509,7 @@ class ResponseFuture(object): self.session.submit( refresh_schema_and_set_result, self.session.cluster.control_connection, - self, **response.results) + self, connection, **response.results) else: results = getattr(response, 'results', None) if results is not None and response.kind == RESULT_KIND_ROWS: @@ -3495,14 +3540,14 @@ class ResponseFuture(object): self._metrics.on_other_error() # need to retry against a different host here log.warning("Host %s is overloaded, retrying against a different " - "host", self._current_host) - self._retry(reuse_connection=False, consistency_level=None) + "host", host) + self._retry(reuse_connection=False, consistency_level=None, host=host) return elif isinstance(response, IsBootstrappingErrorMessage): if self._metrics is not None: self._metrics.on_other_error() # need to retry against a different host here - self._retry(reuse_connection=False, consistency_level=None) + self._retry(reuse_connection=False, consistency_level=None, host=host) return elif isinstance(response, PreparedQueryNotFound): if self.prepared_statement: @@ -3536,11 +3581,11 @@ class ResponseFuture(object): return log.debug("Re-preparing unrecognized prepared statement against host %s: %s", - self._current_host, prepared_statement.query_string) + host, prepared_statement.query_string) prepare_message = PrepareMessage(query=prepared_statement.query_string) # since this might block, run on the executor to avoid hanging # the event loop thread - self.session.submit(self._reprepare, prepare_message) + self.session.submit(self._reprepare, prepare_message, host, connection, pool) return else: if hasattr(response, 'to_exception'): @@ -3553,20 +3598,20 @@ class ResponseFuture(object): if retry_type in (RetryPolicy.RETRY, RetryPolicy.RETRY_NEXT_HOST): self._query_retries += 1 reuse = retry_type == RetryPolicy.RETRY - self._retry(reuse_connection=reuse, consistency_level=consistency) + self._retry(reuse, consistency, host) elif retry_type is RetryPolicy.RETHROW: self._set_final_exception(response.to_exception()) else: # IGNORE if self._metrics is not None: self._metrics.on_ignore() self._set_final_result(None) - self._errors[self._current_host] = response.to_exception() + self._errors[host] = response.to_exception() elif isinstance(response, ConnectionException): if self._metrics is not None: self._metrics.on_connection_error() if not isinstance(response, ConnectionShutdown): self._connection.defunct(response) - self._retry(reuse_connection=False, consistency_level=None) + self._retry(reuse_connection=False, consistency_level=None, host=host) elif isinstance(response, Exception): if hasattr(response, 'to_exception'): self._set_final_exception(response.to_exception()) @@ -3575,7 +3620,7 @@ class ResponseFuture(object): else: # we got some other kind of response message msg = "Got unexpected message: %r" % (response,) - exc = ConnectionException(msg, self._current_host) + exc = ConnectionException(msg, host) self._connection.defunct(exc) self._set_final_exception(exc) except Exception as exc: @@ -3590,13 +3635,14 @@ class ResponseFuture(object): self._set_final_exception(ConnectionException( "Failed to set keyspace on all hosts: %s" % (errors,))) - def _execute_after_prepare(self, response): + def _execute_after_prepare(self, host, connection, pool, response): """ Handle the response to our attempt to prepare a statement. If it succeeded, run the original query again against the same host. """ - if self._current_pool and self._connection: - self._current_pool.return_connection(self._connection) + "AFTER PREPARE" + if pool: + pool.return_connection(connection) if self._final_exception: return @@ -3609,14 +3655,14 @@ class ResponseFuture(object): # use self._query to re-use the same host and # at the same time properly borrow the connection - request_id = self._query(self._current_host) + request_id = self._query(host) if request_id is None: # this host errored out, move on to the next self.send_request() else: self._set_final_exception(ConnectionException( "Got unexpected response when preparing statement " - "on host %s: %s" % (self._current_host, response))) + "on host %s: %s" % (host, response))) elif isinstance(response, ErrorMessage): if hasattr(response, 'to_exception'): self._set_final_exception(response.to_exception()) @@ -3624,14 +3670,14 @@ class ResponseFuture(object): self._set_final_exception(response) elif isinstance(response, ConnectionException): log.debug("Connection error when preparing statement on host %s: %s", - self._current_host, response) + host, response) # try again on a different host, preparing again if necessary - self._errors[self._current_host] = response + self._errors[host] = response self.send_request() else: self._set_final_exception(ConnectionException( "Got unexpected response type when preparing " - "statement on host %s: %s" % (self._current_host, response))) + "statement on host %s: %s" % (host, response))) def _set_final_result(self, response): self._cancel_timer() @@ -3661,7 +3707,7 @@ class ResponseFuture(object): fn, args, kwargs = errback fn(response, *args, **kwargs) - def _retry(self, reuse_connection, consistency_level): + def _retry(self, reuse_connection, consistency_level, host): if self._final_exception: # the connection probably broke while we were waiting # to retry the operation @@ -3673,15 +3719,15 @@ class ResponseFuture(object): self.message.consistency_level = consistency_level # don't retry on the event loop thread - self.session.submit(self._retry_task, reuse_connection) + self.session.submit(self._retry_task, reuse_connection, host) - def _retry_task(self, reuse_connection): + def _retry_task(self, reuse_connection, host): if self._final_exception: # the connection probably broke while we were waiting # to retry the operation return - if reuse_connection and self._query(self._current_host) is not None: + if reuse_connection and self._query(host) is not None: return # otherwise, move onto another host @@ -3852,8 +3898,8 @@ class ResponseFuture(object): def __str__(self): result = "(no result yet)" if self._final_result is _NOT_SET else self._final_result - return "" \ - % (self.query, self._req_id, result, self._final_exception, self._current_host) + return "" \ + % (self.query, self._req_id, result, self._final_exception, self.coordinator_host) __repr__ = __str__ diff --git a/cassandra/policies.py b/cassandra/policies.py index 80b807c3..3736afcb 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -875,7 +875,7 @@ class AddressTranslator(object): """ Accepts the node ip address, and returns a translated address to be used connecting to this node. """ - raise NotImplementedError + raise NotImplementedError() class IdentityTranslator(AddressTranslator): @@ -904,3 +904,57 @@ class EC2MultiRegionTranslator(AddressTranslator): except Exception: pass return addr + + +class SpeculativeExecutionPolicy(object): + """ + Interface for specifying speculative execution plans + """ + + def new_plan(self, keyspace, statement): + """ + Returns + + :param keyspace: + :param statement: + :return: + """ + raise NotImplementedError() + + +class SpeculativeExecutionPlan(object): + def next_execution(self, host): + raise NotImplementedError() + + +class NoSpeculativeExecutionPlan(SpeculativeExecutionPlan): + def next_execution(self, host): + return -1 + + +class NoSpeculativeExecutionPolicy(SpeculativeExecutionPolicy): + + def new_plan(self, keyspace, statement): + return self.NoSpeculativeExecutionPlan() + + +class ConstantSpeculativeExecutionPolicy(SpeculativeExecutionPolicy): + + def __init__(self, delay, max_attempts): + self.delay = delay + self.max_attempts = max_attempts + + class ConstantSpeculativeExecutionPlan(SpeculativeExecutionPlan): + def __init__(self, delay, max_attempts): + self.delay = delay + self.remaining = max_attempts + + def next_execution(self, host): + if self.remaining > 0: + self.remaining -= 1 + return self.delay + else: + return -1 + + def new_plan(self, keyspace, statement): + return self.ConstantSpeculativeExecutionPlan(self.delay, self.max_attempts) diff --git a/cassandra/query.py b/cassandra/query.py index 65cb6ba9..40d522d3 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -215,11 +215,17 @@ class Statement(object): .. versionadded:: 2.6.0 """ + is_idempotent = False + """ + Flag indicating whether this statement is safe to run multiple times in speculative execution. + """ + _serial_consistency_level = None _routing_key = None def __init__(self, retry_policy=None, consistency_level=None, routing_key=None, - serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None): + serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None, + is_idempotent=False): if retry_policy and not hasattr(retry_policy, 'on_read_timeout'): # just checking one method to detect positional parameter errors raise ValueError('retry_policy should implement cassandra.policies.RetryPolicy') self.retry_policy = retry_policy @@ -234,6 +240,7 @@ class Statement(object): self.keyspace = keyspace if custom_payload is not None: self.custom_payload = custom_payload + self.is_idempotent = is_idempotent def _key_parts_packed(self, parts): for p in parts: @@ -328,7 +335,7 @@ class SimpleStatement(Statement): def __init__(self, query_string, retry_policy=None, consistency_level=None, routing_key=None, serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, - custom_payload=None): + custom_payload=None, is_idempotent=False): """ `query_string` should be a literal CQL statement with the exception of parameter placeholders that will be filled through the @@ -337,7 +344,7 @@ class SimpleStatement(Statement): See :class:`Statement` attributes for a description of the other parameters. """ Statement.__init__(self, retry_policy, consistency_level, routing_key, - serial_consistency_level, fetch_size, keyspace, custom_payload) + serial_consistency_level, fetch_size, keyspace, custom_payload, is_idempotent) self._query_string = query_string @property @@ -383,6 +390,7 @@ class PreparedStatement(object): self.keyspace = keyspace self.protocol_version = protocol_version self.result_metadata = result_metadata + self.is_idempotent = False @classmethod def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, @@ -465,6 +473,7 @@ class BoundStatement(Statement): self.serial_consistency_level = prepared_statement.serial_consistency_level self.fetch_size = prepared_statement.fetch_size self.custom_payload = prepared_statement.custom_payload + self.is_idempotent = prepared_statement.is_idempotent self.values = [] meta = prepared_statement.column_metadata diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index 88b08af8..f9959091 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -69,7 +69,7 @@ class ResponseFutureTests(unittest.TestCase): connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) - rf._set_result(self.make_mock_response([{'col': 'val'}])) + rf._set_result(None, None, None, self.make_mock_response([{'col': 'val'}])) result = rf.result() self.assertEqual(result, [{'col': 'val'}]) @@ -81,7 +81,7 @@ class ResponseFutureTests(unittest.TestCase): rf = self.make_response_future(session) rf.send_request() - rf._set_result(object()) + rf._set_result(None, None, None, object()) self.assertRaises(ConnectionException, rf.result) def test_set_keyspace_result(self): @@ -92,7 +92,7 @@ class ResponseFutureTests(unittest.TestCase): result = Mock(spec=ResultMessage, kind=RESULT_KIND_SET_KEYSPACE, results="keyspace1") - rf._set_result(result) + rf._set_result(None, None, None, result) rf._set_keyspace_completed({}) self.assertFalse(rf.result()) @@ -106,15 +106,16 @@ class ResponseFutureTests(unittest.TestCase): result = Mock(spec=ResultMessage, kind=RESULT_KIND_SCHEMA_CHANGE, results=event_results) - rf._set_result(result) - session.submit.assert_called_once_with(ANY, ANY, rf, **event_results) + connection = Mock() + rf._set_result(None, connection, None, result) + session.submit.assert_called_once_with(ANY, ANY, rf, connection, **event_results) def test_other_result_message_kind(self): session = self.make_session() rf = self.make_response_future(session) rf.send_request() result = [1, 2, 3] - rf._set_result(Mock(spec=ResultMessage, kind=999, results=result)) + rf._set_result(None, None, None, Mock(spec=ResultMessage, kind=999, results=result)) self.assertListEqual(list(rf.result()), result) def test_read_timeout_error_message(self): @@ -128,7 +129,7 @@ class ResponseFutureTests(unittest.TestCase): rf.send_request() result = Mock(spec=ReadTimeoutErrorMessage, info={}) - rf._set_result(result) + rf._set_result(None, None, None, result) self.assertRaises(Exception, rf.result) @@ -143,7 +144,7 @@ class ResponseFutureTests(unittest.TestCase): rf.send_request() result = Mock(spec=WriteTimeoutErrorMessage, info={}) - rf._set_result(result) + rf._set_result(None, None, None, result) self.assertRaises(Exception, rf.result) def test_unavailable_error_message(self): @@ -157,7 +158,7 @@ class ResponseFutureTests(unittest.TestCase): rf.send_request() result = Mock(spec=UnavailableErrorMessage, info={}) - rf._set_result(result) + rf._set_result(None, None, None, result) self.assertRaises(Exception, rf.result) def test_retry_policy_says_ignore(self): @@ -171,7 +172,7 @@ class ResponseFutureTests(unittest.TestCase): rf.send_request() result = Mock(spec=UnavailableErrorMessage, info={}) - rf._set_result(result) + rf._set_result(None, None, None, result) self.assertFalse(rf.result()) def test_retry_policy_says_retry(self): @@ -195,20 +196,21 @@ class ResponseFutureTests(unittest.TestCase): connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) result = Mock(spec=UnavailableErrorMessage, info={}) - rf._set_result(result) + host = Mock() + rf._set_result(host, None, None, result) - session.submit.assert_called_once_with(rf._retry_task, True) + session.submit.assert_called_once_with(rf._retry_task, True, host) self.assertEqual(1, rf._query_retries) connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 2) # simulate the executor running this - rf._retry_task(True) + rf._retry_task(True, host) # it should try again with the same host since this was # an UnavailableException - rf.session._pools.get.assert_called_with('ip1') + rf.session._pools.get.assert_called_with(host) pool.borrow_connection.assert_called_with(timeout=ANY) connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) @@ -229,16 +231,17 @@ class ResponseFutureTests(unittest.TestCase): self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level) result = Mock(spec=OverloadedErrorMessage, info={}) - rf._set_result(result) + host = Mock() + rf._set_result(host, None, None, result) - session.submit.assert_called_once_with(rf._retry_task, False) + session.submit.assert_called_once_with(rf._retry_task, False, host) # query_retries does not get incremented for Overloaded/Bootstrapping errors self.assertEqual(0, rf._query_retries) connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 2) # simulate the executor running this - rf._retry_task(False) + rf._retry_task(False, host) # it should try with a different host rf.session._pools.get.assert_called_with('ip2') @@ -259,21 +262,22 @@ class ResponseFutureTests(unittest.TestCase): rf.session._pools.get.assert_called_once_with('ip1') result = Mock(spec=IsBootstrappingErrorMessage, info={}) - rf._set_result(result) + host = Mock() + rf._set_result(host, None, None, result) # simulate the executor running this - session.submit.assert_called_once_with(rf._retry_task, False) - rf._retry_task(False) + session.submit.assert_called_once_with(rf._retry_task, False, host) + rf._retry_task(False, host) # it should try with a different host rf.session._pools.get.assert_called_with('ip2') result = Mock(spec=IsBootstrappingErrorMessage, info={}) - rf._set_result(result) + rf._set_result(host, None, None, result) # simulate the executor running this - session.submit.assert_called_with(rf._retry_task, False) - rf._retry_task(False) + session.submit.assert_called_with(rf._retry_task, False, host) + rf._retry_task(False, host) self.assertRaises(NoHostAvailable, rf.result) @@ -295,7 +299,7 @@ class ResponseFutureTests(unittest.TestCase): rf = self.make_response_future(session) rf.send_request() - rf._set_result(self.make_mock_response([{'col': 'val'}])) + rf._set_result(None, None, None, self.make_mock_response([{'col': 'val'}])) result = rf.result() self.assertEqual(result, [{'col': 'val'}]) @@ -319,7 +323,7 @@ class ResponseFutureTests(unittest.TestCase): rf = self.make_response_future(session) rf.send_request() - rf._set_result(self.make_mock_response([{'col': 'val'}])) + rf._set_result(None, None, None, self.make_mock_response([{'col': 'val'}])) self.assertEqual(rf.result(), [{'col': 'val'}]) # make sure the exception is recorded correctly @@ -336,7 +340,7 @@ class ResponseFutureTests(unittest.TestCase): kwargs = {'one': 1, 'two': 2} rf.add_callback(callback, arg, **kwargs) - rf._set_result(self.make_mock_response(expected_result)) + rf._set_result(None, None, None, self.make_mock_response(expected_result)) result = rf.result() self.assertEqual(result, expected_result) @@ -363,7 +367,7 @@ class ResponseFutureTests(unittest.TestCase): rf.add_errback(self.assertIsInstance, Exception) result = Mock(spec=UnavailableErrorMessage, info={}) - rf._set_result(result) + rf._set_result(None, None, None, result) self.assertRaises(Exception, rf.result) # this should get called immediately now that the error is set @@ -385,7 +389,7 @@ class ResponseFutureTests(unittest.TestCase): kwargs2 = {'three': 3, 'four': 4} rf.add_callback(callback2, arg2, **kwargs2) - rf._set_result(self.make_mock_response(expected_result)) + rf._set_result(None, None, None, self.make_mock_response(expected_result)) result = rf.result() self.assertEqual(result, expected_result) @@ -420,7 +424,7 @@ class ResponseFutureTests(unittest.TestCase): expected_exception = Unavailable("message", 1, 2, 3) result = Mock(spec=UnavailableErrorMessage, info={'something': 'here'}) result.to_exception.return_value = expected_exception - rf._set_result(result) + rf._set_result(None, None, None, result) self.assertRaises(Exception, rf.result) callback.assert_called_once_with(expected_exception, arg, **kwargs) @@ -442,7 +446,7 @@ class ResponseFutureTests(unittest.TestCase): errback=self.assertIsInstance, errback_args=(Exception,)) result = Mock(spec=UnavailableErrorMessage, info={}) - rf._set_result(result) + rf._set_result(None, None, None, result) self.assertRaises(Exception, rf.result) # test callback @@ -457,7 +461,7 @@ class ResponseFutureTests(unittest.TestCase): callback=callback, callback_args=(arg,), callback_kwargs=kwargs, errback=self.assertIsInstance, errback_args=(Exception,)) - rf._set_result(self.make_mock_response(expected_result)) + rf._set_result(None, None, None, self.make_mock_response(expected_result)) self.assertEqual(rf.result(), expected_result) callback.assert_called_once_with(expected_result, arg, **kwargs) @@ -478,7 +482,7 @@ class ResponseFutureTests(unittest.TestCase): rf._connection.keyspace = "FooKeyspace" result = Mock(spec=PreparedQueryNotFound, info='a' * 16) - rf._set_result(result) + rf._set_result(None, None, None, result) self.assertTrue(session.submit.call_args) args, kwargs = session.submit.call_args @@ -502,5 +506,5 @@ class ResponseFutureTests(unittest.TestCase): rf._connection.keyspace = "BarKeyspace" result = Mock(spec=PreparedQueryNotFound, info='a' * 16) - rf._set_result(result) + rf._set_result(None, None, None, result) self.assertRaises(ValueError, rf.result) diff --git a/tests/unit/test_resultset.py b/tests/unit/test_resultset.py index 2deeb30f..c0fbad4a 100644 --- a/tests/unit/test_resultset.py +++ b/tests/unit/test_resultset.py @@ -19,7 +19,6 @@ except ImportError: import unittest # noqa from mock import Mock, PropertyMock -import warnings from cassandra.cluster import ResultSet