Fix bad request ID management from 2.1 support changes

This commit is contained in:
Tyler Hobbs
2014-07-02 16:02:03 -05:00
parent 39efc92fa2
commit 8f847245fa
4 changed files with 58 additions and 32 deletions

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from collections import defaultdict, deque
import errno
from functools import wraps, partial
import logging
@@ -150,8 +150,20 @@ class Connection(object):
ssl_options = None
last_error = None
# The current number of operations that are in flight. More precisely,
# the number of request IDs that are currently in use.
in_flight = 0
current_request_id = 0
# A set of available request IDs. When using the v3 protocol or higher,
# this will no initially include all request IDs in order to save memory,
# but the set will grow if it is exhausted.
request_ids = None
# Tracks the highest used request ID in order to help with growing the
# request_ids set
highest_request_id = 0
is_defunct = False
is_closed = False
lock = None
@@ -178,10 +190,16 @@ class Connection(object):
self._header_unpack = v3_header_unpack
self._header_length = 5
self.max_request_id = (2 ** 15) - 1
# Don't fill the deque with 2**15 items right away. Start with 300 and add
# more if needed.
self.request_ids = deque(range(300))
self.highest_request_id = 299
else:
self._header_unpack = header_unpack
self._header_length = 4
self.max_request_id = (2 ** 7) - 1
self.request_ids = deque(range(self.max_request_id + 1))
self.highest_request_id = self.max_request_id + 1
# 0 8 16 24 32 40
# +---------+---------+---------+---------+---------+
@@ -246,9 +264,16 @@ class Connection(object):
id(self), self.host, exc_info=True)
def get_request_id(self):
current = self.current_request_id
self.current_request_id = (current + 1) % self.max_request_id
return current
"""
This must be called while self.lock is held.
"""
try:
return self.request_ids.popleft()
except IndexError:
self.highest_request_id += 1
# in_flight checks should guarantee this
assert self.highest_request_id <= self.max_request_id
return self.highest_request_id
def handle_pushed(self, response):
log.debug("Message pushed from server: %r", response)
@@ -283,13 +308,12 @@ class Connection(object):
needed = len(msgs) - messages_sent
with self.lock:
available = min(needed, self.max_request_id - self.in_flight)
start_request_id = self.current_request_id
self.current_request_id = (self.current_request_id + available) % self.max_request_id
request_ids = [self.get_request_id() for _ in range(available)]
self.in_flight += available
for i in range(available):
for i, request_id in enumerate(request_ids):
self.send_msg(msgs[messages_sent + i],
(start_request_id + i) % self.max_request_id,
request_id,
partial(waiter.got_response, index=messages_sent + i))
messages_sent += available
@@ -327,6 +351,8 @@ class Connection(object):
callback = None
else:
callback = self._callbacks.pop(stream_id, None)
with self.lock:
self.request_ids.append(stream_id)
body = None
try:
@@ -383,7 +409,7 @@ class Connection(object):
self._send_startup_message()
else:
log.debug("Sending initial options message for new connection (%s) to %s", id(self), self.host)
self.send_msg(OptionsMessage(), 0, self._handle_options_response)
self.send_msg(OptionsMessage(), self.get_request_id(), self._handle_options_response)
@defunct_on_error
def _handle_options_response(self, options_response):
@@ -455,7 +481,7 @@ class Connection(object):
if compression:
opts['COMPRESSION'] = compression
sm = StartupMessage(cqlversion=self.cql_version, options=opts)
self.send_msg(sm, 0, cb=self._handle_startup_response)
self.send_msg(sm, self.get_request_id(), cb=self._handle_startup_response)
log.debug("Sent StartupMessage on %s", self)
@defunct_on_error
@@ -480,7 +506,7 @@ class Connection(object):
log.debug("Sending credentials-based auth response on %s", self)
cm = CredentialsMessage(creds=self.authenticator)
callback = partial(self._handle_startup_response, did_authenticate=True)
self.send_msg(cm, 0, cb=callback)
self.send_msg(cm, self.get_request_id(), cb=callback)
else:
log.debug("Sending SASL-based auth response on %s", self)
initial_response = self.authenticator.initial_response()
@@ -520,7 +546,7 @@ class Connection(object):
response = self.authenticator.evaluate_challenge(auth_response.challenge)
msg = AuthResponseMessage("" if response is None else response)
log.debug("Responding to auth challenge on %s", self)
self.send_msg(msg, 0, self._handle_auth_response)
self.send_msg(msg, self.get_request_id(), self._handle_auth_response)
elif isinstance(auth_response, ErrorMessage):
log.debug("Received ErrorMessage on new connection (%s) from %s: %s",
id(self), self.host, auth_response.summary_msg())

View File

@@ -328,11 +328,12 @@ class HostConnection(object):
raise NoConnectionsAvailable()
with conn.lock:
if conn.in_flight > conn.max_request_id:
raise NoConnectionsAvailable("All request IDs are currently in use")
if conn.in_flight < conn.max_request_id:
conn.in_flight += 1
return conn, conn.get_request_id()
raise NoConnectionsAvailable("All request IDs are currently in use")
def return_connection(self, connection):
with connection.lock:
connection.in_flight -= 1
@@ -463,13 +464,12 @@ class HostConnectionPool(object):
# its in_flight count
need_to_wait = False
with least_busy.lock:
if least_busy.in_flight >= least_busy.max_request_id:
# once we release the lock, wait for another connection
need_to_wait = True
else:
if least_busy.in_flight < least_busy.max_request_id:
least_busy.in_flight += 1
request_id = least_busy.get_request_id()
else:
# once we release the lock, wait for another connection
need_to_wait = True
if need_to_wait:
# wait_for_conn will increment in_flight on the conn

View File

@@ -101,7 +101,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
# let it write out a StartupMessage
c.handle_write()
header = self.make_header_prefix(ReadyMessage, stream_id=0)
header = self.make_header_prefix(ReadyMessage, stream_id=1)
c.socket.recv.return_value = self.make_msg(header)
c.handle_read()
@@ -173,7 +173,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
# let it write out a StartupMessage
c.handle_write()
header = self.make_header_prefix(ServerError, stream_id=0)
header = self.make_header_prefix(ServerError, stream_id=1)
body = self.make_error_body(ServerError.error_code, ServerError.summary)
c.socket.recv.return_value = self.make_msg(header, body)
c.handle_read()
@@ -255,7 +255,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
# let it write out a StartupMessage
c.handle_write()
header = self.make_header_prefix(ReadyMessage, stream_id=0)
header = self.make_header_prefix(ReadyMessage, stream_id=1)
c.socket.recv.return_value = self.make_msg(header)
c.handle_read()
@@ -282,7 +282,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
# let it write out a StartupMessage
c.handle_write()
header = self.make_header_prefix(ReadyMessage, stream_id=0)
header = self.make_header_prefix(ReadyMessage, stream_id=1)
c.socket.recv.return_value = self.make_msg(header)
c.handle_read()

View File

@@ -98,7 +98,7 @@ class LibevConnectionTest(unittest.TestCase):
# let it write out a StartupMessage
c.handle_write(None, 0)
header = self.make_header_prefix(ReadyMessage, stream_id=0)
header = self.make_header_prefix(ReadyMessage, stream_id=1)
c._socket.recv.return_value = self.make_msg(header)
c.handle_read(None, 0)
@@ -170,7 +170,7 @@ class LibevConnectionTest(unittest.TestCase):
# let it write out a StartupMessage
c.handle_write(None, 0)
header = self.make_header_prefix(ServerError, stream_id=0)
header = self.make_header_prefix(ServerError, stream_id=1)
body = self.make_error_body(ServerError.error_code, ServerError.summary)
c._socket.recv.return_value = self.make_msg(header, body)
c.handle_read(None, 0)
@@ -253,7 +253,7 @@ class LibevConnectionTest(unittest.TestCase):
# let it write out a StartupMessage
c.handle_write(None, 0)
header = self.make_header_prefix(ReadyMessage, stream_id=0)
header = self.make_header_prefix(ReadyMessage, stream_id=1)
c._socket.recv.return_value = self.make_msg(header)
c.handle_read(None, 0)
@@ -280,7 +280,7 @@ class LibevConnectionTest(unittest.TestCase):
# let it write out a StartupMessage
c.handle_write(None, 0)
header = self.make_header_prefix(ReadyMessage, stream_id=0)
header = self.make_header_prefix(ReadyMessage, stream_id=1)
c._socket.recv.return_value = self.make_msg(header)
c.handle_read(None, 0)