diff --git a/cassandra/cluster.py b/cassandra/cluster.py index d7c9b76e..a2e72905 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1675,7 +1675,13 @@ class ResponseFuture(object): self._set_final_exception(NoHostAvailable( "Unable to complete the operation against any hosts", self._errors)) - def _query(self, host): + def _query(self, host, message=None, cb=None): + if message is None: + message = self.message + + if cb is None: + cb = self._set_result + pool = self.session._pools.get(host) if not pool: self._errors[host] = ConnectionException("Host has been marked down or removed") @@ -1684,21 +1690,11 @@ class ResponseFuture(object): self._errors[host] = ConnectionException("Pool is shutdown") return None - return self._borrow_conn_and_send_message(host, pool, self.message, self._set_result) - - def _borrow_conn_and_send_message(self, host, pool, message, cb): - if cb is None: - cb = self._set_result - connection = None try: # TODO get connectTimeout from cluster settings connection = pool.borrow_connection(timeout=2.0) request_id = connection.send_msg(message, cb=cb) - self._current_host = host - self._current_pool = pool - self._connection = connection - return request_id except Exception as exc: log.debug("Error querying host %s", host, exc_info=True) self._errors[host] = exc @@ -1706,6 +1702,17 @@ class ResponseFuture(object): pool.return_connection(connection) return None + self._current_host = host + self._current_pool = pool + self._connection = connection + return request_id + + def _reprepare(self, prepare_message): + request_id = self._query(self._current_host, prepare_message, cb=self._execute_after_prepare) + if request_id is None: + # try to submit the original prepared statement on some other host + self.send_request() + def _set_result(self, response): try: if self._current_pool and self._connection: @@ -1769,7 +1776,7 @@ class ResponseFuture(object): self._metrics.on_other_error() # need to retry against a different host here log.warn("Host %s is overloaded, retrying against a different " - "host" % (self._current_host)) + "host", self._current_host) self._retry(reuse_connection=False, consistency_level=None) return elif isinstance(response, IsBootstrappingErrorMessage): @@ -1800,11 +1807,7 @@ class ResponseFuture(object): prepare_message = PrepareMessage(query=prepared_statement.query_string) # since this might block, run on the executor to avoid hanging # the event loop thread - self.session.submit(self._borrow_conn_and_send_message, - self._current_host, - self._current_pool, - prepare_message, - self._execute_after_prepare) + self.session.submit(self._reprepare, prepare_message) return else: if hasattr(response, 'to_exception'): diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index 928dcc30..ea05a86f 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -366,8 +366,9 @@ class ResponseFutureTests(unittest.TestCase): session.submit.assert_called_once() args, kwargs = session.submit.call_args - self.assertIsInstance(args[-2], PrepareMessage) - self.assertEquals(args[-2].query, "SELECT * FROM foobar") + self.assertEquals(rf._reprepare, args[-2]) + self.assertIsInstance(args[-1], PrepareMessage) + self.assertEquals(args[-1].query, "SELECT * FROM foobar") def test_prepared_query_not_found_bad_keyspace(self): session = self.make_session()