Fix bad request ID management from 2.1 support changes

This commit is contained in:
Tyler Hobbs
2014-07-02 16:02:03 -05:00
parent 39efc92fa2
commit 8f847245fa
4 changed files with 58 additions and 32 deletions

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict, deque
import errno import errno
from functools import wraps, partial from functools import wraps, partial
import logging import logging
@@ -150,8 +150,20 @@ class Connection(object):
ssl_options = None ssl_options = None
last_error = None last_error = None
# The current number of operations that are in flight. More precisely,
# the number of request IDs that are currently in use.
in_flight = 0 in_flight = 0
current_request_id = 0
# A set of available request IDs. When using the v3 protocol or higher,
# this will no initially include all request IDs in order to save memory,
# but the set will grow if it is exhausted.
request_ids = None
# Tracks the highest used request ID in order to help with growing the
# request_ids set
highest_request_id = 0
is_defunct = False is_defunct = False
is_closed = False is_closed = False
lock = None lock = None
@@ -178,10 +190,16 @@ class Connection(object):
self._header_unpack = v3_header_unpack self._header_unpack = v3_header_unpack
self._header_length = 5 self._header_length = 5
self.max_request_id = (2 ** 15) - 1 self.max_request_id = (2 ** 15) - 1
# Don't fill the deque with 2**15 items right away. Start with 300 and add
# more if needed.
self.request_ids = deque(range(300))
self.highest_request_id = 299
else: else:
self._header_unpack = header_unpack self._header_unpack = header_unpack
self._header_length = 4 self._header_length = 4
self.max_request_id = (2 ** 7) - 1 self.max_request_id = (2 ** 7) - 1
self.request_ids = deque(range(self.max_request_id + 1))
self.highest_request_id = self.max_request_id + 1
# 0 8 16 24 32 40 # 0 8 16 24 32 40
# +---------+---------+---------+---------+---------+ # +---------+---------+---------+---------+---------+
@@ -246,9 +264,16 @@ class Connection(object):
id(self), self.host, exc_info=True) id(self), self.host, exc_info=True)
def get_request_id(self): def get_request_id(self):
current = self.current_request_id """
self.current_request_id = (current + 1) % self.max_request_id This must be called while self.lock is held.
return current """
try:
return self.request_ids.popleft()
except IndexError:
self.highest_request_id += 1
# in_flight checks should guarantee this
assert self.highest_request_id <= self.max_request_id
return self.highest_request_id
def handle_pushed(self, response): def handle_pushed(self, response):
log.debug("Message pushed from server: %r", response) log.debug("Message pushed from server: %r", response)
@@ -283,13 +308,12 @@ class Connection(object):
needed = len(msgs) - messages_sent needed = len(msgs) - messages_sent
with self.lock: with self.lock:
available = min(needed, self.max_request_id - self.in_flight) available = min(needed, self.max_request_id - self.in_flight)
start_request_id = self.current_request_id request_ids = [self.get_request_id() for _ in range(available)]
self.current_request_id = (self.current_request_id + available) % self.max_request_id
self.in_flight += available self.in_flight += available
for i in range(available): for i, request_id in enumerate(request_ids):
self.send_msg(msgs[messages_sent + i], self.send_msg(msgs[messages_sent + i],
(start_request_id + i) % self.max_request_id, request_id,
partial(waiter.got_response, index=messages_sent + i)) partial(waiter.got_response, index=messages_sent + i))
messages_sent += available messages_sent += available
@@ -327,6 +351,8 @@ class Connection(object):
callback = None callback = None
else: else:
callback = self._callbacks.pop(stream_id, None) callback = self._callbacks.pop(stream_id, None)
with self.lock:
self.request_ids.append(stream_id)
body = None body = None
try: try:
@@ -383,7 +409,7 @@ class Connection(object):
self._send_startup_message() self._send_startup_message()
else: else:
log.debug("Sending initial options message for new connection (%s) to %s", id(self), self.host) log.debug("Sending initial options message for new connection (%s) to %s", id(self), self.host)
self.send_msg(OptionsMessage(), 0, self._handle_options_response) self.send_msg(OptionsMessage(), self.get_request_id(), self._handle_options_response)
@defunct_on_error @defunct_on_error
def _handle_options_response(self, options_response): def _handle_options_response(self, options_response):
@@ -455,7 +481,7 @@ class Connection(object):
if compression: if compression:
opts['COMPRESSION'] = compression opts['COMPRESSION'] = compression
sm = StartupMessage(cqlversion=self.cql_version, options=opts) sm = StartupMessage(cqlversion=self.cql_version, options=opts)
self.send_msg(sm, 0, cb=self._handle_startup_response) self.send_msg(sm, self.get_request_id(), cb=self._handle_startup_response)
log.debug("Sent StartupMessage on %s", self) log.debug("Sent StartupMessage on %s", self)
@defunct_on_error @defunct_on_error
@@ -480,7 +506,7 @@ class Connection(object):
log.debug("Sending credentials-based auth response on %s", self) log.debug("Sending credentials-based auth response on %s", self)
cm = CredentialsMessage(creds=self.authenticator) cm = CredentialsMessage(creds=self.authenticator)
callback = partial(self._handle_startup_response, did_authenticate=True) callback = partial(self._handle_startup_response, did_authenticate=True)
self.send_msg(cm, 0, cb=callback) self.send_msg(cm, self.get_request_id(), cb=callback)
else: else:
log.debug("Sending SASL-based auth response on %s", self) log.debug("Sending SASL-based auth response on %s", self)
initial_response = self.authenticator.initial_response() initial_response = self.authenticator.initial_response()
@@ -520,7 +546,7 @@ class Connection(object):
response = self.authenticator.evaluate_challenge(auth_response.challenge) response = self.authenticator.evaluate_challenge(auth_response.challenge)
msg = AuthResponseMessage("" if response is None else response) msg = AuthResponseMessage("" if response is None else response)
log.debug("Responding to auth challenge on %s", self) log.debug("Responding to auth challenge on %s", self)
self.send_msg(msg, 0, self._handle_auth_response) self.send_msg(msg, self.get_request_id(), self._handle_auth_response)
elif isinstance(auth_response, ErrorMessage): elif isinstance(auth_response, ErrorMessage):
log.debug("Received ErrorMessage on new connection (%s) from %s: %s", log.debug("Received ErrorMessage on new connection (%s) from %s: %s",
id(self), self.host, auth_response.summary_msg()) id(self), self.host, auth_response.summary_msg())

View File

@@ -328,10 +328,11 @@ class HostConnection(object):
raise NoConnectionsAvailable() raise NoConnectionsAvailable()
with conn.lock: with conn.lock:
if conn.in_flight > conn.max_request_id: if conn.in_flight < conn.max_request_id:
raise NoConnectionsAvailable("All request IDs are currently in use") conn.in_flight += 1
conn.in_flight += 1 return conn, conn.get_request_id()
return conn, conn.get_request_id()
raise NoConnectionsAvailable("All request IDs are currently in use")
def return_connection(self, connection): def return_connection(self, connection):
with connection.lock: with connection.lock:
@@ -341,7 +342,7 @@ class HostConnection(object):
log.debug("Defunct or closed connection (%s) returned to pool, potentially " log.debug("Defunct or closed connection (%s) returned to pool, potentially "
"marking host %s as down", id(connection), self.host) "marking host %s as down", id(connection), self.host)
is_down = self._session.cluster.signal_connection_failure( is_down = self._session.cluster.signal_connection_failure(
self.host, connection.last_error, is_host_addition=False) self.host, connection.last_error, is_host_addition=False)
if is_down: if is_down:
self.shutdown() self.shutdown()
else: else:
@@ -463,13 +464,12 @@ class HostConnectionPool(object):
# its in_flight count # its in_flight count
need_to_wait = False need_to_wait = False
with least_busy.lock: with least_busy.lock:
if least_busy.in_flight < least_busy.max_request_id:
if least_busy.in_flight >= least_busy.max_request_id:
# once we release the lock, wait for another connection
need_to_wait = True
else:
least_busy.in_flight += 1 least_busy.in_flight += 1
request_id = least_busy.get_request_id() request_id = least_busy.get_request_id()
else:
# once we release the lock, wait for another connection
need_to_wait = True
if need_to_wait: if need_to_wait:
# wait_for_conn will increment in_flight on the conn # wait_for_conn will increment in_flight on the conn
@@ -587,7 +587,7 @@ class HostConnectionPool(object):
log.debug("Defunct or closed connection (%s) returned to pool, potentially " log.debug("Defunct or closed connection (%s) returned to pool, potentially "
"marking host %s as down", id(connection), self.host) "marking host %s as down", id(connection), self.host)
is_down = self._session.cluster.signal_connection_failure( is_down = self._session.cluster.signal_connection_failure(
self.host, connection.last_error, is_host_addition=False) self.host, connection.last_error, is_host_addition=False)
if is_down: if is_down:
self.shutdown() self.shutdown()
else: else:

View File

@@ -101,7 +101,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
# let it write out a StartupMessage # let it write out a StartupMessage
c.handle_write() c.handle_write()
header = self.make_header_prefix(ReadyMessage, stream_id=0) header = self.make_header_prefix(ReadyMessage, stream_id=1)
c.socket.recv.return_value = self.make_msg(header) c.socket.recv.return_value = self.make_msg(header)
c.handle_read() c.handle_read()
@@ -173,7 +173,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
# let it write out a StartupMessage # let it write out a StartupMessage
c.handle_write() c.handle_write()
header = self.make_header_prefix(ServerError, stream_id=0) header = self.make_header_prefix(ServerError, stream_id=1)
body = self.make_error_body(ServerError.error_code, ServerError.summary) body = self.make_error_body(ServerError.error_code, ServerError.summary)
c.socket.recv.return_value = self.make_msg(header, body) c.socket.recv.return_value = self.make_msg(header, body)
c.handle_read() c.handle_read()
@@ -255,7 +255,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
# let it write out a StartupMessage # let it write out a StartupMessage
c.handle_write() c.handle_write()
header = self.make_header_prefix(ReadyMessage, stream_id=0) header = self.make_header_prefix(ReadyMessage, stream_id=1)
c.socket.recv.return_value = self.make_msg(header) c.socket.recv.return_value = self.make_msg(header)
c.handle_read() c.handle_read()
@@ -282,7 +282,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
# let it write out a StartupMessage # let it write out a StartupMessage
c.handle_write() c.handle_write()
header = self.make_header_prefix(ReadyMessage, stream_id=0) header = self.make_header_prefix(ReadyMessage, stream_id=1)
c.socket.recv.return_value = self.make_msg(header) c.socket.recv.return_value = self.make_msg(header)
c.handle_read() c.handle_read()

View File

@@ -98,7 +98,7 @@ class LibevConnectionTest(unittest.TestCase):
# let it write out a StartupMessage # let it write out a StartupMessage
c.handle_write(None, 0) c.handle_write(None, 0)
header = self.make_header_prefix(ReadyMessage, stream_id=0) header = self.make_header_prefix(ReadyMessage, stream_id=1)
c._socket.recv.return_value = self.make_msg(header) c._socket.recv.return_value = self.make_msg(header)
c.handle_read(None, 0) c.handle_read(None, 0)
@@ -170,7 +170,7 @@ class LibevConnectionTest(unittest.TestCase):
# let it write out a StartupMessage # let it write out a StartupMessage
c.handle_write(None, 0) c.handle_write(None, 0)
header = self.make_header_prefix(ServerError, stream_id=0) header = self.make_header_prefix(ServerError, stream_id=1)
body = self.make_error_body(ServerError.error_code, ServerError.summary) body = self.make_error_body(ServerError.error_code, ServerError.summary)
c._socket.recv.return_value = self.make_msg(header, body) c._socket.recv.return_value = self.make_msg(header, body)
c.handle_read(None, 0) c.handle_read(None, 0)
@@ -253,7 +253,7 @@ class LibevConnectionTest(unittest.TestCase):
# let it write out a StartupMessage # let it write out a StartupMessage
c.handle_write(None, 0) c.handle_write(None, 0)
header = self.make_header_prefix(ReadyMessage, stream_id=0) header = self.make_header_prefix(ReadyMessage, stream_id=1)
c._socket.recv.return_value = self.make_msg(header) c._socket.recv.return_value = self.make_msg(header)
c.handle_read(None, 0) c.handle_read(None, 0)
@@ -280,7 +280,7 @@ class LibevConnectionTest(unittest.TestCase):
# let it write out a StartupMessage # let it write out a StartupMessage
c.handle_write(None, 0) c.handle_write(None, 0)
header = self.make_header_prefix(ReadyMessage, stream_id=0) header = self.make_header_prefix(ReadyMessage, stream_id=1)
c._socket.recv.return_value = self.make_msg(header) c._socket.recv.return_value = self.make_msg(header)
c.handle_read(None, 0) c.handle_read(None, 0)