Fix bad request ID management from 2.1 support changes
This commit is contained in:
@@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict, deque
|
||||||
import errno
|
import errno
|
||||||
from functools import wraps, partial
|
from functools import wraps, partial
|
||||||
import logging
|
import logging
|
||||||
@@ -150,8 +150,20 @@ class Connection(object):
|
|||||||
|
|
||||||
ssl_options = None
|
ssl_options = None
|
||||||
last_error = None
|
last_error = None
|
||||||
|
|
||||||
|
# The current number of operations that are in flight. More precisely,
|
||||||
|
# the number of request IDs that are currently in use.
|
||||||
in_flight = 0
|
in_flight = 0
|
||||||
current_request_id = 0
|
|
||||||
|
# A set of available request IDs. When using the v3 protocol or higher,
|
||||||
|
# this will no initially include all request IDs in order to save memory,
|
||||||
|
# but the set will grow if it is exhausted.
|
||||||
|
request_ids = None
|
||||||
|
|
||||||
|
# Tracks the highest used request ID in order to help with growing the
|
||||||
|
# request_ids set
|
||||||
|
highest_request_id = 0
|
||||||
|
|
||||||
is_defunct = False
|
is_defunct = False
|
||||||
is_closed = False
|
is_closed = False
|
||||||
lock = None
|
lock = None
|
||||||
@@ -178,10 +190,16 @@ class Connection(object):
|
|||||||
self._header_unpack = v3_header_unpack
|
self._header_unpack = v3_header_unpack
|
||||||
self._header_length = 5
|
self._header_length = 5
|
||||||
self.max_request_id = (2 ** 15) - 1
|
self.max_request_id = (2 ** 15) - 1
|
||||||
|
# Don't fill the deque with 2**15 items right away. Start with 300 and add
|
||||||
|
# more if needed.
|
||||||
|
self.request_ids = deque(range(300))
|
||||||
|
self.highest_request_id = 299
|
||||||
else:
|
else:
|
||||||
self._header_unpack = header_unpack
|
self._header_unpack = header_unpack
|
||||||
self._header_length = 4
|
self._header_length = 4
|
||||||
self.max_request_id = (2 ** 7) - 1
|
self.max_request_id = (2 ** 7) - 1
|
||||||
|
self.request_ids = deque(range(self.max_request_id + 1))
|
||||||
|
self.highest_request_id = self.max_request_id + 1
|
||||||
|
|
||||||
# 0 8 16 24 32 40
|
# 0 8 16 24 32 40
|
||||||
# +---------+---------+---------+---------+---------+
|
# +---------+---------+---------+---------+---------+
|
||||||
@@ -246,9 +264,16 @@ class Connection(object):
|
|||||||
id(self), self.host, exc_info=True)
|
id(self), self.host, exc_info=True)
|
||||||
|
|
||||||
def get_request_id(self):
|
def get_request_id(self):
|
||||||
current = self.current_request_id
|
"""
|
||||||
self.current_request_id = (current + 1) % self.max_request_id
|
This must be called while self.lock is held.
|
||||||
return current
|
"""
|
||||||
|
try:
|
||||||
|
return self.request_ids.popleft()
|
||||||
|
except IndexError:
|
||||||
|
self.highest_request_id += 1
|
||||||
|
# in_flight checks should guarantee this
|
||||||
|
assert self.highest_request_id <= self.max_request_id
|
||||||
|
return self.highest_request_id
|
||||||
|
|
||||||
def handle_pushed(self, response):
|
def handle_pushed(self, response):
|
||||||
log.debug("Message pushed from server: %r", response)
|
log.debug("Message pushed from server: %r", response)
|
||||||
@@ -283,13 +308,12 @@ class Connection(object):
|
|||||||
needed = len(msgs) - messages_sent
|
needed = len(msgs) - messages_sent
|
||||||
with self.lock:
|
with self.lock:
|
||||||
available = min(needed, self.max_request_id - self.in_flight)
|
available = min(needed, self.max_request_id - self.in_flight)
|
||||||
start_request_id = self.current_request_id
|
request_ids = [self.get_request_id() for _ in range(available)]
|
||||||
self.current_request_id = (self.current_request_id + available) % self.max_request_id
|
|
||||||
self.in_flight += available
|
self.in_flight += available
|
||||||
|
|
||||||
for i in range(available):
|
for i, request_id in enumerate(request_ids):
|
||||||
self.send_msg(msgs[messages_sent + i],
|
self.send_msg(msgs[messages_sent + i],
|
||||||
(start_request_id + i) % self.max_request_id,
|
request_id,
|
||||||
partial(waiter.got_response, index=messages_sent + i))
|
partial(waiter.got_response, index=messages_sent + i))
|
||||||
messages_sent += available
|
messages_sent += available
|
||||||
|
|
||||||
@@ -327,6 +351,8 @@ class Connection(object):
|
|||||||
callback = None
|
callback = None
|
||||||
else:
|
else:
|
||||||
callback = self._callbacks.pop(stream_id, None)
|
callback = self._callbacks.pop(stream_id, None)
|
||||||
|
with self.lock:
|
||||||
|
self.request_ids.append(stream_id)
|
||||||
|
|
||||||
body = None
|
body = None
|
||||||
try:
|
try:
|
||||||
@@ -383,7 +409,7 @@ class Connection(object):
|
|||||||
self._send_startup_message()
|
self._send_startup_message()
|
||||||
else:
|
else:
|
||||||
log.debug("Sending initial options message for new connection (%s) to %s", id(self), self.host)
|
log.debug("Sending initial options message for new connection (%s) to %s", id(self), self.host)
|
||||||
self.send_msg(OptionsMessage(), 0, self._handle_options_response)
|
self.send_msg(OptionsMessage(), self.get_request_id(), self._handle_options_response)
|
||||||
|
|
||||||
@defunct_on_error
|
@defunct_on_error
|
||||||
def _handle_options_response(self, options_response):
|
def _handle_options_response(self, options_response):
|
||||||
@@ -455,7 +481,7 @@ class Connection(object):
|
|||||||
if compression:
|
if compression:
|
||||||
opts['COMPRESSION'] = compression
|
opts['COMPRESSION'] = compression
|
||||||
sm = StartupMessage(cqlversion=self.cql_version, options=opts)
|
sm = StartupMessage(cqlversion=self.cql_version, options=opts)
|
||||||
self.send_msg(sm, 0, cb=self._handle_startup_response)
|
self.send_msg(sm, self.get_request_id(), cb=self._handle_startup_response)
|
||||||
log.debug("Sent StartupMessage on %s", self)
|
log.debug("Sent StartupMessage on %s", self)
|
||||||
|
|
||||||
@defunct_on_error
|
@defunct_on_error
|
||||||
@@ -480,7 +506,7 @@ class Connection(object):
|
|||||||
log.debug("Sending credentials-based auth response on %s", self)
|
log.debug("Sending credentials-based auth response on %s", self)
|
||||||
cm = CredentialsMessage(creds=self.authenticator)
|
cm = CredentialsMessage(creds=self.authenticator)
|
||||||
callback = partial(self._handle_startup_response, did_authenticate=True)
|
callback = partial(self._handle_startup_response, did_authenticate=True)
|
||||||
self.send_msg(cm, 0, cb=callback)
|
self.send_msg(cm, self.get_request_id(), cb=callback)
|
||||||
else:
|
else:
|
||||||
log.debug("Sending SASL-based auth response on %s", self)
|
log.debug("Sending SASL-based auth response on %s", self)
|
||||||
initial_response = self.authenticator.initial_response()
|
initial_response = self.authenticator.initial_response()
|
||||||
@@ -520,7 +546,7 @@ class Connection(object):
|
|||||||
response = self.authenticator.evaluate_challenge(auth_response.challenge)
|
response = self.authenticator.evaluate_challenge(auth_response.challenge)
|
||||||
msg = AuthResponseMessage("" if response is None else response)
|
msg = AuthResponseMessage("" if response is None else response)
|
||||||
log.debug("Responding to auth challenge on %s", self)
|
log.debug("Responding to auth challenge on %s", self)
|
||||||
self.send_msg(msg, 0, self._handle_auth_response)
|
self.send_msg(msg, self.get_request_id(), self._handle_auth_response)
|
||||||
elif isinstance(auth_response, ErrorMessage):
|
elif isinstance(auth_response, ErrorMessage):
|
||||||
log.debug("Received ErrorMessage on new connection (%s) from %s: %s",
|
log.debug("Received ErrorMessage on new connection (%s) from %s: %s",
|
||||||
id(self), self.host, auth_response.summary_msg())
|
id(self), self.host, auth_response.summary_msg())
|
||||||
|
@@ -328,10 +328,11 @@ class HostConnection(object):
|
|||||||
raise NoConnectionsAvailable()
|
raise NoConnectionsAvailable()
|
||||||
|
|
||||||
with conn.lock:
|
with conn.lock:
|
||||||
if conn.in_flight > conn.max_request_id:
|
if conn.in_flight < conn.max_request_id:
|
||||||
raise NoConnectionsAvailable("All request IDs are currently in use")
|
conn.in_flight += 1
|
||||||
conn.in_flight += 1
|
return conn, conn.get_request_id()
|
||||||
return conn, conn.get_request_id()
|
|
||||||
|
raise NoConnectionsAvailable("All request IDs are currently in use")
|
||||||
|
|
||||||
def return_connection(self, connection):
|
def return_connection(self, connection):
|
||||||
with connection.lock:
|
with connection.lock:
|
||||||
@@ -341,7 +342,7 @@ class HostConnection(object):
|
|||||||
log.debug("Defunct or closed connection (%s) returned to pool, potentially "
|
log.debug("Defunct or closed connection (%s) returned to pool, potentially "
|
||||||
"marking host %s as down", id(connection), self.host)
|
"marking host %s as down", id(connection), self.host)
|
||||||
is_down = self._session.cluster.signal_connection_failure(
|
is_down = self._session.cluster.signal_connection_failure(
|
||||||
self.host, connection.last_error, is_host_addition=False)
|
self.host, connection.last_error, is_host_addition=False)
|
||||||
if is_down:
|
if is_down:
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
else:
|
else:
|
||||||
@@ -463,13 +464,12 @@ class HostConnectionPool(object):
|
|||||||
# its in_flight count
|
# its in_flight count
|
||||||
need_to_wait = False
|
need_to_wait = False
|
||||||
with least_busy.lock:
|
with least_busy.lock:
|
||||||
|
if least_busy.in_flight < least_busy.max_request_id:
|
||||||
if least_busy.in_flight >= least_busy.max_request_id:
|
|
||||||
# once we release the lock, wait for another connection
|
|
||||||
need_to_wait = True
|
|
||||||
else:
|
|
||||||
least_busy.in_flight += 1
|
least_busy.in_flight += 1
|
||||||
request_id = least_busy.get_request_id()
|
request_id = least_busy.get_request_id()
|
||||||
|
else:
|
||||||
|
# once we release the lock, wait for another connection
|
||||||
|
need_to_wait = True
|
||||||
|
|
||||||
if need_to_wait:
|
if need_to_wait:
|
||||||
# wait_for_conn will increment in_flight on the conn
|
# wait_for_conn will increment in_flight on the conn
|
||||||
@@ -587,7 +587,7 @@ class HostConnectionPool(object):
|
|||||||
log.debug("Defunct or closed connection (%s) returned to pool, potentially "
|
log.debug("Defunct or closed connection (%s) returned to pool, potentially "
|
||||||
"marking host %s as down", id(connection), self.host)
|
"marking host %s as down", id(connection), self.host)
|
||||||
is_down = self._session.cluster.signal_connection_failure(
|
is_down = self._session.cluster.signal_connection_failure(
|
||||||
self.host, connection.last_error, is_host_addition=False)
|
self.host, connection.last_error, is_host_addition=False)
|
||||||
if is_down:
|
if is_down:
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
else:
|
else:
|
||||||
|
@@ -101,7 +101,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
|||||||
# let it write out a StartupMessage
|
# let it write out a StartupMessage
|
||||||
c.handle_write()
|
c.handle_write()
|
||||||
|
|
||||||
header = self.make_header_prefix(ReadyMessage, stream_id=0)
|
header = self.make_header_prefix(ReadyMessage, stream_id=1)
|
||||||
c.socket.recv.return_value = self.make_msg(header)
|
c.socket.recv.return_value = self.make_msg(header)
|
||||||
c.handle_read()
|
c.handle_read()
|
||||||
|
|
||||||
@@ -173,7 +173,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
|||||||
# let it write out a StartupMessage
|
# let it write out a StartupMessage
|
||||||
c.handle_write()
|
c.handle_write()
|
||||||
|
|
||||||
header = self.make_header_prefix(ServerError, stream_id=0)
|
header = self.make_header_prefix(ServerError, stream_id=1)
|
||||||
body = self.make_error_body(ServerError.error_code, ServerError.summary)
|
body = self.make_error_body(ServerError.error_code, ServerError.summary)
|
||||||
c.socket.recv.return_value = self.make_msg(header, body)
|
c.socket.recv.return_value = self.make_msg(header, body)
|
||||||
c.handle_read()
|
c.handle_read()
|
||||||
@@ -255,7 +255,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
|||||||
# let it write out a StartupMessage
|
# let it write out a StartupMessage
|
||||||
c.handle_write()
|
c.handle_write()
|
||||||
|
|
||||||
header = self.make_header_prefix(ReadyMessage, stream_id=0)
|
header = self.make_header_prefix(ReadyMessage, stream_id=1)
|
||||||
c.socket.recv.return_value = self.make_msg(header)
|
c.socket.recv.return_value = self.make_msg(header)
|
||||||
c.handle_read()
|
c.handle_read()
|
||||||
|
|
||||||
@@ -282,7 +282,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
|||||||
# let it write out a StartupMessage
|
# let it write out a StartupMessage
|
||||||
c.handle_write()
|
c.handle_write()
|
||||||
|
|
||||||
header = self.make_header_prefix(ReadyMessage, stream_id=0)
|
header = self.make_header_prefix(ReadyMessage, stream_id=1)
|
||||||
c.socket.recv.return_value = self.make_msg(header)
|
c.socket.recv.return_value = self.make_msg(header)
|
||||||
c.handle_read()
|
c.handle_read()
|
||||||
|
|
||||||
|
@@ -98,7 +98,7 @@ class LibevConnectionTest(unittest.TestCase):
|
|||||||
# let it write out a StartupMessage
|
# let it write out a StartupMessage
|
||||||
c.handle_write(None, 0)
|
c.handle_write(None, 0)
|
||||||
|
|
||||||
header = self.make_header_prefix(ReadyMessage, stream_id=0)
|
header = self.make_header_prefix(ReadyMessage, stream_id=1)
|
||||||
c._socket.recv.return_value = self.make_msg(header)
|
c._socket.recv.return_value = self.make_msg(header)
|
||||||
c.handle_read(None, 0)
|
c.handle_read(None, 0)
|
||||||
|
|
||||||
@@ -170,7 +170,7 @@ class LibevConnectionTest(unittest.TestCase):
|
|||||||
# let it write out a StartupMessage
|
# let it write out a StartupMessage
|
||||||
c.handle_write(None, 0)
|
c.handle_write(None, 0)
|
||||||
|
|
||||||
header = self.make_header_prefix(ServerError, stream_id=0)
|
header = self.make_header_prefix(ServerError, stream_id=1)
|
||||||
body = self.make_error_body(ServerError.error_code, ServerError.summary)
|
body = self.make_error_body(ServerError.error_code, ServerError.summary)
|
||||||
c._socket.recv.return_value = self.make_msg(header, body)
|
c._socket.recv.return_value = self.make_msg(header, body)
|
||||||
c.handle_read(None, 0)
|
c.handle_read(None, 0)
|
||||||
@@ -253,7 +253,7 @@ class LibevConnectionTest(unittest.TestCase):
|
|||||||
# let it write out a StartupMessage
|
# let it write out a StartupMessage
|
||||||
c.handle_write(None, 0)
|
c.handle_write(None, 0)
|
||||||
|
|
||||||
header = self.make_header_prefix(ReadyMessage, stream_id=0)
|
header = self.make_header_prefix(ReadyMessage, stream_id=1)
|
||||||
c._socket.recv.return_value = self.make_msg(header)
|
c._socket.recv.return_value = self.make_msg(header)
|
||||||
c.handle_read(None, 0)
|
c.handle_read(None, 0)
|
||||||
|
|
||||||
@@ -280,7 +280,7 @@ class LibevConnectionTest(unittest.TestCase):
|
|||||||
# let it write out a StartupMessage
|
# let it write out a StartupMessage
|
||||||
c.handle_write(None, 0)
|
c.handle_write(None, 0)
|
||||||
|
|
||||||
header = self.make_header_prefix(ReadyMessage, stream_id=0)
|
header = self.make_header_prefix(ReadyMessage, stream_id=1)
|
||||||
c._socket.recv.return_value = self.make_msg(header)
|
c._socket.recv.return_value = self.make_msg(header)
|
||||||
c.handle_read(None, 0)
|
c.handle_read(None, 0)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user