Fix bad request ID management from 2.1 support changes
This commit is contained in:
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, deque
|
||||
import errno
|
||||
from functools import wraps, partial
|
||||
import logging
|
||||
@@ -150,8 +150,20 @@ class Connection(object):
|
||||
|
||||
ssl_options = None
|
||||
last_error = None
|
||||
|
||||
# The current number of operations that are in flight. More precisely,
|
||||
# the number of request IDs that are currently in use.
|
||||
in_flight = 0
|
||||
current_request_id = 0
|
||||
|
||||
# A set of available request IDs. When using the v3 protocol or higher,
|
||||
# this will no initially include all request IDs in order to save memory,
|
||||
# but the set will grow if it is exhausted.
|
||||
request_ids = None
|
||||
|
||||
# Tracks the highest used request ID in order to help with growing the
|
||||
# request_ids set
|
||||
highest_request_id = 0
|
||||
|
||||
is_defunct = False
|
||||
is_closed = False
|
||||
lock = None
|
||||
@@ -178,10 +190,16 @@ class Connection(object):
|
||||
self._header_unpack = v3_header_unpack
|
||||
self._header_length = 5
|
||||
self.max_request_id = (2 ** 15) - 1
|
||||
# Don't fill the deque with 2**15 items right away. Start with 300 and add
|
||||
# more if needed.
|
||||
self.request_ids = deque(range(300))
|
||||
self.highest_request_id = 299
|
||||
else:
|
||||
self._header_unpack = header_unpack
|
||||
self._header_length = 4
|
||||
self.max_request_id = (2 ** 7) - 1
|
||||
self.request_ids = deque(range(self.max_request_id + 1))
|
||||
self.highest_request_id = self.max_request_id + 1
|
||||
|
||||
# 0 8 16 24 32 40
|
||||
# +---------+---------+---------+---------+---------+
|
||||
@@ -246,9 +264,16 @@ class Connection(object):
|
||||
id(self), self.host, exc_info=True)
|
||||
|
||||
def get_request_id(self):
|
||||
current = self.current_request_id
|
||||
self.current_request_id = (current + 1) % self.max_request_id
|
||||
return current
|
||||
"""
|
||||
This must be called while self.lock is held.
|
||||
"""
|
||||
try:
|
||||
return self.request_ids.popleft()
|
||||
except IndexError:
|
||||
self.highest_request_id += 1
|
||||
# in_flight checks should guarantee this
|
||||
assert self.highest_request_id <= self.max_request_id
|
||||
return self.highest_request_id
|
||||
|
||||
def handle_pushed(self, response):
|
||||
log.debug("Message pushed from server: %r", response)
|
||||
@@ -283,13 +308,12 @@ class Connection(object):
|
||||
needed = len(msgs) - messages_sent
|
||||
with self.lock:
|
||||
available = min(needed, self.max_request_id - self.in_flight)
|
||||
start_request_id = self.current_request_id
|
||||
self.current_request_id = (self.current_request_id + available) % self.max_request_id
|
||||
request_ids = [self.get_request_id() for _ in range(available)]
|
||||
self.in_flight += available
|
||||
|
||||
for i in range(available):
|
||||
for i, request_id in enumerate(request_ids):
|
||||
self.send_msg(msgs[messages_sent + i],
|
||||
(start_request_id + i) % self.max_request_id,
|
||||
request_id,
|
||||
partial(waiter.got_response, index=messages_sent + i))
|
||||
messages_sent += available
|
||||
|
||||
@@ -327,6 +351,8 @@ class Connection(object):
|
||||
callback = None
|
||||
else:
|
||||
callback = self._callbacks.pop(stream_id, None)
|
||||
with self.lock:
|
||||
self.request_ids.append(stream_id)
|
||||
|
||||
body = None
|
||||
try:
|
||||
@@ -383,7 +409,7 @@ class Connection(object):
|
||||
self._send_startup_message()
|
||||
else:
|
||||
log.debug("Sending initial options message for new connection (%s) to %s", id(self), self.host)
|
||||
self.send_msg(OptionsMessage(), 0, self._handle_options_response)
|
||||
self.send_msg(OptionsMessage(), self.get_request_id(), self._handle_options_response)
|
||||
|
||||
@defunct_on_error
|
||||
def _handle_options_response(self, options_response):
|
||||
@@ -455,7 +481,7 @@ class Connection(object):
|
||||
if compression:
|
||||
opts['COMPRESSION'] = compression
|
||||
sm = StartupMessage(cqlversion=self.cql_version, options=opts)
|
||||
self.send_msg(sm, 0, cb=self._handle_startup_response)
|
||||
self.send_msg(sm, self.get_request_id(), cb=self._handle_startup_response)
|
||||
log.debug("Sent StartupMessage on %s", self)
|
||||
|
||||
@defunct_on_error
|
||||
@@ -480,7 +506,7 @@ class Connection(object):
|
||||
log.debug("Sending credentials-based auth response on %s", self)
|
||||
cm = CredentialsMessage(creds=self.authenticator)
|
||||
callback = partial(self._handle_startup_response, did_authenticate=True)
|
||||
self.send_msg(cm, 0, cb=callback)
|
||||
self.send_msg(cm, self.get_request_id(), cb=callback)
|
||||
else:
|
||||
log.debug("Sending SASL-based auth response on %s", self)
|
||||
initial_response = self.authenticator.initial_response()
|
||||
@@ -520,7 +546,7 @@ class Connection(object):
|
||||
response = self.authenticator.evaluate_challenge(auth_response.challenge)
|
||||
msg = AuthResponseMessage("" if response is None else response)
|
||||
log.debug("Responding to auth challenge on %s", self)
|
||||
self.send_msg(msg, 0, self._handle_auth_response)
|
||||
self.send_msg(msg, self.get_request_id(), self._handle_auth_response)
|
||||
elif isinstance(auth_response, ErrorMessage):
|
||||
log.debug("Received ErrorMessage on new connection (%s) from %s: %s",
|
||||
id(self), self.host, auth_response.summary_msg())
|
||||
|
@@ -328,10 +328,11 @@ class HostConnection(object):
|
||||
raise NoConnectionsAvailable()
|
||||
|
||||
with conn.lock:
|
||||
if conn.in_flight > conn.max_request_id:
|
||||
raise NoConnectionsAvailable("All request IDs are currently in use")
|
||||
conn.in_flight += 1
|
||||
return conn, conn.get_request_id()
|
||||
if conn.in_flight < conn.max_request_id:
|
||||
conn.in_flight += 1
|
||||
return conn, conn.get_request_id()
|
||||
|
||||
raise NoConnectionsAvailable("All request IDs are currently in use")
|
||||
|
||||
def return_connection(self, connection):
|
||||
with connection.lock:
|
||||
@@ -341,7 +342,7 @@ class HostConnection(object):
|
||||
log.debug("Defunct or closed connection (%s) returned to pool, potentially "
|
||||
"marking host %s as down", id(connection), self.host)
|
||||
is_down = self._session.cluster.signal_connection_failure(
|
||||
self.host, connection.last_error, is_host_addition=False)
|
||||
self.host, connection.last_error, is_host_addition=False)
|
||||
if is_down:
|
||||
self.shutdown()
|
||||
else:
|
||||
@@ -463,13 +464,12 @@ class HostConnectionPool(object):
|
||||
# its in_flight count
|
||||
need_to_wait = False
|
||||
with least_busy.lock:
|
||||
|
||||
if least_busy.in_flight >= least_busy.max_request_id:
|
||||
# once we release the lock, wait for another connection
|
||||
need_to_wait = True
|
||||
else:
|
||||
if least_busy.in_flight < least_busy.max_request_id:
|
||||
least_busy.in_flight += 1
|
||||
request_id = least_busy.get_request_id()
|
||||
else:
|
||||
# once we release the lock, wait for another connection
|
||||
need_to_wait = True
|
||||
|
||||
if need_to_wait:
|
||||
# wait_for_conn will increment in_flight on the conn
|
||||
@@ -587,7 +587,7 @@ class HostConnectionPool(object):
|
||||
log.debug("Defunct or closed connection (%s) returned to pool, potentially "
|
||||
"marking host %s as down", id(connection), self.host)
|
||||
is_down = self._session.cluster.signal_connection_failure(
|
||||
self.host, connection.last_error, is_host_addition=False)
|
||||
self.host, connection.last_error, is_host_addition=False)
|
||||
if is_down:
|
||||
self.shutdown()
|
||||
else:
|
||||
|
@@ -101,7 +101,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
||||
# let it write out a StartupMessage
|
||||
c.handle_write()
|
||||
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=0)
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=1)
|
||||
c.socket.recv.return_value = self.make_msg(header)
|
||||
c.handle_read()
|
||||
|
||||
@@ -173,7 +173,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
||||
# let it write out a StartupMessage
|
||||
c.handle_write()
|
||||
|
||||
header = self.make_header_prefix(ServerError, stream_id=0)
|
||||
header = self.make_header_prefix(ServerError, stream_id=1)
|
||||
body = self.make_error_body(ServerError.error_code, ServerError.summary)
|
||||
c.socket.recv.return_value = self.make_msg(header, body)
|
||||
c.handle_read()
|
||||
@@ -255,7 +255,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
||||
# let it write out a StartupMessage
|
||||
c.handle_write()
|
||||
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=0)
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=1)
|
||||
c.socket.recv.return_value = self.make_msg(header)
|
||||
c.handle_read()
|
||||
|
||||
@@ -282,7 +282,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
||||
# let it write out a StartupMessage
|
||||
c.handle_write()
|
||||
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=0)
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=1)
|
||||
c.socket.recv.return_value = self.make_msg(header)
|
||||
c.handle_read()
|
||||
|
||||
|
@@ -98,7 +98,7 @@ class LibevConnectionTest(unittest.TestCase):
|
||||
# let it write out a StartupMessage
|
||||
c.handle_write(None, 0)
|
||||
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=0)
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=1)
|
||||
c._socket.recv.return_value = self.make_msg(header)
|
||||
c.handle_read(None, 0)
|
||||
|
||||
@@ -170,7 +170,7 @@ class LibevConnectionTest(unittest.TestCase):
|
||||
# let it write out a StartupMessage
|
||||
c.handle_write(None, 0)
|
||||
|
||||
header = self.make_header_prefix(ServerError, stream_id=0)
|
||||
header = self.make_header_prefix(ServerError, stream_id=1)
|
||||
body = self.make_error_body(ServerError.error_code, ServerError.summary)
|
||||
c._socket.recv.return_value = self.make_msg(header, body)
|
||||
c.handle_read(None, 0)
|
||||
@@ -253,7 +253,7 @@ class LibevConnectionTest(unittest.TestCase):
|
||||
# let it write out a StartupMessage
|
||||
c.handle_write(None, 0)
|
||||
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=0)
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=1)
|
||||
c._socket.recv.return_value = self.make_msg(header)
|
||||
c.handle_read(None, 0)
|
||||
|
||||
@@ -280,7 +280,7 @@ class LibevConnectionTest(unittest.TestCase):
|
||||
# let it write out a StartupMessage
|
||||
c.handle_write(None, 0)
|
||||
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=0)
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=1)
|
||||
c._socket.recv.return_value = self.make_msg(header)
|
||||
c.handle_read(None, 0)
|
||||
|
||||
|
Reference in New Issue
Block a user