diff --git a/cassandra/connection.py b/cassandra/connection.py index 9796bf44..ead99430 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import defaultdict +from collections import defaultdict, deque import errno from functools import wraps, partial import logging @@ -150,8 +150,20 @@ class Connection(object): ssl_options = None last_error = None + + # The current number of operations that are in flight. More precisely, + # the number of request IDs that are currently in use. in_flight = 0 - current_request_id = 0 + + # A set of available request IDs. When using the v3 protocol or higher, + # this will no initially include all request IDs in order to save memory, + # but the set will grow if it is exhausted. + request_ids = None + + # Tracks the highest used request ID in order to help with growing the + # request_ids set + highest_request_id = 0 + is_defunct = False is_closed = False lock = None @@ -178,10 +190,16 @@ class Connection(object): self._header_unpack = v3_header_unpack self._header_length = 5 self.max_request_id = (2 ** 15) - 1 + # Don't fill the deque with 2**15 items right away. Start with 300 and add + # more if needed. + self.request_ids = deque(range(300)) + self.highest_request_id = 299 else: self._header_unpack = header_unpack self._header_length = 4 self.max_request_id = (2 ** 7) - 1 + self.request_ids = deque(range(self.max_request_id + 1)) + self.highest_request_id = self.max_request_id + 1 # 0 8 16 24 32 40 # +---------+---------+---------+---------+---------+ @@ -246,9 +264,16 @@ class Connection(object): id(self), self.host, exc_info=True) def get_request_id(self): - current = self.current_request_id - self.current_request_id = (current + 1) % self.max_request_id - return current + """ + This must be called while self.lock is held. + """ + try: + return self.request_ids.popleft() + except IndexError: + self.highest_request_id += 1 + # in_flight checks should guarantee this + assert self.highest_request_id <= self.max_request_id + return self.highest_request_id def handle_pushed(self, response): log.debug("Message pushed from server: %r", response) @@ -283,13 +308,12 @@ class Connection(object): needed = len(msgs) - messages_sent with self.lock: available = min(needed, self.max_request_id - self.in_flight) - start_request_id = self.current_request_id - self.current_request_id = (self.current_request_id + available) % self.max_request_id + request_ids = [self.get_request_id() for _ in range(available)] self.in_flight += available - for i in range(available): + for i, request_id in enumerate(request_ids): self.send_msg(msgs[messages_sent + i], - (start_request_id + i) % self.max_request_id, + request_id, partial(waiter.got_response, index=messages_sent + i)) messages_sent += available @@ -327,6 +351,8 @@ class Connection(object): callback = None else: callback = self._callbacks.pop(stream_id, None) + with self.lock: + self.request_ids.append(stream_id) body = None try: @@ -383,7 +409,7 @@ class Connection(object): self._send_startup_message() else: log.debug("Sending initial options message for new connection (%s) to %s", id(self), self.host) - self.send_msg(OptionsMessage(), 0, self._handle_options_response) + self.send_msg(OptionsMessage(), self.get_request_id(), self._handle_options_response) @defunct_on_error def _handle_options_response(self, options_response): @@ -455,7 +481,7 @@ class Connection(object): if compression: opts['COMPRESSION'] = compression sm = StartupMessage(cqlversion=self.cql_version, options=opts) - self.send_msg(sm, 0, cb=self._handle_startup_response) + self.send_msg(sm, self.get_request_id(), cb=self._handle_startup_response) log.debug("Sent StartupMessage on %s", self) @defunct_on_error @@ -480,7 +506,7 @@ class Connection(object): log.debug("Sending credentials-based auth response on %s", self) cm = CredentialsMessage(creds=self.authenticator) callback = partial(self._handle_startup_response, did_authenticate=True) - self.send_msg(cm, 0, cb=callback) + self.send_msg(cm, self.get_request_id(), cb=callback) else: log.debug("Sending SASL-based auth response on %s", self) initial_response = self.authenticator.initial_response() @@ -520,7 +546,7 @@ class Connection(object): response = self.authenticator.evaluate_challenge(auth_response.challenge) msg = AuthResponseMessage("" if response is None else response) log.debug("Responding to auth challenge on %s", self) - self.send_msg(msg, 0, self._handle_auth_response) + self.send_msg(msg, self.get_request_id(), self._handle_auth_response) elif isinstance(auth_response, ErrorMessage): log.debug("Received ErrorMessage on new connection (%s) from %s: %s", id(self), self.host, auth_response.summary_msg()) diff --git a/cassandra/pool.py b/cassandra/pool.py index 0f29c77f..dd462998 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -328,10 +328,11 @@ class HostConnection(object): raise NoConnectionsAvailable() with conn.lock: - if conn.in_flight > conn.max_request_id: - raise NoConnectionsAvailable("All request IDs are currently in use") - conn.in_flight += 1 - return conn, conn.get_request_id() + if conn.in_flight < conn.max_request_id: + conn.in_flight += 1 + return conn, conn.get_request_id() + + raise NoConnectionsAvailable("All request IDs are currently in use") def return_connection(self, connection): with connection.lock: @@ -341,7 +342,7 @@ class HostConnection(object): log.debug("Defunct or closed connection (%s) returned to pool, potentially " "marking host %s as down", id(connection), self.host) is_down = self._session.cluster.signal_connection_failure( - self.host, connection.last_error, is_host_addition=False) + self.host, connection.last_error, is_host_addition=False) if is_down: self.shutdown() else: @@ -463,13 +464,12 @@ class HostConnectionPool(object): # its in_flight count need_to_wait = False with least_busy.lock: - - if least_busy.in_flight >= least_busy.max_request_id: - # once we release the lock, wait for another connection - need_to_wait = True - else: + if least_busy.in_flight < least_busy.max_request_id: least_busy.in_flight += 1 request_id = least_busy.get_request_id() + else: + # once we release the lock, wait for another connection + need_to_wait = True if need_to_wait: # wait_for_conn will increment in_flight on the conn @@ -587,7 +587,7 @@ class HostConnectionPool(object): log.debug("Defunct or closed connection (%s) returned to pool, potentially " "marking host %s as down", id(connection), self.host) is_down = self._session.cluster.signal_connection_failure( - self.host, connection.last_error, is_host_addition=False) + self.host, connection.last_error, is_host_addition=False) if is_down: self.shutdown() else: diff --git a/tests/unit/io/test_asyncorereactor.py b/tests/unit/io/test_asyncorereactor.py index ecb59bd0..5ab59670 100644 --- a/tests/unit/io/test_asyncorereactor.py +++ b/tests/unit/io/test_asyncorereactor.py @@ -101,7 +101,7 @@ class AsyncoreConnectionTest(unittest.TestCase): # let it write out a StartupMessage c.handle_write() - header = self.make_header_prefix(ReadyMessage, stream_id=0) + header = self.make_header_prefix(ReadyMessage, stream_id=1) c.socket.recv.return_value = self.make_msg(header) c.handle_read() @@ -173,7 +173,7 @@ class AsyncoreConnectionTest(unittest.TestCase): # let it write out a StartupMessage c.handle_write() - header = self.make_header_prefix(ServerError, stream_id=0) + header = self.make_header_prefix(ServerError, stream_id=1) body = self.make_error_body(ServerError.error_code, ServerError.summary) c.socket.recv.return_value = self.make_msg(header, body) c.handle_read() @@ -255,7 +255,7 @@ class AsyncoreConnectionTest(unittest.TestCase): # let it write out a StartupMessage c.handle_write() - header = self.make_header_prefix(ReadyMessage, stream_id=0) + header = self.make_header_prefix(ReadyMessage, stream_id=1) c.socket.recv.return_value = self.make_msg(header) c.handle_read() @@ -282,7 +282,7 @@ class AsyncoreConnectionTest(unittest.TestCase): # let it write out a StartupMessage c.handle_write() - header = self.make_header_prefix(ReadyMessage, stream_id=0) + header = self.make_header_prefix(ReadyMessage, stream_id=1) c.socket.recv.return_value = self.make_msg(header) c.handle_read() diff --git a/tests/unit/io/test_libevreactor.py b/tests/unit/io/test_libevreactor.py index bef4ecbf..3b677ab5 100644 --- a/tests/unit/io/test_libevreactor.py +++ b/tests/unit/io/test_libevreactor.py @@ -98,7 +98,7 @@ class LibevConnectionTest(unittest.TestCase): # let it write out a StartupMessage c.handle_write(None, 0) - header = self.make_header_prefix(ReadyMessage, stream_id=0) + header = self.make_header_prefix(ReadyMessage, stream_id=1) c._socket.recv.return_value = self.make_msg(header) c.handle_read(None, 0) @@ -170,7 +170,7 @@ class LibevConnectionTest(unittest.TestCase): # let it write out a StartupMessage c.handle_write(None, 0) - header = self.make_header_prefix(ServerError, stream_id=0) + header = self.make_header_prefix(ServerError, stream_id=1) body = self.make_error_body(ServerError.error_code, ServerError.summary) c._socket.recv.return_value = self.make_msg(header, body) c.handle_read(None, 0) @@ -253,7 +253,7 @@ class LibevConnectionTest(unittest.TestCase): # let it write out a StartupMessage c.handle_write(None, 0) - header = self.make_header_prefix(ReadyMessage, stream_id=0) + header = self.make_header_prefix(ReadyMessage, stream_id=1) c._socket.recv.return_value = self.make_msg(header) c.handle_read(None, 0) @@ -280,7 +280,7 @@ class LibevConnectionTest(unittest.TestCase): # let it write out a StartupMessage c.handle_write(None, 0) - header = self.make_header_prefix(ReadyMessage, stream_id=0) + header = self.make_header_prefix(ReadyMessage, stream_id=1) c._socket.recv.return_value = self.make_msg(header) c.handle_read(None, 0)