2 byte request IDs + queue-less req ID management

This commit is contained in:
Tyler Hobbs
2014-05-30 15:41:06 -05:00
parent 47e79ac668
commit 98ef13169a
12 changed files with 163 additions and 102 deletions

View File

@@ -2143,8 +2143,8 @@ class ResponseFuture(object):
connection = None
try:
# TODO get connectTimeout from cluster settings
connection = pool.borrow_connection(timeout=2.0)
request_id = connection.send_msg(message, cb=cb)
connection, request_id = pool.borrow_connection(timeout=2.0)
connection.send_msg(message, request_id, cb=cb)
except NoConnectionsAvailable as exc:
log.debug("All connections for host %s are at capacity, moving to the next host", host)
self._errors[host] = exc

View File

@@ -29,7 +29,7 @@ import six
from six.moves import range
from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut
from cassandra.marshal import int32_pack, header_unpack
from cassandra.marshal import int32_pack, header_unpack, v3_header_unpack
from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessage,
StartupMessage, ErrorMessage, CredentialsMessage,
QueryMessage, ResultMessage, decode_response,
@@ -79,8 +79,6 @@ else:
locally_supported_compressions['snappy'] = (snappy.compress, decompress)
MAX_STREAM_PER_CONNECTION = 127
PROTOCOL_VERSION_MASK = 0x7f
HEADER_DIRECTION_FROM_CLIENT = 0x00
@@ -153,6 +151,7 @@ class Connection(object):
ssl_options = None
last_error = None
in_flight = 0
current_request_id = 0
is_defunct = False
is_closed = False
lock = None
@@ -172,10 +171,27 @@ class Connection(object):
self.protocol_version = protocol_version
self.is_control_connection = is_control_connection
self._push_watchers = defaultdict(set)
if protocol_version >= 3:
self._header_unpack = v3_header_unpack
self._header_length = 5
self.max_request_id = (2 ** 15) - 1
else:
self._header_unpack = header_unpack
self._header_length = 4
self.max_request_id = (2 ** 7) - 1
self._id_queue = Queue(MAX_STREAM_PER_CONNECTION)
for i in range(MAX_STREAM_PER_CONNECTION):
self._id_queue.put_nowait(i)
# 0 8 16 24 32 40
# +---------+---------+---------+---------+---------+
# | version | flags | stream | opcode |
# +---------+---------+---------+---------+---------+
# | length |
# +---------+---------+---------+---------+
# | |
# . ... body ... .
# . .
# . .
# +----------------------------------------
self._full_header_length = self._header_length + 4
self.lock = RLock()
@@ -210,6 +226,11 @@ class Connection(object):
"failed connection (%s) to host %s:",
id(self), self.host, exc_info=True)
def get_request_id(self):
current = self.current_request_id
self.current_request_id = (current + 1) % self.max_request_id
return current
def handle_pushed(self, response):
log.debug("Message pushed from server: %r", response)
for cb in self._push_watchers.get(response.event_type, []):
@@ -218,21 +239,12 @@ class Connection(object):
except Exception:
log.exception("Pushed event handler errored, ignoring:")
def send_msg(self, msg, cb, wait_for_id=False):
def send_msg(self, msg, request_id, cb):
if self.is_defunct:
raise ConnectionShutdown("Connection to %s is defunct" % self.host)
elif self.is_closed:
raise ConnectionShutdown("Connection to %s is closed" % self.host)
if not wait_for_id:
try:
request_id = self._id_queue.get_nowait()
except Empty:
raise ConnectionBusy(
"Connection to %s is at the max number of requests" % self.host)
else:
request_id = self._id_queue.get()
self._callbacks[request_id] = cb
self.push(msg.to_binary(request_id, self.protocol_version, compression=self.compressor))
return request_id
@@ -251,11 +263,15 @@ class Connection(object):
while True:
needed = len(msgs) - messages_sent
with self.lock:
available = min(needed, MAX_STREAM_PER_CONNECTION - self.in_flight)
available = min(needed, self.max_request_id - self.in_flight)
start_request_id = self.current_request_id
self.current_request_id = (self.current_request_id + available) % self.max_request_id
self.in_flight += available
for i in range(messages_sent, messages_sent + available):
self.send_msg(msgs[i], partial(waiter.got_response, index=i), wait_for_id=True)
for i in range(available):
self.send_msg(msgs[messages_sent + i],
(start_request_id + i) % self.max_request_id,
partial(waiter.got_response, index=messages_sent + i))
messages_sent += available
if messages_sent == len(msgs):
@@ -287,12 +303,11 @@ class Connection(object):
@defunct_on_error
def process_msg(self, msg, body_len):
version, flags, stream_id, opcode = header_unpack(msg[:4])
version, flags, stream_id, opcode = self._header_unpack(msg[:self._header_length])
if stream_id < 0:
callback = None
else:
callback = self._callbacks.pop(stream_id, None)
self._id_queue.put_nowait(stream_id)
body = None
try:
@@ -344,7 +359,7 @@ class Connection(object):
self._send_startup_message()
else:
log.debug("Sending initial options message for new connection (%s) to %s", id(self), self.host)
self.send_msg(OptionsMessage(), self._handle_options_response)
self.send_msg(OptionsMessage(), 0, self._handle_options_response)
@defunct_on_error
def _handle_options_response(self, options_response):
@@ -411,11 +426,13 @@ class Connection(object):
@defunct_on_error
def _send_startup_message(self, compression=None):
log.debug("Sending StartupMessage on %s", self)
opts = {}
if compression:
opts['COMPRESSION'] = compression
sm = StartupMessage(cqlversion=self.cql_version, options=opts)
self.send_msg(sm, cb=self._handle_startup_response)
self.send_msg(sm, 0, cb=self._handle_startup_response)
log.debug("Sent StartupMessage on %s", self)
@defunct_on_error
def _handle_startup_response(self, startup_response, did_authenticate=False):
@@ -439,12 +456,12 @@ class Connection(object):
log.debug("Sending credentials-based auth response on %s", self)
cm = CredentialsMessage(creds=self.authenticator)
callback = partial(self._handle_startup_response, did_authenticate=True)
self.send_msg(cm, cb=callback)
self.send_msg(cm, 0, cb=callback)
else:
log.debug("Sending SASL-based auth response on %s", self)
initial_response = self.authenticator.initial_response()
initial_response = "" if initial_response is None else initial_response.encode('utf-8')
self.send_msg(AuthResponseMessage(initial_response), self._handle_auth_response)
self.send_msg(AuthResponseMessage(initial_response), 0, self._handle_auth_response)
elif isinstance(startup_response, ErrorMessage):
log.debug("Received ErrorMessage on new connection (%s) from %s: %s",
id(self), self.host, startup_response.summary_msg())
@@ -479,7 +496,7 @@ class Connection(object):
response = self.authenticator.evaluate_challenge(auth_response.challenge)
msg = AuthResponseMessage("" if response is None else response)
log.debug("Responding to auth challenge on %s", self)
self.send_msg(msg, self._handle_auth_response)
self.send_msg(msg, 0, self._handle_auth_response)
elif isinstance(auth_response, ErrorMessage):
log.debug("Received ErrorMessage on new connection (%s) from %s: %s",
id(self), self.host, auth_response.summary_msg())
@@ -543,7 +560,21 @@ class Connection(object):
callback(self, self.defunct(ConnectionException(
"Problem while setting keyspace: %r" % (result,), self.host)))
self.send_msg(query, process_result, wait_for_id=True)
request_id = None
# we use a busy wait on the lock here because:
# - we'll only spin if the connection is at max capacity, which is very
# unlikely for a set_keyspace call
# - it allows us to avoid signaling a condition every time a request completes
while True:
with self.lock:
if self.in_flight < self.max_request_id:
request_id = self.get_request_id()
self.in_flight += 1
break
time.sleep(0.001)
self.send_msg(query, request_id, process_result)
def __str__(self):
status = ""

View File

@@ -229,7 +229,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
self.defunct(sys.exc_info()[1])
def handle_close(self):
log.debug("connection (%s) to %s closed by server", id(self), self.host)
log.debug("Connection %s closed by server", self)
self.close()
def handle_write(self):
@@ -277,24 +277,24 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
if self._iobuf.tell():
while True:
pos = self._iobuf.tell()
if pos < 8 or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
if pos < self._full_header_length or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
# we don't have a complete header yet or we
# already saw a header, but we don't have a
# complete message yet
break
else:
# have enough for header, read body len from header
self._iobuf.seek(4)
self._iobuf.seek(self._header_length)
body_len = int32_unpack(self._iobuf.read(4))
# seek to end to get length of current buffer
self._iobuf.seek(0, os.SEEK_END)
pos = self._iobuf.tell()
if pos >= body_len + 8:
if pos >= body_len + self._full_header_length:
# read message header and body
self._iobuf.seek(0)
msg = self._iobuf.read(8 + body_len)
msg = self._iobuf.read(self._full_header_length + body_len)
# leave leftover in current buffer
leftover = self._iobuf.read()
@@ -304,7 +304,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
self._total_reqd_bytes = 0
self.process_msg(msg, body_len)
else:
self._total_reqd_bytes = body_len + 8
self._total_reqd_bytes = body_len + self._full_header_length
break
if not self._callbacks and not self.is_control_connection:

View File

@@ -305,24 +305,24 @@ class LibevConnection(Connection):
if self._iobuf.tell():
while True:
pos = self._iobuf.tell()
if pos < 8 or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
if pos < self._full_header_length or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
# we don't have a complete header yet or we
# already saw a header, but we don't have a
# complete message yet
break
else:
# have enough for header, read body len from header
self._iobuf.seek(4)
self._iobuf.seek(self._header_length)
body_len = int32_unpack(self._iobuf.read(4))
# seek to end to get length of current buffer
self._iobuf.seek(0, os.SEEK_END)
pos = self._iobuf.tell()
if pos >= body_len + 8:
if pos >= body_len + self._full_header_length:
# read message header and body
self._iobuf.seek(0)
msg = self._iobuf.read(8 + body_len)
msg = self._iobuf.read(self._full_header_length + body_len)
# leave leftover in current buffer
leftover = self._iobuf.read()
@@ -332,7 +332,7 @@ class LibevConnection(Connection):
self._total_reqd_bytes = 0
self.process_msg(msg, body_len)
else:
self._total_reqd_bytes = body_len + 8
self._total_reqd_bytes = body_len + self._full_header_length
break
else:
log.debug("Connection %s closed by server", self)

View File

@@ -17,14 +17,9 @@ import struct
def _make_packer(format_string):
try:
packer = struct.Struct(format_string) # new in Python 2.5
except AttributeError:
pack = lambda x: struct.pack(format_string, x)
unpack = lambda s: struct.unpack(format_string, s)
else:
pack = packer.pack
unpack = lambda s: packer.unpack(s)[0]
packer = struct.Struct(format_string)
pack = packer.pack
unpack = lambda s: packer.unpack(s)[0]
return pack, unpack
int64_pack, int64_unpack = _make_packer('>q')
@@ -43,6 +38,11 @@ header_struct = struct.Struct('>BBbB')
header_pack = header_struct.pack
header_unpack = header_struct.unpack
# in protocol version 3 and higher, the stream ID is two bytes
v3_header_struct = struct.Struct('>BBhB')
v3_header_pack = v3_header_struct.pack
v3_header_unpack = v3_header_struct.unpack
if six.PY3:
def varint_unpack(term):

View File

@@ -28,7 +28,7 @@ except ImportError:
from cassandra.util import WeakSet # NOQA
from cassandra import AuthenticationFailed
from cassandra.connection import MAX_STREAM_PER_CONNECTION, ConnectionException
from cassandra.connection import ConnectionException
log = logging.getLogger(__name__)
@@ -349,6 +349,7 @@ class HostConnectionPool(object):
max_conns = self._session.cluster.get_max_connections_per_host(self.host_distance)
least_busy = min(conns, key=lambda c: c.in_flight)
request_id = None
# to avoid another thread closing this connection while
# trashing it (through the return_connection process), hold
# the connection lock from this point until we've incremented
@@ -356,15 +357,16 @@ class HostConnectionPool(object):
need_to_wait = False
with least_busy.lock:
if least_busy.in_flight >= MAX_STREAM_PER_CONNECTION:
if least_busy.in_flight >= least_busy.max_request_id:
# once we release the lock, wait for another connection
need_to_wait = True
else:
least_busy.in_flight += 1
request_id = least_busy.get_request_id()
if need_to_wait:
# wait_for_conn will increment in_flight on the conn
least_busy = self._wait_for_conn(timeout)
least_busy, request_id = self._wait_for_conn(timeout)
# if we have too many requests on this connection but we still
# have space to open a new connection against this host, go ahead
@@ -372,7 +374,7 @@ class HostConnectionPool(object):
if least_busy.in_flight >= max_reqs and len(self._connections) < max_conns:
self._maybe_spawn_new_connection()
return least_busy
return least_busy, request_id
def _maybe_spawn_new_connection(self):
with self._lock:
@@ -461,9 +463,9 @@ class HostConnectionPool(object):
if conns:
least_busy = min(conns, key=lambda c: c.in_flight)
with least_busy.lock:
if least_busy.in_flight < MAX_STREAM_PER_CONNECTION:
if least_busy.in_flight < least_busy.max_request_id:
least_busy.in_flight += 1
return least_busy
return least_busy, least_busy.get_request_id()
remaining = timeout - (time.time() - start)

View File

@@ -23,7 +23,8 @@ from cassandra import (Unavailable, WriteTimeout, ReadTimeout,
AlreadyExists, InvalidRequest, Unauthorized,
UnsupportedOperation)
from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack,
int8_pack, int8_unpack, uint64_pack, header_pack)
int8_pack, int8_unpack, uint64_pack, header_pack,
v3_header_pack)
from cassandra.cqltypes import (AsciiType, BytesType, BooleanType,
CounterColumnType, DateType, DecimalType,
DoubleType, FloatType, Int32Type,
@@ -80,11 +81,7 @@ class _MessageType(object):
flags |= TRACING_FLAG
msg = six.BytesIO()
write_header(
msg,
protocol_version | HEADER_DIRECTION_FROM_CLIENT,
flags, stream_id, self.opcode, len(body)
)
write_header(msg, protocol_version, flags, stream_id, self.opcode, len(body))
msg.write(body)
return msg.getvalue()
@@ -824,7 +821,8 @@ def write_header(f, version, flags, stream_id, opcode, length):
"""
Write a CQL protocol frame header.
"""
f.write(header_pack(version, flags, stream_id, opcode))
pack = v3_header_pack if version >= 3 else header_pack
f.write(pack(version | HEADER_DIRECTION_FROM_CLIENT, flags, stream_id, opcode))
write_int(f, length)

View File

@@ -46,6 +46,9 @@ class AsyncoreConnectionTest(unittest.TestCase):
cls.mock_socket = cls.socket_patcher.start()
cls.mock_socket().connect_ex.return_value = 0
cls.mock_socket().getsockopt.return_value = 0
cls.mock_socket().fileno.return_value = 100
AsyncoreConnection.add_channel = lambda *args, **kwargs: None
@classmethod
def tearDownClass(cls):
@@ -97,7 +100,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
# let it write out a StartupMessage
c.handle_write()
header = self.make_header_prefix(ReadyMessage, stream_id=1)
header = self.make_header_prefix(ReadyMessage, stream_id=0)
c.socket.recv.return_value = self.make_msg(header)
c.handle_read()
@@ -169,7 +172,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
# let it write out a StartupMessage
c.handle_write()
header = self.make_header_prefix(ServerError, stream_id=1)
header = self.make_header_prefix(ServerError, stream_id=0)
body = self.make_error_body(ServerError.error_code, ServerError.summary)
c.socket.recv.return_value = self.make_msg(header, body)
c.handle_read()
@@ -251,7 +254,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
# let it write out a StartupMessage
c.handle_write()
header = self.make_header_prefix(ReadyMessage, stream_id=1)
header = self.make_header_prefix(ReadyMessage, stream_id=0)
c.socket.recv.return_value = self.make_msg(header)
c.handle_read()
@@ -278,7 +281,7 @@ class AsyncoreConnectionTest(unittest.TestCase):
# let it write out a StartupMessage
c.handle_write()
header = self.make_header_prefix(ReadyMessage, stream_id=1)
header = self.make_header_prefix(ReadyMessage, stream_id=0)
c.socket.recv.return_value = self.make_msg(header)
c.handle_read()

View File

@@ -97,7 +97,7 @@ class LibevConnectionTest(unittest.TestCase):
# let it write out a StartupMessage
c.handle_write(None, 0)
header = self.make_header_prefix(ReadyMessage, stream_id=1)
header = self.make_header_prefix(ReadyMessage, stream_id=0)
c._socket.recv.return_value = self.make_msg(header)
c.handle_read(None, 0)
@@ -169,7 +169,7 @@ class LibevConnectionTest(unittest.TestCase):
# let it write out a StartupMessage
c.handle_write(None, 0)
header = self.make_header_prefix(ServerError, stream_id=1)
header = self.make_header_prefix(ServerError, stream_id=0)
body = self.make_error_body(ServerError.error_code, ServerError.summary)
c._socket.recv.return_value = self.make_msg(header, body)
c.handle_read(None, 0)
@@ -252,7 +252,7 @@ class LibevConnectionTest(unittest.TestCase):
# let it write out a StartupMessage
c.handle_write(None, 0)
header = self.make_header_prefix(ReadyMessage, stream_id=1)
header = self.make_header_prefix(ReadyMessage, stream_id=0)
c._socket.recv.return_value = self.make_msg(header)
c.handle_read(None, 0)
@@ -279,7 +279,7 @@ class LibevConnectionTest(unittest.TestCase):
# let it write out a StartupMessage
c.handle_write(None, 0)
header = self.make_header_prefix(ReadyMessage, stream_id=1)
header = self.make_header_prefix(ReadyMessage, stream_id=0)
c._socket.recv.return_value = self.make_msg(header)
c.handle_read(None, 0)

View File

@@ -68,7 +68,6 @@ class ConnectionTest(unittest.TestCase):
def test_bad_protocol_version(self, *args):
c = self.make_connection()
c._id_queue.get_nowait()
c._callbacks = Mock()
c.defunct = Mock()
@@ -85,7 +84,6 @@ class ConnectionTest(unittest.TestCase):
def test_bad_header_direction(self, *args):
c = self.make_connection()
c._id_queue.get_nowait()
c._callbacks = Mock()
c.defunct = Mock()
@@ -107,7 +105,6 @@ class ConnectionTest(unittest.TestCase):
def test_negative_body_length(self, *args):
c = self.make_connection()
c._id_queue.get_nowait()
c._callbacks = Mock()
c.defunct = Mock()
@@ -124,7 +121,6 @@ class ConnectionTest(unittest.TestCase):
def test_unsupported_cql_version(self, *args):
c = self.make_connection()
c._id_queue.get_nowait()
c._callbacks = {0: c._handle_options_response}
c.defunct = Mock()
c.cql_version = "3.0.3"
@@ -149,7 +145,6 @@ class ConnectionTest(unittest.TestCase):
def test_prefer_lz4_compression(self, *args):
c = self.make_connection()
c._id_queue.get_nowait()
c._callbacks = {0: c._handle_options_response}
c.defunct = Mock()
c.cql_version = "3.0.3"
@@ -176,7 +171,6 @@ class ConnectionTest(unittest.TestCase):
def test_requested_compression_not_available(self, *args):
c = self.make_connection()
c._id_queue.get_nowait()
c._callbacks = {0: c._handle_options_response}
c.defunct = Mock()
# request lz4 compression
@@ -208,7 +202,6 @@ class ConnectionTest(unittest.TestCase):
def test_use_requested_compression(self, *args):
c = self.make_connection()
c._id_queue.get_nowait()
c._callbacks = {0: c._handle_options_response}
c.defunct = Mock()
# request snappy compression
@@ -237,7 +230,6 @@ class ConnectionTest(unittest.TestCase):
def test_disable_compression(self, *args):
c = self.make_connection()
c._id_queue.get_nowait()
c._callbacks = {0: c._handle_options_response}
c.defunct = Mock()
# disable compression

View File

@@ -21,7 +21,7 @@ from mock import Mock, NonCallableMagicMock
from threading import Thread, Event
from cassandra.cluster import Session
from cassandra.connection import Connection, MAX_STREAM_PER_CONNECTION
from cassandra.connection import Connection
from cassandra.pool import Host, HostConnectionPool, NoConnectionsAvailable
from cassandra.policies import HostDistance, SimpleConvictionPolicy
@@ -38,13 +38,13 @@ class HostConnectionPoolTests(unittest.TestCase):
def test_borrow_and_return(self):
host = Mock(spec=Host, address='ip1')
session = self.make_session()
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False)
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100)
session.cluster.connection_factory.return_value = conn
pool = HostConnectionPool(host, HostDistance.LOCAL, session)
session.cluster.connection_factory.assert_called_once_with(host.address)
c = pool.borrow_connection(timeout=0.01)
c, request_id = pool.borrow_connection(timeout=0.01)
self.assertIs(c, conn)
self.assertEqual(1, conn.in_flight)
conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace')
@@ -56,7 +56,7 @@ class HostConnectionPoolTests(unittest.TestCase):
def test_failed_wait_for_connection(self):
host = Mock(spec=Host, address='ip1')
session = self.make_session()
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False)
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100)
session.cluster.connection_factory.return_value = conn
pool = HostConnectionPool(host, HostDistance.LOCAL, session)
@@ -65,7 +65,7 @@ class HostConnectionPoolTests(unittest.TestCase):
pool.borrow_connection(timeout=0.01)
self.assertEqual(1, conn.in_flight)
conn.in_flight = MAX_STREAM_PER_CONNECTION
conn.in_flight = conn.max_request_id
# we're already at the max number of requests for this connection,
# so we this should fail
@@ -74,7 +74,7 @@ class HostConnectionPoolTests(unittest.TestCase):
def test_successful_wait_for_connection(self):
host = Mock(spec=Host, address='ip1')
session = self.make_session()
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False)
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100)
session.cluster.connection_factory.return_value = conn
pool = HostConnectionPool(host, HostDistance.LOCAL, session)
@@ -84,7 +84,7 @@ class HostConnectionPoolTests(unittest.TestCase):
self.assertEqual(1, conn.in_flight)
def get_second_conn():
c = pool.borrow_connection(1.0)
c, request_id = pool.borrow_connection(1.0)
self.assertIs(conn, c)
pool.return_connection(c)
@@ -98,7 +98,7 @@ class HostConnectionPoolTests(unittest.TestCase):
def test_all_connections_trashed(self):
host = Mock(spec=Host, address='ip1')
session = self.make_session()
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False)
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100)
session.cluster.connection_factory.return_value = conn
session.cluster.get_core_connections_per_host.return_value = 1
@@ -118,7 +118,7 @@ class HostConnectionPoolTests(unittest.TestCase):
def get_conn():
conn.reset_mock()
c = pool.borrow_connection(1.0)
c, request_id = pool.borrow_connection(1.0)
self.assertIs(conn, c)
self.assertEqual(1, conn.in_flight)
conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace')
@@ -140,7 +140,8 @@ class HostConnectionPoolTests(unittest.TestCase):
def test_spawn_when_at_max(self):
host = Mock(spec=Host, address='ip1')
session = self.make_session()
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False)
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100)
conn.max_request_id = 100
session.cluster.connection_factory.return_value = conn
# core conns = 1, max conns = 2
@@ -153,7 +154,7 @@ class HostConnectionPoolTests(unittest.TestCase):
self.assertEqual(1, conn.in_flight)
# make this conn full
conn.in_flight = MAX_STREAM_PER_CONNECTION
conn.in_flight = conn.max_request_id
# we don't care about making this borrow_connection call succeed for the
# purposes of this test, as long as it results in a new connection
@@ -164,7 +165,7 @@ class HostConnectionPoolTests(unittest.TestCase):
def test_return_defunct_connection(self):
host = Mock(spec=Host, address='ip1')
session = self.make_session()
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False)
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100)
session.cluster.connection_factory.return_value = conn
pool = HostConnectionPool(host, HostDistance.LOCAL, session)
@@ -183,7 +184,7 @@ class HostConnectionPoolTests(unittest.TestCase):
def test_return_defunct_connection_on_down_host(self):
host = Mock(spec=Host, address='ip1')
session = self.make_session()
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False)
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100)
session.cluster.connection_factory.return_value = conn
pool = HostConnectionPool(host, HostDistance.LOCAL, session)
@@ -203,7 +204,7 @@ class HostConnectionPoolTests(unittest.TestCase):
def test_return_closed_connection(self):
host = Mock(spec=Host, address='ip1')
session = self.make_session()
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=True)
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=True, max_request_id=100)
session.cluster.connection_factory.return_value = conn
pool = HostConnectionPool(host, HostDistance.LOCAL, session)

View File

@@ -21,7 +21,7 @@ from mock import Mock, MagicMock, ANY
from cassandra import ConsistencyLevel
from cassandra.cluster import Session, ResponseFuture, NoHostAvailable
from cassandra.connection import ConnectionException
from cassandra.connection import Connection, ConnectionException
from cassandra.protocol import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessage,
UnavailableErrorMessage, ResultMessage, QueryMessage,
OverloadedErrorMessage, IsBootstrappingErrorMessage,
@@ -58,13 +58,16 @@ class ResponseFutureTests(unittest.TestCase):
pool = session._pools.get.return_value
pool.is_shutdown = False
connection = Mock(spec=Connection)
pool.borrow_connection.return_value = (connection, 1)
rf = self.make_response_future(session)
rf.send_request()
rf.session._pools.get.assert_called_once_with('ip1')
pool.borrow_connection.assert_called_once_with(timeout=ANY)
connection = pool.borrow_connection.return_value
connection.send_msg.assert_called_once_with(rf.message, cb=ANY)
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY)
rf._set_result(self.make_mock_response([{'col': 'val'}]))
result = rf.result()
@@ -72,6 +75,10 @@ class ResponseFutureTests(unittest.TestCase):
def test_unknown_result_class(self):
session = self.make_session()
pool = session._pools.get.return_value
connection = Mock(spec=Connection)
pool.borrow_connection.return_value = (connection, 1)
rf = self.make_response_future(session)
rf.send_request()
rf._set_result(object())
@@ -168,18 +175,21 @@ class ResponseFutureTests(unittest.TestCase):
def test_retry_policy_says_retry(self):
session = self.make_session()
pool = session._pools.get.return_value
query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)")
query.retry_policy = Mock()
query.retry_policy.on_unavailable.return_value = (RetryPolicy.RETRY, ConsistencyLevel.ONE)
message = QueryMessage(query=query, consistency_level=ConsistencyLevel.QUORUM)
connection = Mock(spec=Connection)
pool.borrow_connection.return_value = (connection, 1)
rf = ResponseFuture(session, message, query)
rf.send_request()
rf.session._pools.get.assert_called_once_with('ip1')
pool.borrow_connection.assert_called_once_with(timeout=ANY)
connection = pool.borrow_connection.return_value
connection.send_msg.assert_called_once_with(rf.message, cb=ANY)
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY)
result = Mock(spec=UnavailableErrorMessage, info={})
rf._set_result(result)
@@ -187,6 +197,9 @@ class ResponseFutureTests(unittest.TestCase):
session.submit.assert_called_once_with(rf._retry_task, True)
self.assertEqual(1, rf._query_retries)
connection = Mock(spec=Connection)
pool.borrow_connection.return_value = (connection, 2)
# simulate the executor running this
rf._retry_task(True)
@@ -194,21 +207,22 @@ class ResponseFutureTests(unittest.TestCase):
# an UnavailableException
rf.session._pools.get.assert_called_with('ip1')
pool.borrow_connection.assert_called_with(timeout=ANY)
connection = pool.borrow_connection.return_value
connection.send_msg.assert_called_with(rf.message, cb=ANY)
connection.send_msg.assert_called_with(rf.message, 2, cb=ANY)
def test_retry_with_different_host(self):
session = self.make_session()
pool = session._pools.get.return_value
connection = Mock(spec=Connection)
pool.borrow_connection.return_value = (connection, 1)
rf = self.make_response_future(session)
rf.message.consistency_level = ConsistencyLevel.QUORUM
rf.send_request()
rf.session._pools.get.assert_called_once_with('ip1')
pool.borrow_connection.assert_called_once_with(timeout=ANY)
connection = pool.borrow_connection.return_value
connection.send_msg.assert_called_once_with(rf.message, cb=ANY)
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY)
self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level)
result = Mock(spec=OverloadedErrorMessage, info={})
@@ -218,20 +232,24 @@ class ResponseFutureTests(unittest.TestCase):
# query_retries does not get incremented for Overloaded/Bootstrapping errors
self.assertEqual(0, rf._query_retries)
connection = Mock(spec=Connection)
pool.borrow_connection.return_value = (connection, 2)
# simulate the executor running this
rf._retry_task(False)
# it should try with a different host
rf.session._pools.get.assert_called_with('ip2')
pool.borrow_connection.assert_called_with(timeout=ANY)
connection = pool.borrow_connection.return_value
connection.send_msg.assert_called_with(rf.message, cb=ANY)
connection.send_msg.assert_called_with(rf.message, 2, cb=ANY)
# the consistency level should be the same
self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level)
def test_all_retries_fail(self):
session = self.make_session()
pool = session._pools.get.return_value
connection = Mock(spec=Connection)
pool.borrow_connection.return_value = (connection, 1)
rf = self.make_response_future(session)
rf.send_request()
@@ -287,7 +305,11 @@ class ResponseFutureTests(unittest.TestCase):
exc = NoConnectionsAvailable()
first_pool = Mock(is_shutdown=False)
first_pool.borrow_connection.side_effect = exc
# the second pool will return a connection
second_pool = Mock(is_shutdown=False)
connection = Mock(spec=Connection)
second_pool.borrow_connection.return_value = (connection, 1)
session._pools.get.side_effect = [first_pool, second_pool]
@@ -317,6 +339,10 @@ class ResponseFutureTests(unittest.TestCase):
def test_errback(self):
session = self.make_session()
pool = session._pools.get.return_value
connection = Mock(spec=Connection)
pool.borrow_connection.return_value = (connection, 1)
query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)")
query.retry_policy = Mock()
query.retry_policy.on_unavailable.return_value = (RetryPolicy.RETHROW, None)
@@ -366,6 +392,10 @@ class ResponseFutureTests(unittest.TestCase):
def test_prepared_query_not_found(self):
session = self.make_session()
pool = session._pools.get.return_value
connection = Mock(spec=Connection)
pool.borrow_connection.return_value = (connection, 1)
rf = self.make_response_future(session)
rf.send_request()
@@ -386,6 +416,10 @@ class ResponseFutureTests(unittest.TestCase):
def test_prepared_query_not_found_bad_keyspace(self):
session = self.make_session()
pool = session._pools.get.return_value
connection = Mock(spec=Connection)
pool.borrow_connection.return_value = (connection, 1)
rf = self.make_response_future(session)
rf.send_request()