2 byte request IDs + queue-less req ID management
This commit is contained in:
@@ -2143,8 +2143,8 @@ class ResponseFuture(object):
|
||||
connection = None
|
||||
try:
|
||||
# TODO get connectTimeout from cluster settings
|
||||
connection = pool.borrow_connection(timeout=2.0)
|
||||
request_id = connection.send_msg(message, cb=cb)
|
||||
connection, request_id = pool.borrow_connection(timeout=2.0)
|
||||
connection.send_msg(message, request_id, cb=cb)
|
||||
except NoConnectionsAvailable as exc:
|
||||
log.debug("All connections for host %s are at capacity, moving to the next host", host)
|
||||
self._errors[host] = exc
|
||||
|
||||
@@ -29,7 +29,7 @@ import six
|
||||
from six.moves import range
|
||||
|
||||
from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut
|
||||
from cassandra.marshal import int32_pack, header_unpack
|
||||
from cassandra.marshal import int32_pack, header_unpack, v3_header_unpack
|
||||
from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessage,
|
||||
StartupMessage, ErrorMessage, CredentialsMessage,
|
||||
QueryMessage, ResultMessage, decode_response,
|
||||
@@ -79,8 +79,6 @@ else:
|
||||
locally_supported_compressions['snappy'] = (snappy.compress, decompress)
|
||||
|
||||
|
||||
MAX_STREAM_PER_CONNECTION = 127
|
||||
|
||||
PROTOCOL_VERSION_MASK = 0x7f
|
||||
|
||||
HEADER_DIRECTION_FROM_CLIENT = 0x00
|
||||
@@ -153,6 +151,7 @@ class Connection(object):
|
||||
ssl_options = None
|
||||
last_error = None
|
||||
in_flight = 0
|
||||
current_request_id = 0
|
||||
is_defunct = False
|
||||
is_closed = False
|
||||
lock = None
|
||||
@@ -172,10 +171,27 @@ class Connection(object):
|
||||
self.protocol_version = protocol_version
|
||||
self.is_control_connection = is_control_connection
|
||||
self._push_watchers = defaultdict(set)
|
||||
if protocol_version >= 3:
|
||||
self._header_unpack = v3_header_unpack
|
||||
self._header_length = 5
|
||||
self.max_request_id = (2 ** 15) - 1
|
||||
else:
|
||||
self._header_unpack = header_unpack
|
||||
self._header_length = 4
|
||||
self.max_request_id = (2 ** 7) - 1
|
||||
|
||||
self._id_queue = Queue(MAX_STREAM_PER_CONNECTION)
|
||||
for i in range(MAX_STREAM_PER_CONNECTION):
|
||||
self._id_queue.put_nowait(i)
|
||||
# 0 8 16 24 32 40
|
||||
# +---------+---------+---------+---------+---------+
|
||||
# | version | flags | stream | opcode |
|
||||
# +---------+---------+---------+---------+---------+
|
||||
# | length |
|
||||
# +---------+---------+---------+---------+
|
||||
# | |
|
||||
# . ... body ... .
|
||||
# . .
|
||||
# . .
|
||||
# +----------------------------------------
|
||||
self._full_header_length = self._header_length + 4
|
||||
|
||||
self.lock = RLock()
|
||||
|
||||
@@ -210,6 +226,11 @@ class Connection(object):
|
||||
"failed connection (%s) to host %s:",
|
||||
id(self), self.host, exc_info=True)
|
||||
|
||||
def get_request_id(self):
|
||||
current = self.current_request_id
|
||||
self.current_request_id = (current + 1) % self.max_request_id
|
||||
return current
|
||||
|
||||
def handle_pushed(self, response):
|
||||
log.debug("Message pushed from server: %r", response)
|
||||
for cb in self._push_watchers.get(response.event_type, []):
|
||||
@@ -218,21 +239,12 @@ class Connection(object):
|
||||
except Exception:
|
||||
log.exception("Pushed event handler errored, ignoring:")
|
||||
|
||||
def send_msg(self, msg, cb, wait_for_id=False):
|
||||
def send_msg(self, msg, request_id, cb):
|
||||
if self.is_defunct:
|
||||
raise ConnectionShutdown("Connection to %s is defunct" % self.host)
|
||||
elif self.is_closed:
|
||||
raise ConnectionShutdown("Connection to %s is closed" % self.host)
|
||||
|
||||
if not wait_for_id:
|
||||
try:
|
||||
request_id = self._id_queue.get_nowait()
|
||||
except Empty:
|
||||
raise ConnectionBusy(
|
||||
"Connection to %s is at the max number of requests" % self.host)
|
||||
else:
|
||||
request_id = self._id_queue.get()
|
||||
|
||||
self._callbacks[request_id] = cb
|
||||
self.push(msg.to_binary(request_id, self.protocol_version, compression=self.compressor))
|
||||
return request_id
|
||||
@@ -251,11 +263,15 @@ class Connection(object):
|
||||
while True:
|
||||
needed = len(msgs) - messages_sent
|
||||
with self.lock:
|
||||
available = min(needed, MAX_STREAM_PER_CONNECTION - self.in_flight)
|
||||
available = min(needed, self.max_request_id - self.in_flight)
|
||||
start_request_id = self.current_request_id
|
||||
self.current_request_id = (self.current_request_id + available) % self.max_request_id
|
||||
self.in_flight += available
|
||||
|
||||
for i in range(messages_sent, messages_sent + available):
|
||||
self.send_msg(msgs[i], partial(waiter.got_response, index=i), wait_for_id=True)
|
||||
for i in range(available):
|
||||
self.send_msg(msgs[messages_sent + i],
|
||||
(start_request_id + i) % self.max_request_id,
|
||||
partial(waiter.got_response, index=messages_sent + i))
|
||||
messages_sent += available
|
||||
|
||||
if messages_sent == len(msgs):
|
||||
@@ -287,12 +303,11 @@ class Connection(object):
|
||||
|
||||
@defunct_on_error
|
||||
def process_msg(self, msg, body_len):
|
||||
version, flags, stream_id, opcode = header_unpack(msg[:4])
|
||||
version, flags, stream_id, opcode = self._header_unpack(msg[:self._header_length])
|
||||
if stream_id < 0:
|
||||
callback = None
|
||||
else:
|
||||
callback = self._callbacks.pop(stream_id, None)
|
||||
self._id_queue.put_nowait(stream_id)
|
||||
|
||||
body = None
|
||||
try:
|
||||
@@ -344,7 +359,7 @@ class Connection(object):
|
||||
self._send_startup_message()
|
||||
else:
|
||||
log.debug("Sending initial options message for new connection (%s) to %s", id(self), self.host)
|
||||
self.send_msg(OptionsMessage(), self._handle_options_response)
|
||||
self.send_msg(OptionsMessage(), 0, self._handle_options_response)
|
||||
|
||||
@defunct_on_error
|
||||
def _handle_options_response(self, options_response):
|
||||
@@ -411,11 +426,13 @@ class Connection(object):
|
||||
|
||||
@defunct_on_error
|
||||
def _send_startup_message(self, compression=None):
|
||||
log.debug("Sending StartupMessage on %s", self)
|
||||
opts = {}
|
||||
if compression:
|
||||
opts['COMPRESSION'] = compression
|
||||
sm = StartupMessage(cqlversion=self.cql_version, options=opts)
|
||||
self.send_msg(sm, cb=self._handle_startup_response)
|
||||
self.send_msg(sm, 0, cb=self._handle_startup_response)
|
||||
log.debug("Sent StartupMessage on %s", self)
|
||||
|
||||
@defunct_on_error
|
||||
def _handle_startup_response(self, startup_response, did_authenticate=False):
|
||||
@@ -439,12 +456,12 @@ class Connection(object):
|
||||
log.debug("Sending credentials-based auth response on %s", self)
|
||||
cm = CredentialsMessage(creds=self.authenticator)
|
||||
callback = partial(self._handle_startup_response, did_authenticate=True)
|
||||
self.send_msg(cm, cb=callback)
|
||||
self.send_msg(cm, 0, cb=callback)
|
||||
else:
|
||||
log.debug("Sending SASL-based auth response on %s", self)
|
||||
initial_response = self.authenticator.initial_response()
|
||||
initial_response = "" if initial_response is None else initial_response.encode('utf-8')
|
||||
self.send_msg(AuthResponseMessage(initial_response), self._handle_auth_response)
|
||||
self.send_msg(AuthResponseMessage(initial_response), 0, self._handle_auth_response)
|
||||
elif isinstance(startup_response, ErrorMessage):
|
||||
log.debug("Received ErrorMessage on new connection (%s) from %s: %s",
|
||||
id(self), self.host, startup_response.summary_msg())
|
||||
@@ -479,7 +496,7 @@ class Connection(object):
|
||||
response = self.authenticator.evaluate_challenge(auth_response.challenge)
|
||||
msg = AuthResponseMessage("" if response is None else response)
|
||||
log.debug("Responding to auth challenge on %s", self)
|
||||
self.send_msg(msg, self._handle_auth_response)
|
||||
self.send_msg(msg, 0, self._handle_auth_response)
|
||||
elif isinstance(auth_response, ErrorMessage):
|
||||
log.debug("Received ErrorMessage on new connection (%s) from %s: %s",
|
||||
id(self), self.host, auth_response.summary_msg())
|
||||
@@ -543,7 +560,21 @@ class Connection(object):
|
||||
callback(self, self.defunct(ConnectionException(
|
||||
"Problem while setting keyspace: %r" % (result,), self.host)))
|
||||
|
||||
self.send_msg(query, process_result, wait_for_id=True)
|
||||
request_id = None
|
||||
# we use a busy wait on the lock here because:
|
||||
# - we'll only spin if the connection is at max capacity, which is very
|
||||
# unlikely for a set_keyspace call
|
||||
# - it allows us to avoid signaling a condition every time a request completes
|
||||
while True:
|
||||
with self.lock:
|
||||
if self.in_flight < self.max_request_id:
|
||||
request_id = self.get_request_id()
|
||||
self.in_flight += 1
|
||||
break
|
||||
|
||||
time.sleep(0.001)
|
||||
|
||||
self.send_msg(query, request_id, process_result)
|
||||
|
||||
def __str__(self):
|
||||
status = ""
|
||||
|
||||
@@ -229,7 +229,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
|
||||
self.defunct(sys.exc_info()[1])
|
||||
|
||||
def handle_close(self):
|
||||
log.debug("connection (%s) to %s closed by server", id(self), self.host)
|
||||
log.debug("Connection %s closed by server", self)
|
||||
self.close()
|
||||
|
||||
def handle_write(self):
|
||||
@@ -277,24 +277,24 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
|
||||
if self._iobuf.tell():
|
||||
while True:
|
||||
pos = self._iobuf.tell()
|
||||
if pos < 8 or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
|
||||
if pos < self._full_header_length or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
|
||||
# we don't have a complete header yet or we
|
||||
# already saw a header, but we don't have a
|
||||
# complete message yet
|
||||
break
|
||||
else:
|
||||
# have enough for header, read body len from header
|
||||
self._iobuf.seek(4)
|
||||
self._iobuf.seek(self._header_length)
|
||||
body_len = int32_unpack(self._iobuf.read(4))
|
||||
|
||||
# seek to end to get length of current buffer
|
||||
self._iobuf.seek(0, os.SEEK_END)
|
||||
pos = self._iobuf.tell()
|
||||
|
||||
if pos >= body_len + 8:
|
||||
if pos >= body_len + self._full_header_length:
|
||||
# read message header and body
|
||||
self._iobuf.seek(0)
|
||||
msg = self._iobuf.read(8 + body_len)
|
||||
msg = self._iobuf.read(self._full_header_length + body_len)
|
||||
|
||||
# leave leftover in current buffer
|
||||
leftover = self._iobuf.read()
|
||||
@@ -304,7 +304,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
|
||||
self._total_reqd_bytes = 0
|
||||
self.process_msg(msg, body_len)
|
||||
else:
|
||||
self._total_reqd_bytes = body_len + 8
|
||||
self._total_reqd_bytes = body_len + self._full_header_length
|
||||
break
|
||||
|
||||
if not self._callbacks and not self.is_control_connection:
|
||||
|
||||
@@ -305,24 +305,24 @@ class LibevConnection(Connection):
|
||||
if self._iobuf.tell():
|
||||
while True:
|
||||
pos = self._iobuf.tell()
|
||||
if pos < 8 or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
|
||||
if pos < self._full_header_length or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
|
||||
# we don't have a complete header yet or we
|
||||
# already saw a header, but we don't have a
|
||||
# complete message yet
|
||||
break
|
||||
else:
|
||||
# have enough for header, read body len from header
|
||||
self._iobuf.seek(4)
|
||||
self._iobuf.seek(self._header_length)
|
||||
body_len = int32_unpack(self._iobuf.read(4))
|
||||
|
||||
# seek to end to get length of current buffer
|
||||
self._iobuf.seek(0, os.SEEK_END)
|
||||
pos = self._iobuf.tell()
|
||||
|
||||
if pos >= body_len + 8:
|
||||
if pos >= body_len + self._full_header_length:
|
||||
# read message header and body
|
||||
self._iobuf.seek(0)
|
||||
msg = self._iobuf.read(8 + body_len)
|
||||
msg = self._iobuf.read(self._full_header_length + body_len)
|
||||
|
||||
# leave leftover in current buffer
|
||||
leftover = self._iobuf.read()
|
||||
@@ -332,7 +332,7 @@ class LibevConnection(Connection):
|
||||
self._total_reqd_bytes = 0
|
||||
self.process_msg(msg, body_len)
|
||||
else:
|
||||
self._total_reqd_bytes = body_len + 8
|
||||
self._total_reqd_bytes = body_len + self._full_header_length
|
||||
break
|
||||
else:
|
||||
log.debug("Connection %s closed by server", self)
|
||||
|
||||
@@ -17,14 +17,9 @@ import struct
|
||||
|
||||
|
||||
def _make_packer(format_string):
|
||||
try:
|
||||
packer = struct.Struct(format_string) # new in Python 2.5
|
||||
except AttributeError:
|
||||
pack = lambda x: struct.pack(format_string, x)
|
||||
unpack = lambda s: struct.unpack(format_string, s)
|
||||
else:
|
||||
pack = packer.pack
|
||||
unpack = lambda s: packer.unpack(s)[0]
|
||||
packer = struct.Struct(format_string)
|
||||
pack = packer.pack
|
||||
unpack = lambda s: packer.unpack(s)[0]
|
||||
return pack, unpack
|
||||
|
||||
int64_pack, int64_unpack = _make_packer('>q')
|
||||
@@ -43,6 +38,11 @@ header_struct = struct.Struct('>BBbB')
|
||||
header_pack = header_struct.pack
|
||||
header_unpack = header_struct.unpack
|
||||
|
||||
# in protocol version 3 and higher, the stream ID is two bytes
|
||||
v3_header_struct = struct.Struct('>BBhB')
|
||||
v3_header_pack = v3_header_struct.pack
|
||||
v3_header_unpack = v3_header_struct.unpack
|
||||
|
||||
|
||||
if six.PY3:
|
||||
def varint_unpack(term):
|
||||
|
||||
@@ -28,7 +28,7 @@ except ImportError:
|
||||
from cassandra.util import WeakSet # NOQA
|
||||
|
||||
from cassandra import AuthenticationFailed
|
||||
from cassandra.connection import MAX_STREAM_PER_CONNECTION, ConnectionException
|
||||
from cassandra.connection import ConnectionException
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -349,6 +349,7 @@ class HostConnectionPool(object):
|
||||
max_conns = self._session.cluster.get_max_connections_per_host(self.host_distance)
|
||||
|
||||
least_busy = min(conns, key=lambda c: c.in_flight)
|
||||
request_id = None
|
||||
# to avoid another thread closing this connection while
|
||||
# trashing it (through the return_connection process), hold
|
||||
# the connection lock from this point until we've incremented
|
||||
@@ -356,15 +357,16 @@ class HostConnectionPool(object):
|
||||
need_to_wait = False
|
||||
with least_busy.lock:
|
||||
|
||||
if least_busy.in_flight >= MAX_STREAM_PER_CONNECTION:
|
||||
if least_busy.in_flight >= least_busy.max_request_id:
|
||||
# once we release the lock, wait for another connection
|
||||
need_to_wait = True
|
||||
else:
|
||||
least_busy.in_flight += 1
|
||||
request_id = least_busy.get_request_id()
|
||||
|
||||
if need_to_wait:
|
||||
# wait_for_conn will increment in_flight on the conn
|
||||
least_busy = self._wait_for_conn(timeout)
|
||||
least_busy, request_id = self._wait_for_conn(timeout)
|
||||
|
||||
# if we have too many requests on this connection but we still
|
||||
# have space to open a new connection against this host, go ahead
|
||||
@@ -372,7 +374,7 @@ class HostConnectionPool(object):
|
||||
if least_busy.in_flight >= max_reqs and len(self._connections) < max_conns:
|
||||
self._maybe_spawn_new_connection()
|
||||
|
||||
return least_busy
|
||||
return least_busy, request_id
|
||||
|
||||
def _maybe_spawn_new_connection(self):
|
||||
with self._lock:
|
||||
@@ -461,9 +463,9 @@ class HostConnectionPool(object):
|
||||
if conns:
|
||||
least_busy = min(conns, key=lambda c: c.in_flight)
|
||||
with least_busy.lock:
|
||||
if least_busy.in_flight < MAX_STREAM_PER_CONNECTION:
|
||||
if least_busy.in_flight < least_busy.max_request_id:
|
||||
least_busy.in_flight += 1
|
||||
return least_busy
|
||||
return least_busy, least_busy.get_request_id()
|
||||
|
||||
remaining = timeout - (time.time() - start)
|
||||
|
||||
|
||||
@@ -23,7 +23,8 @@ from cassandra import (Unavailable, WriteTimeout, ReadTimeout,
|
||||
AlreadyExists, InvalidRequest, Unauthorized,
|
||||
UnsupportedOperation)
|
||||
from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack,
|
||||
int8_pack, int8_unpack, uint64_pack, header_pack)
|
||||
int8_pack, int8_unpack, uint64_pack, header_pack,
|
||||
v3_header_pack)
|
||||
from cassandra.cqltypes import (AsciiType, BytesType, BooleanType,
|
||||
CounterColumnType, DateType, DecimalType,
|
||||
DoubleType, FloatType, Int32Type,
|
||||
@@ -80,11 +81,7 @@ class _MessageType(object):
|
||||
flags |= TRACING_FLAG
|
||||
|
||||
msg = six.BytesIO()
|
||||
write_header(
|
||||
msg,
|
||||
protocol_version | HEADER_DIRECTION_FROM_CLIENT,
|
||||
flags, stream_id, self.opcode, len(body)
|
||||
)
|
||||
write_header(msg, protocol_version, flags, stream_id, self.opcode, len(body))
|
||||
msg.write(body)
|
||||
|
||||
return msg.getvalue()
|
||||
@@ -824,7 +821,8 @@ def write_header(f, version, flags, stream_id, opcode, length):
|
||||
"""
|
||||
Write a CQL protocol frame header.
|
||||
"""
|
||||
f.write(header_pack(version, flags, stream_id, opcode))
|
||||
pack = v3_header_pack if version >= 3 else header_pack
|
||||
f.write(pack(version | HEADER_DIRECTION_FROM_CLIENT, flags, stream_id, opcode))
|
||||
write_int(f, length)
|
||||
|
||||
|
||||
|
||||
@@ -46,6 +46,9 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
||||
cls.mock_socket = cls.socket_patcher.start()
|
||||
cls.mock_socket().connect_ex.return_value = 0
|
||||
cls.mock_socket().getsockopt.return_value = 0
|
||||
cls.mock_socket().fileno.return_value = 100
|
||||
|
||||
AsyncoreConnection.add_channel = lambda *args, **kwargs: None
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
@@ -97,7 +100,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
||||
# let it write out a StartupMessage
|
||||
c.handle_write()
|
||||
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=1)
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=0)
|
||||
c.socket.recv.return_value = self.make_msg(header)
|
||||
c.handle_read()
|
||||
|
||||
@@ -169,7 +172,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
||||
# let it write out a StartupMessage
|
||||
c.handle_write()
|
||||
|
||||
header = self.make_header_prefix(ServerError, stream_id=1)
|
||||
header = self.make_header_prefix(ServerError, stream_id=0)
|
||||
body = self.make_error_body(ServerError.error_code, ServerError.summary)
|
||||
c.socket.recv.return_value = self.make_msg(header, body)
|
||||
c.handle_read()
|
||||
@@ -251,7 +254,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
||||
# let it write out a StartupMessage
|
||||
c.handle_write()
|
||||
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=1)
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=0)
|
||||
c.socket.recv.return_value = self.make_msg(header)
|
||||
c.handle_read()
|
||||
|
||||
@@ -278,7 +281,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
|
||||
# let it write out a StartupMessage
|
||||
c.handle_write()
|
||||
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=1)
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=0)
|
||||
c.socket.recv.return_value = self.make_msg(header)
|
||||
c.handle_read()
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ class LibevConnectionTest(unittest.TestCase):
|
||||
# let it write out a StartupMessage
|
||||
c.handle_write(None, 0)
|
||||
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=1)
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=0)
|
||||
c._socket.recv.return_value = self.make_msg(header)
|
||||
c.handle_read(None, 0)
|
||||
|
||||
@@ -169,7 +169,7 @@ class LibevConnectionTest(unittest.TestCase):
|
||||
# let it write out a StartupMessage
|
||||
c.handle_write(None, 0)
|
||||
|
||||
header = self.make_header_prefix(ServerError, stream_id=1)
|
||||
header = self.make_header_prefix(ServerError, stream_id=0)
|
||||
body = self.make_error_body(ServerError.error_code, ServerError.summary)
|
||||
c._socket.recv.return_value = self.make_msg(header, body)
|
||||
c.handle_read(None, 0)
|
||||
@@ -252,7 +252,7 @@ class LibevConnectionTest(unittest.TestCase):
|
||||
# let it write out a StartupMessage
|
||||
c.handle_write(None, 0)
|
||||
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=1)
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=0)
|
||||
c._socket.recv.return_value = self.make_msg(header)
|
||||
c.handle_read(None, 0)
|
||||
|
||||
@@ -279,7 +279,7 @@ class LibevConnectionTest(unittest.TestCase):
|
||||
# let it write out a StartupMessage
|
||||
c.handle_write(None, 0)
|
||||
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=1)
|
||||
header = self.make_header_prefix(ReadyMessage, stream_id=0)
|
||||
c._socket.recv.return_value = self.make_msg(header)
|
||||
c.handle_read(None, 0)
|
||||
|
||||
|
||||
@@ -68,7 +68,6 @@ class ConnectionTest(unittest.TestCase):
|
||||
|
||||
def test_bad_protocol_version(self, *args):
|
||||
c = self.make_connection()
|
||||
c._id_queue.get_nowait()
|
||||
c._callbacks = Mock()
|
||||
c.defunct = Mock()
|
||||
|
||||
@@ -85,7 +84,6 @@ class ConnectionTest(unittest.TestCase):
|
||||
|
||||
def test_bad_header_direction(self, *args):
|
||||
c = self.make_connection()
|
||||
c._id_queue.get_nowait()
|
||||
c._callbacks = Mock()
|
||||
c.defunct = Mock()
|
||||
|
||||
@@ -107,7 +105,6 @@ class ConnectionTest(unittest.TestCase):
|
||||
|
||||
def test_negative_body_length(self, *args):
|
||||
c = self.make_connection()
|
||||
c._id_queue.get_nowait()
|
||||
c._callbacks = Mock()
|
||||
c.defunct = Mock()
|
||||
|
||||
@@ -124,7 +121,6 @@ class ConnectionTest(unittest.TestCase):
|
||||
|
||||
def test_unsupported_cql_version(self, *args):
|
||||
c = self.make_connection()
|
||||
c._id_queue.get_nowait()
|
||||
c._callbacks = {0: c._handle_options_response}
|
||||
c.defunct = Mock()
|
||||
c.cql_version = "3.0.3"
|
||||
@@ -149,7 +145,6 @@ class ConnectionTest(unittest.TestCase):
|
||||
|
||||
def test_prefer_lz4_compression(self, *args):
|
||||
c = self.make_connection()
|
||||
c._id_queue.get_nowait()
|
||||
c._callbacks = {0: c._handle_options_response}
|
||||
c.defunct = Mock()
|
||||
c.cql_version = "3.0.3"
|
||||
@@ -176,7 +171,6 @@ class ConnectionTest(unittest.TestCase):
|
||||
|
||||
def test_requested_compression_not_available(self, *args):
|
||||
c = self.make_connection()
|
||||
c._id_queue.get_nowait()
|
||||
c._callbacks = {0: c._handle_options_response}
|
||||
c.defunct = Mock()
|
||||
# request lz4 compression
|
||||
@@ -208,7 +202,6 @@ class ConnectionTest(unittest.TestCase):
|
||||
|
||||
def test_use_requested_compression(self, *args):
|
||||
c = self.make_connection()
|
||||
c._id_queue.get_nowait()
|
||||
c._callbacks = {0: c._handle_options_response}
|
||||
c.defunct = Mock()
|
||||
# request snappy compression
|
||||
@@ -237,7 +230,6 @@ class ConnectionTest(unittest.TestCase):
|
||||
|
||||
def test_disable_compression(self, *args):
|
||||
c = self.make_connection()
|
||||
c._id_queue.get_nowait()
|
||||
c._callbacks = {0: c._handle_options_response}
|
||||
c.defunct = Mock()
|
||||
# disable compression
|
||||
|
||||
@@ -21,7 +21,7 @@ from mock import Mock, NonCallableMagicMock
|
||||
from threading import Thread, Event
|
||||
|
||||
from cassandra.cluster import Session
|
||||
from cassandra.connection import Connection, MAX_STREAM_PER_CONNECTION
|
||||
from cassandra.connection import Connection
|
||||
from cassandra.pool import Host, HostConnectionPool, NoConnectionsAvailable
|
||||
from cassandra.policies import HostDistance, SimpleConvictionPolicy
|
||||
|
||||
@@ -38,13 +38,13 @@ class HostConnectionPoolTests(unittest.TestCase):
|
||||
def test_borrow_and_return(self):
|
||||
host = Mock(spec=Host, address='ip1')
|
||||
session = self.make_session()
|
||||
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False)
|
||||
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100)
|
||||
session.cluster.connection_factory.return_value = conn
|
||||
|
||||
pool = HostConnectionPool(host, HostDistance.LOCAL, session)
|
||||
session.cluster.connection_factory.assert_called_once_with(host.address)
|
||||
|
||||
c = pool.borrow_connection(timeout=0.01)
|
||||
c, request_id = pool.borrow_connection(timeout=0.01)
|
||||
self.assertIs(c, conn)
|
||||
self.assertEqual(1, conn.in_flight)
|
||||
conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace')
|
||||
@@ -56,7 +56,7 @@ class HostConnectionPoolTests(unittest.TestCase):
|
||||
def test_failed_wait_for_connection(self):
|
||||
host = Mock(spec=Host, address='ip1')
|
||||
session = self.make_session()
|
||||
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False)
|
||||
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100)
|
||||
session.cluster.connection_factory.return_value = conn
|
||||
|
||||
pool = HostConnectionPool(host, HostDistance.LOCAL, session)
|
||||
@@ -65,7 +65,7 @@ class HostConnectionPoolTests(unittest.TestCase):
|
||||
pool.borrow_connection(timeout=0.01)
|
||||
self.assertEqual(1, conn.in_flight)
|
||||
|
||||
conn.in_flight = MAX_STREAM_PER_CONNECTION
|
||||
conn.in_flight = conn.max_request_id
|
||||
|
||||
# we're already at the max number of requests for this connection,
|
||||
# so we this should fail
|
||||
@@ -74,7 +74,7 @@ class HostConnectionPoolTests(unittest.TestCase):
|
||||
def test_successful_wait_for_connection(self):
|
||||
host = Mock(spec=Host, address='ip1')
|
||||
session = self.make_session()
|
||||
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False)
|
||||
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100)
|
||||
session.cluster.connection_factory.return_value = conn
|
||||
|
||||
pool = HostConnectionPool(host, HostDistance.LOCAL, session)
|
||||
@@ -84,7 +84,7 @@ class HostConnectionPoolTests(unittest.TestCase):
|
||||
self.assertEqual(1, conn.in_flight)
|
||||
|
||||
def get_second_conn():
|
||||
c = pool.borrow_connection(1.0)
|
||||
c, request_id = pool.borrow_connection(1.0)
|
||||
self.assertIs(conn, c)
|
||||
pool.return_connection(c)
|
||||
|
||||
@@ -98,7 +98,7 @@ class HostConnectionPoolTests(unittest.TestCase):
|
||||
def test_all_connections_trashed(self):
|
||||
host = Mock(spec=Host, address='ip1')
|
||||
session = self.make_session()
|
||||
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False)
|
||||
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100)
|
||||
session.cluster.connection_factory.return_value = conn
|
||||
session.cluster.get_core_connections_per_host.return_value = 1
|
||||
|
||||
@@ -118,7 +118,7 @@ class HostConnectionPoolTests(unittest.TestCase):
|
||||
|
||||
def get_conn():
|
||||
conn.reset_mock()
|
||||
c = pool.borrow_connection(1.0)
|
||||
c, request_id = pool.borrow_connection(1.0)
|
||||
self.assertIs(conn, c)
|
||||
self.assertEqual(1, conn.in_flight)
|
||||
conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace')
|
||||
@@ -140,7 +140,8 @@ class HostConnectionPoolTests(unittest.TestCase):
|
||||
def test_spawn_when_at_max(self):
|
||||
host = Mock(spec=Host, address='ip1')
|
||||
session = self.make_session()
|
||||
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False)
|
||||
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100)
|
||||
conn.max_request_id = 100
|
||||
session.cluster.connection_factory.return_value = conn
|
||||
|
||||
# core conns = 1, max conns = 2
|
||||
@@ -153,7 +154,7 @@ class HostConnectionPoolTests(unittest.TestCase):
|
||||
self.assertEqual(1, conn.in_flight)
|
||||
|
||||
# make this conn full
|
||||
conn.in_flight = MAX_STREAM_PER_CONNECTION
|
||||
conn.in_flight = conn.max_request_id
|
||||
|
||||
# we don't care about making this borrow_connection call succeed for the
|
||||
# purposes of this test, as long as it results in a new connection
|
||||
@@ -164,7 +165,7 @@ class HostConnectionPoolTests(unittest.TestCase):
|
||||
def test_return_defunct_connection(self):
|
||||
host = Mock(spec=Host, address='ip1')
|
||||
session = self.make_session()
|
||||
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False)
|
||||
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100)
|
||||
session.cluster.connection_factory.return_value = conn
|
||||
|
||||
pool = HostConnectionPool(host, HostDistance.LOCAL, session)
|
||||
@@ -183,7 +184,7 @@ class HostConnectionPoolTests(unittest.TestCase):
|
||||
def test_return_defunct_connection_on_down_host(self):
|
||||
host = Mock(spec=Host, address='ip1')
|
||||
session = self.make_session()
|
||||
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False)
|
||||
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100)
|
||||
session.cluster.connection_factory.return_value = conn
|
||||
|
||||
pool = HostConnectionPool(host, HostDistance.LOCAL, session)
|
||||
@@ -203,7 +204,7 @@ class HostConnectionPoolTests(unittest.TestCase):
|
||||
def test_return_closed_connection(self):
|
||||
host = Mock(spec=Host, address='ip1')
|
||||
session = self.make_session()
|
||||
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=True)
|
||||
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=True, max_request_id=100)
|
||||
session.cluster.connection_factory.return_value = conn
|
||||
|
||||
pool = HostConnectionPool(host, HostDistance.LOCAL, session)
|
||||
|
||||
@@ -21,7 +21,7 @@ from mock import Mock, MagicMock, ANY
|
||||
|
||||
from cassandra import ConsistencyLevel
|
||||
from cassandra.cluster import Session, ResponseFuture, NoHostAvailable
|
||||
from cassandra.connection import ConnectionException
|
||||
from cassandra.connection import Connection, ConnectionException
|
||||
from cassandra.protocol import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessage,
|
||||
UnavailableErrorMessage, ResultMessage, QueryMessage,
|
||||
OverloadedErrorMessage, IsBootstrappingErrorMessage,
|
||||
@@ -58,13 +58,16 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
pool = session._pools.get.return_value
|
||||
pool.is_shutdown = False
|
||||
|
||||
connection = Mock(spec=Connection)
|
||||
pool.borrow_connection.return_value = (connection, 1)
|
||||
|
||||
rf = self.make_response_future(session)
|
||||
rf.send_request()
|
||||
|
||||
rf.session._pools.get.assert_called_once_with('ip1')
|
||||
pool.borrow_connection.assert_called_once_with(timeout=ANY)
|
||||
connection = pool.borrow_connection.return_value
|
||||
connection.send_msg.assert_called_once_with(rf.message, cb=ANY)
|
||||
|
||||
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY)
|
||||
|
||||
rf._set_result(self.make_mock_response([{'col': 'val'}]))
|
||||
result = rf.result()
|
||||
@@ -72,6 +75,10 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
|
||||
def test_unknown_result_class(self):
|
||||
session = self.make_session()
|
||||
pool = session._pools.get.return_value
|
||||
connection = Mock(spec=Connection)
|
||||
pool.borrow_connection.return_value = (connection, 1)
|
||||
|
||||
rf = self.make_response_future(session)
|
||||
rf.send_request()
|
||||
rf._set_result(object())
|
||||
@@ -168,18 +175,21 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
def test_retry_policy_says_retry(self):
|
||||
session = self.make_session()
|
||||
pool = session._pools.get.return_value
|
||||
|
||||
query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)")
|
||||
query.retry_policy = Mock()
|
||||
query.retry_policy.on_unavailable.return_value = (RetryPolicy.RETRY, ConsistencyLevel.ONE)
|
||||
message = QueryMessage(query=query, consistency_level=ConsistencyLevel.QUORUM)
|
||||
|
||||
connection = Mock(spec=Connection)
|
||||
pool.borrow_connection.return_value = (connection, 1)
|
||||
|
||||
rf = ResponseFuture(session, message, query)
|
||||
rf.send_request()
|
||||
|
||||
rf.session._pools.get.assert_called_once_with('ip1')
|
||||
pool.borrow_connection.assert_called_once_with(timeout=ANY)
|
||||
connection = pool.borrow_connection.return_value
|
||||
connection.send_msg.assert_called_once_with(rf.message, cb=ANY)
|
||||
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY)
|
||||
|
||||
result = Mock(spec=UnavailableErrorMessage, info={})
|
||||
rf._set_result(result)
|
||||
@@ -187,6 +197,9 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
session.submit.assert_called_once_with(rf._retry_task, True)
|
||||
self.assertEqual(1, rf._query_retries)
|
||||
|
||||
connection = Mock(spec=Connection)
|
||||
pool.borrow_connection.return_value = (connection, 2)
|
||||
|
||||
# simulate the executor running this
|
||||
rf._retry_task(True)
|
||||
|
||||
@@ -194,21 +207,22 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
# an UnavailableException
|
||||
rf.session._pools.get.assert_called_with('ip1')
|
||||
pool.borrow_connection.assert_called_with(timeout=ANY)
|
||||
connection = pool.borrow_connection.return_value
|
||||
connection.send_msg.assert_called_with(rf.message, cb=ANY)
|
||||
connection.send_msg.assert_called_with(rf.message, 2, cb=ANY)
|
||||
|
||||
def test_retry_with_different_host(self):
|
||||
session = self.make_session()
|
||||
pool = session._pools.get.return_value
|
||||
|
||||
connection = Mock(spec=Connection)
|
||||
pool.borrow_connection.return_value = (connection, 1)
|
||||
|
||||
rf = self.make_response_future(session)
|
||||
rf.message.consistency_level = ConsistencyLevel.QUORUM
|
||||
rf.send_request()
|
||||
|
||||
rf.session._pools.get.assert_called_once_with('ip1')
|
||||
pool.borrow_connection.assert_called_once_with(timeout=ANY)
|
||||
connection = pool.borrow_connection.return_value
|
||||
connection.send_msg.assert_called_once_with(rf.message, cb=ANY)
|
||||
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY)
|
||||
self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level)
|
||||
|
||||
result = Mock(spec=OverloadedErrorMessage, info={})
|
||||
@@ -218,20 +232,24 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
# query_retries does not get incremented for Overloaded/Bootstrapping errors
|
||||
self.assertEqual(0, rf._query_retries)
|
||||
|
||||
connection = Mock(spec=Connection)
|
||||
pool.borrow_connection.return_value = (connection, 2)
|
||||
# simulate the executor running this
|
||||
rf._retry_task(False)
|
||||
|
||||
# it should try with a different host
|
||||
rf.session._pools.get.assert_called_with('ip2')
|
||||
pool.borrow_connection.assert_called_with(timeout=ANY)
|
||||
connection = pool.borrow_connection.return_value
|
||||
connection.send_msg.assert_called_with(rf.message, cb=ANY)
|
||||
connection.send_msg.assert_called_with(rf.message, 2, cb=ANY)
|
||||
|
||||
# the consistency level should be the same
|
||||
self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level)
|
||||
|
||||
def test_all_retries_fail(self):
|
||||
session = self.make_session()
|
||||
pool = session._pools.get.return_value
|
||||
connection = Mock(spec=Connection)
|
||||
pool.borrow_connection.return_value = (connection, 1)
|
||||
|
||||
rf = self.make_response_future(session)
|
||||
rf.send_request()
|
||||
@@ -287,7 +305,11 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
exc = NoConnectionsAvailable()
|
||||
first_pool = Mock(is_shutdown=False)
|
||||
first_pool.borrow_connection.side_effect = exc
|
||||
|
||||
# the second pool will return a connection
|
||||
second_pool = Mock(is_shutdown=False)
|
||||
connection = Mock(spec=Connection)
|
||||
second_pool.borrow_connection.return_value = (connection, 1)
|
||||
|
||||
session._pools.get.side_effect = [first_pool, second_pool]
|
||||
|
||||
@@ -317,6 +339,10 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
|
||||
def test_errback(self):
|
||||
session = self.make_session()
|
||||
pool = session._pools.get.return_value
|
||||
connection = Mock(spec=Connection)
|
||||
pool.borrow_connection.return_value = (connection, 1)
|
||||
|
||||
query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)")
|
||||
query.retry_policy = Mock()
|
||||
query.retry_policy.on_unavailable.return_value = (RetryPolicy.RETHROW, None)
|
||||
@@ -366,6 +392,10 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
|
||||
def test_prepared_query_not_found(self):
|
||||
session = self.make_session()
|
||||
pool = session._pools.get.return_value
|
||||
connection = Mock(spec=Connection)
|
||||
pool.borrow_connection.return_value = (connection, 1)
|
||||
|
||||
rf = self.make_response_future(session)
|
||||
rf.send_request()
|
||||
|
||||
@@ -386,6 +416,10 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
|
||||
def test_prepared_query_not_found_bad_keyspace(self):
|
||||
session = self.make_session()
|
||||
pool = session._pools.get.return_value
|
||||
connection = Mock(spec=Connection)
|
||||
pool.borrow_connection.return_value = (connection, 1)
|
||||
|
||||
rf = self.make_response_future(session)
|
||||
rf.send_request()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user