diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 7473c588..f9a60142 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2143,8 +2143,8 @@ class ResponseFuture(object): connection = None try: # TODO get connectTimeout from cluster settings - connection = pool.borrow_connection(timeout=2.0) - request_id = connection.send_msg(message, cb=cb) + connection, request_id = pool.borrow_connection(timeout=2.0) + connection.send_msg(message, request_id, cb=cb) except NoConnectionsAvailable as exc: log.debug("All connections for host %s are at capacity, moving to the next host", host) self._errors[host] = exc diff --git a/cassandra/connection.py b/cassandra/connection.py index 0fcf36c6..ec2c7c41 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -29,7 +29,7 @@ import six from six.moves import range from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut -from cassandra.marshal import int32_pack, header_unpack +from cassandra.marshal import int32_pack, header_unpack, v3_header_unpack from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessage, StartupMessage, ErrorMessage, CredentialsMessage, QueryMessage, ResultMessage, decode_response, @@ -79,8 +79,6 @@ else: locally_supported_compressions['snappy'] = (snappy.compress, decompress) -MAX_STREAM_PER_CONNECTION = 127 - PROTOCOL_VERSION_MASK = 0x7f HEADER_DIRECTION_FROM_CLIENT = 0x00 @@ -153,6 +151,7 @@ class Connection(object): ssl_options = None last_error = None in_flight = 0 + current_request_id = 0 is_defunct = False is_closed = False lock = None @@ -172,10 +171,27 @@ class Connection(object): self.protocol_version = protocol_version self.is_control_connection = is_control_connection self._push_watchers = defaultdict(set) + if protocol_version >= 3: + self._header_unpack = v3_header_unpack + self._header_length = 5 + self.max_request_id = (2 ** 15) - 1 + else: + self._header_unpack = header_unpack + self._header_length = 4 + self.max_request_id = (2 ** 7) - 1 - self._id_queue = Queue(MAX_STREAM_PER_CONNECTION) - for i in range(MAX_STREAM_PER_CONNECTION): - self._id_queue.put_nowait(i) + # 0 8 16 24 32 40 + # +---------+---------+---------+---------+---------+ + # | version | flags | stream | opcode | + # +---------+---------+---------+---------+---------+ + # | length | + # +---------+---------+---------+---------+ + # | | + # . ... body ... . + # . . + # . . + # +---------------------------------------- + self._full_header_length = self._header_length + 4 self.lock = RLock() @@ -210,6 +226,11 @@ class Connection(object): "failed connection (%s) to host %s:", id(self), self.host, exc_info=True) + def get_request_id(self): + current = self.current_request_id + self.current_request_id = (current + 1) % self.max_request_id + return current + def handle_pushed(self, response): log.debug("Message pushed from server: %r", response) for cb in self._push_watchers.get(response.event_type, []): @@ -218,21 +239,12 @@ class Connection(object): except Exception: log.exception("Pushed event handler errored, ignoring:") - def send_msg(self, msg, cb, wait_for_id=False): + def send_msg(self, msg, request_id, cb): if self.is_defunct: raise ConnectionShutdown("Connection to %s is defunct" % self.host) elif self.is_closed: raise ConnectionShutdown("Connection to %s is closed" % self.host) - if not wait_for_id: - try: - request_id = self._id_queue.get_nowait() - except Empty: - raise ConnectionBusy( - "Connection to %s is at the max number of requests" % self.host) - else: - request_id = self._id_queue.get() - self._callbacks[request_id] = cb self.push(msg.to_binary(request_id, self.protocol_version, compression=self.compressor)) return request_id @@ -251,11 +263,15 @@ class Connection(object): while True: needed = len(msgs) - messages_sent with self.lock: - available = min(needed, MAX_STREAM_PER_CONNECTION - self.in_flight) + available = min(needed, self.max_request_id - self.in_flight) + start_request_id = self.current_request_id + self.current_request_id = (self.current_request_id + available) % self.max_request_id self.in_flight += available - for i in range(messages_sent, messages_sent + available): - self.send_msg(msgs[i], partial(waiter.got_response, index=i), wait_for_id=True) + for i in range(available): + self.send_msg(msgs[messages_sent + i], + (start_request_id + i) % self.max_request_id, + partial(waiter.got_response, index=messages_sent + i)) messages_sent += available if messages_sent == len(msgs): @@ -287,12 +303,11 @@ class Connection(object): @defunct_on_error def process_msg(self, msg, body_len): - version, flags, stream_id, opcode = header_unpack(msg[:4]) + version, flags, stream_id, opcode = self._header_unpack(msg[:self._header_length]) if stream_id < 0: callback = None else: callback = self._callbacks.pop(stream_id, None) - self._id_queue.put_nowait(stream_id) body = None try: @@ -344,7 +359,7 @@ class Connection(object): self._send_startup_message() else: log.debug("Sending initial options message for new connection (%s) to %s", id(self), self.host) - self.send_msg(OptionsMessage(), self._handle_options_response) + self.send_msg(OptionsMessage(), 0, self._handle_options_response) @defunct_on_error def _handle_options_response(self, options_response): @@ -411,11 +426,13 @@ class Connection(object): @defunct_on_error def _send_startup_message(self, compression=None): + log.debug("Sending StartupMessage on %s", self) opts = {} if compression: opts['COMPRESSION'] = compression sm = StartupMessage(cqlversion=self.cql_version, options=opts) - self.send_msg(sm, cb=self._handle_startup_response) + self.send_msg(sm, 0, cb=self._handle_startup_response) + log.debug("Sent StartupMessage on %s", self) @defunct_on_error def _handle_startup_response(self, startup_response, did_authenticate=False): @@ -439,12 +456,12 @@ class Connection(object): log.debug("Sending credentials-based auth response on %s", self) cm = CredentialsMessage(creds=self.authenticator) callback = partial(self._handle_startup_response, did_authenticate=True) - self.send_msg(cm, cb=callback) + self.send_msg(cm, 0, cb=callback) else: log.debug("Sending SASL-based auth response on %s", self) initial_response = self.authenticator.initial_response() initial_response = "" if initial_response is None else initial_response.encode('utf-8') - self.send_msg(AuthResponseMessage(initial_response), self._handle_auth_response) + self.send_msg(AuthResponseMessage(initial_response), 0, self._handle_auth_response) elif isinstance(startup_response, ErrorMessage): log.debug("Received ErrorMessage on new connection (%s) from %s: %s", id(self), self.host, startup_response.summary_msg()) @@ -479,7 +496,7 @@ class Connection(object): response = self.authenticator.evaluate_challenge(auth_response.challenge) msg = AuthResponseMessage("" if response is None else response) log.debug("Responding to auth challenge on %s", self) - self.send_msg(msg, self._handle_auth_response) + self.send_msg(msg, 0, self._handle_auth_response) elif isinstance(auth_response, ErrorMessage): log.debug("Received ErrorMessage on new connection (%s) from %s: %s", id(self), self.host, auth_response.summary_msg()) @@ -543,7 +560,21 @@ class Connection(object): callback(self, self.defunct(ConnectionException( "Problem while setting keyspace: %r" % (result,), self.host))) - self.send_msg(query, process_result, wait_for_id=True) + request_id = None + # we use a busy wait on the lock here because: + # - we'll only spin if the connection is at max capacity, which is very + # unlikely for a set_keyspace call + # - it allows us to avoid signaling a condition every time a request completes + while True: + with self.lock: + if self.in_flight < self.max_request_id: + request_id = self.get_request_id() + self.in_flight += 1 + break + + time.sleep(0.001) + + self.send_msg(query, request_id, process_result) def __str__(self): status = "" diff --git a/cassandra/io/asyncorereactor.py b/cassandra/io/asyncorereactor.py index 758458e0..de1429c4 100644 --- a/cassandra/io/asyncorereactor.py +++ b/cassandra/io/asyncorereactor.py @@ -229,7 +229,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher): self.defunct(sys.exc_info()[1]) def handle_close(self): - log.debug("connection (%s) to %s closed by server", id(self), self.host) + log.debug("Connection %s closed by server", self) self.close() def handle_write(self): @@ -277,24 +277,24 @@ class AsyncoreConnection(Connection, asyncore.dispatcher): if self._iobuf.tell(): while True: pos = self._iobuf.tell() - if pos < 8 or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes): + if pos < self._full_header_length or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes): # we don't have a complete header yet or we # already saw a header, but we don't have a # complete message yet break else: # have enough for header, read body len from header - self._iobuf.seek(4) + self._iobuf.seek(self._header_length) body_len = int32_unpack(self._iobuf.read(4)) # seek to end to get length of current buffer self._iobuf.seek(0, os.SEEK_END) pos = self._iobuf.tell() - if pos >= body_len + 8: + if pos >= body_len + self._full_header_length: # read message header and body self._iobuf.seek(0) - msg = self._iobuf.read(8 + body_len) + msg = self._iobuf.read(self._full_header_length + body_len) # leave leftover in current buffer leftover = self._iobuf.read() @@ -304,7 +304,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher): self._total_reqd_bytes = 0 self.process_msg(msg, body_len) else: - self._total_reqd_bytes = body_len + 8 + self._total_reqd_bytes = body_len + self._full_header_length break if not self._callbacks and not self.is_control_connection: diff --git a/cassandra/io/libevreactor.py b/cassandra/io/libevreactor.py index 02eff6ff..bd000e78 100644 --- a/cassandra/io/libevreactor.py +++ b/cassandra/io/libevreactor.py @@ -305,24 +305,24 @@ class LibevConnection(Connection): if self._iobuf.tell(): while True: pos = self._iobuf.tell() - if pos < 8 or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes): + if pos < self._full_header_length or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes): # we don't have a complete header yet or we # already saw a header, but we don't have a # complete message yet break else: # have enough for header, read body len from header - self._iobuf.seek(4) + self._iobuf.seek(self._header_length) body_len = int32_unpack(self._iobuf.read(4)) # seek to end to get length of current buffer self._iobuf.seek(0, os.SEEK_END) pos = self._iobuf.tell() - if pos >= body_len + 8: + if pos >= body_len + self._full_header_length: # read message header and body self._iobuf.seek(0) - msg = self._iobuf.read(8 + body_len) + msg = self._iobuf.read(self._full_header_length + body_len) # leave leftover in current buffer leftover = self._iobuf.read() @@ -332,7 +332,7 @@ class LibevConnection(Connection): self._total_reqd_bytes = 0 self.process_msg(msg, body_len) else: - self._total_reqd_bytes = body_len + 8 + self._total_reqd_bytes = body_len + self._full_header_length break else: log.debug("Connection %s closed by server", self) diff --git a/cassandra/marshal.py b/cassandra/marshal.py index 447e4ef1..1a78a97a 100644 --- a/cassandra/marshal.py +++ b/cassandra/marshal.py @@ -17,14 +17,9 @@ import struct def _make_packer(format_string): - try: - packer = struct.Struct(format_string) # new in Python 2.5 - except AttributeError: - pack = lambda x: struct.pack(format_string, x) - unpack = lambda s: struct.unpack(format_string, s) - else: - pack = packer.pack - unpack = lambda s: packer.unpack(s)[0] + packer = struct.Struct(format_string) + pack = packer.pack + unpack = lambda s: packer.unpack(s)[0] return pack, unpack int64_pack, int64_unpack = _make_packer('>q') @@ -43,6 +38,11 @@ header_struct = struct.Struct('>BBbB') header_pack = header_struct.pack header_unpack = header_struct.unpack +# in protocol version 3 and higher, the stream ID is two bytes +v3_header_struct = struct.Struct('>BBhB') +v3_header_pack = v3_header_struct.pack +v3_header_unpack = v3_header_struct.unpack + if six.PY3: def varint_unpack(term): diff --git a/cassandra/pool.py b/cassandra/pool.py index c2cf4236..f99121d7 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -28,7 +28,7 @@ except ImportError: from cassandra.util import WeakSet # NOQA from cassandra import AuthenticationFailed -from cassandra.connection import MAX_STREAM_PER_CONNECTION, ConnectionException +from cassandra.connection import ConnectionException log = logging.getLogger(__name__) @@ -349,6 +349,7 @@ class HostConnectionPool(object): max_conns = self._session.cluster.get_max_connections_per_host(self.host_distance) least_busy = min(conns, key=lambda c: c.in_flight) + request_id = None # to avoid another thread closing this connection while # trashing it (through the return_connection process), hold # the connection lock from this point until we've incremented @@ -356,15 +357,16 @@ class HostConnectionPool(object): need_to_wait = False with least_busy.lock: - if least_busy.in_flight >= MAX_STREAM_PER_CONNECTION: + if least_busy.in_flight >= least_busy.max_request_id: # once we release the lock, wait for another connection need_to_wait = True else: least_busy.in_flight += 1 + request_id = least_busy.get_request_id() if need_to_wait: # wait_for_conn will increment in_flight on the conn - least_busy = self._wait_for_conn(timeout) + least_busy, request_id = self._wait_for_conn(timeout) # if we have too many requests on this connection but we still # have space to open a new connection against this host, go ahead @@ -372,7 +374,7 @@ class HostConnectionPool(object): if least_busy.in_flight >= max_reqs and len(self._connections) < max_conns: self._maybe_spawn_new_connection() - return least_busy + return least_busy, request_id def _maybe_spawn_new_connection(self): with self._lock: @@ -461,9 +463,9 @@ class HostConnectionPool(object): if conns: least_busy = min(conns, key=lambda c: c.in_flight) with least_busy.lock: - if least_busy.in_flight < MAX_STREAM_PER_CONNECTION: + if least_busy.in_flight < least_busy.max_request_id: least_busy.in_flight += 1 - return least_busy + return least_busy, least_busy.get_request_id() remaining = timeout - (time.time() - start) diff --git a/cassandra/protocol.py b/cassandra/protocol.py index 4171da1d..99979c32 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -23,7 +23,8 @@ from cassandra import (Unavailable, WriteTimeout, ReadTimeout, AlreadyExists, InvalidRequest, Unauthorized, UnsupportedOperation) from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack, - int8_pack, int8_unpack, uint64_pack, header_pack) + int8_pack, int8_unpack, uint64_pack, header_pack, + v3_header_pack) from cassandra.cqltypes import (AsciiType, BytesType, BooleanType, CounterColumnType, DateType, DecimalType, DoubleType, FloatType, Int32Type, @@ -80,11 +81,7 @@ class _MessageType(object): flags |= TRACING_FLAG msg = six.BytesIO() - write_header( - msg, - protocol_version | HEADER_DIRECTION_FROM_CLIENT, - flags, stream_id, self.opcode, len(body) - ) + write_header(msg, protocol_version, flags, stream_id, self.opcode, len(body)) msg.write(body) return msg.getvalue() @@ -824,7 +821,8 @@ def write_header(f, version, flags, stream_id, opcode, length): """ Write a CQL protocol frame header. """ - f.write(header_pack(version, flags, stream_id, opcode)) + pack = v3_header_pack if version >= 3 else header_pack + f.write(pack(version | HEADER_DIRECTION_FROM_CLIENT, flags, stream_id, opcode)) write_int(f, length) diff --git a/tests/unit/io/test_asyncorereactor.py b/tests/unit/io/test_asyncorereactor.py index 179fe637..8b358e6d 100644 --- a/tests/unit/io/test_asyncorereactor.py +++ b/tests/unit/io/test_asyncorereactor.py @@ -46,6 +46,9 @@ class AsyncoreConnectionTest(unittest.TestCase): cls.mock_socket = cls.socket_patcher.start() cls.mock_socket().connect_ex.return_value = 0 cls.mock_socket().getsockopt.return_value = 0 + cls.mock_socket().fileno.return_value = 100 + + AsyncoreConnection.add_channel = lambda *args, **kwargs: None @classmethod def tearDownClass(cls): @@ -97,7 +100,7 @@ class AsyncoreConnectionTest(unittest.TestCase): # let it write out a StartupMessage c.handle_write() - header = self.make_header_prefix(ReadyMessage, stream_id=1) + header = self.make_header_prefix(ReadyMessage, stream_id=0) c.socket.recv.return_value = self.make_msg(header) c.handle_read() @@ -169,7 +172,7 @@ class AsyncoreConnectionTest(unittest.TestCase): # let it write out a StartupMessage c.handle_write() - header = self.make_header_prefix(ServerError, stream_id=1) + header = self.make_header_prefix(ServerError, stream_id=0) body = self.make_error_body(ServerError.error_code, ServerError.summary) c.socket.recv.return_value = self.make_msg(header, body) c.handle_read() @@ -251,7 +254,7 @@ class AsyncoreConnectionTest(unittest.TestCase): # let it write out a StartupMessage c.handle_write() - header = self.make_header_prefix(ReadyMessage, stream_id=1) + header = self.make_header_prefix(ReadyMessage, stream_id=0) c.socket.recv.return_value = self.make_msg(header) c.handle_read() @@ -278,7 +281,7 @@ class AsyncoreConnectionTest(unittest.TestCase): # let it write out a StartupMessage c.handle_write() - header = self.make_header_prefix(ReadyMessage, stream_id=1) + header = self.make_header_prefix(ReadyMessage, stream_id=0) c.socket.recv.return_value = self.make_msg(header) c.handle_read() diff --git a/tests/unit/io/test_libevreactor.py b/tests/unit/io/test_libevreactor.py index 5fed53b4..a7b060b1 100644 --- a/tests/unit/io/test_libevreactor.py +++ b/tests/unit/io/test_libevreactor.py @@ -97,7 +97,7 @@ class LibevConnectionTest(unittest.TestCase): # let it write out a StartupMessage c.handle_write(None, 0) - header = self.make_header_prefix(ReadyMessage, stream_id=1) + header = self.make_header_prefix(ReadyMessage, stream_id=0) c._socket.recv.return_value = self.make_msg(header) c.handle_read(None, 0) @@ -169,7 +169,7 @@ class LibevConnectionTest(unittest.TestCase): # let it write out a StartupMessage c.handle_write(None, 0) - header = self.make_header_prefix(ServerError, stream_id=1) + header = self.make_header_prefix(ServerError, stream_id=0) body = self.make_error_body(ServerError.error_code, ServerError.summary) c._socket.recv.return_value = self.make_msg(header, body) c.handle_read(None, 0) @@ -252,7 +252,7 @@ class LibevConnectionTest(unittest.TestCase): # let it write out a StartupMessage c.handle_write(None, 0) - header = self.make_header_prefix(ReadyMessage, stream_id=1) + header = self.make_header_prefix(ReadyMessage, stream_id=0) c._socket.recv.return_value = self.make_msg(header) c.handle_read(None, 0) @@ -279,7 +279,7 @@ class LibevConnectionTest(unittest.TestCase): # let it write out a StartupMessage c.handle_write(None, 0) - header = self.make_header_prefix(ReadyMessage, stream_id=1) + header = self.make_header_prefix(ReadyMessage, stream_id=0) c._socket.recv.return_value = self.make_msg(header) c.handle_read(None, 0) diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 708c3e0b..6bc50c22 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -68,7 +68,6 @@ class ConnectionTest(unittest.TestCase): def test_bad_protocol_version(self, *args): c = self.make_connection() - c._id_queue.get_nowait() c._callbacks = Mock() c.defunct = Mock() @@ -85,7 +84,6 @@ class ConnectionTest(unittest.TestCase): def test_bad_header_direction(self, *args): c = self.make_connection() - c._id_queue.get_nowait() c._callbacks = Mock() c.defunct = Mock() @@ -107,7 +105,6 @@ class ConnectionTest(unittest.TestCase): def test_negative_body_length(self, *args): c = self.make_connection() - c._id_queue.get_nowait() c._callbacks = Mock() c.defunct = Mock() @@ -124,7 +121,6 @@ class ConnectionTest(unittest.TestCase): def test_unsupported_cql_version(self, *args): c = self.make_connection() - c._id_queue.get_nowait() c._callbacks = {0: c._handle_options_response} c.defunct = Mock() c.cql_version = "3.0.3" @@ -149,7 +145,6 @@ class ConnectionTest(unittest.TestCase): def test_prefer_lz4_compression(self, *args): c = self.make_connection() - c._id_queue.get_nowait() c._callbacks = {0: c._handle_options_response} c.defunct = Mock() c.cql_version = "3.0.3" @@ -176,7 +171,6 @@ class ConnectionTest(unittest.TestCase): def test_requested_compression_not_available(self, *args): c = self.make_connection() - c._id_queue.get_nowait() c._callbacks = {0: c._handle_options_response} c.defunct = Mock() # request lz4 compression @@ -208,7 +202,6 @@ class ConnectionTest(unittest.TestCase): def test_use_requested_compression(self, *args): c = self.make_connection() - c._id_queue.get_nowait() c._callbacks = {0: c._handle_options_response} c.defunct = Mock() # request snappy compression @@ -237,7 +230,6 @@ class ConnectionTest(unittest.TestCase): def test_disable_compression(self, *args): c = self.make_connection() - c._id_queue.get_nowait() c._callbacks = {0: c._handle_options_response} c.defunct = Mock() # disable compression diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index 20c8830f..6c32bc47 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -21,7 +21,7 @@ from mock import Mock, NonCallableMagicMock from threading import Thread, Event from cassandra.cluster import Session -from cassandra.connection import Connection, MAX_STREAM_PER_CONNECTION +from cassandra.connection import Connection from cassandra.pool import Host, HostConnectionPool, NoConnectionsAvailable from cassandra.policies import HostDistance, SimpleConvictionPolicy @@ -38,13 +38,13 @@ class HostConnectionPoolTests(unittest.TestCase): def test_borrow_and_return(self): host = Mock(spec=Host, address='ip1') session = self.make_session() - conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False) + conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100) session.cluster.connection_factory.return_value = conn pool = HostConnectionPool(host, HostDistance.LOCAL, session) session.cluster.connection_factory.assert_called_once_with(host.address) - c = pool.borrow_connection(timeout=0.01) + c, request_id = pool.borrow_connection(timeout=0.01) self.assertIs(c, conn) self.assertEqual(1, conn.in_flight) conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace') @@ -56,7 +56,7 @@ class HostConnectionPoolTests(unittest.TestCase): def test_failed_wait_for_connection(self): host = Mock(spec=Host, address='ip1') session = self.make_session() - conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False) + conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100) session.cluster.connection_factory.return_value = conn pool = HostConnectionPool(host, HostDistance.LOCAL, session) @@ -65,7 +65,7 @@ class HostConnectionPoolTests(unittest.TestCase): pool.borrow_connection(timeout=0.01) self.assertEqual(1, conn.in_flight) - conn.in_flight = MAX_STREAM_PER_CONNECTION + conn.in_flight = conn.max_request_id # we're already at the max number of requests for this connection, # so we this should fail @@ -74,7 +74,7 @@ class HostConnectionPoolTests(unittest.TestCase): def test_successful_wait_for_connection(self): host = Mock(spec=Host, address='ip1') session = self.make_session() - conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False) + conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100) session.cluster.connection_factory.return_value = conn pool = HostConnectionPool(host, HostDistance.LOCAL, session) @@ -84,7 +84,7 @@ class HostConnectionPoolTests(unittest.TestCase): self.assertEqual(1, conn.in_flight) def get_second_conn(): - c = pool.borrow_connection(1.0) + c, request_id = pool.borrow_connection(1.0) self.assertIs(conn, c) pool.return_connection(c) @@ -98,7 +98,7 @@ class HostConnectionPoolTests(unittest.TestCase): def test_all_connections_trashed(self): host = Mock(spec=Host, address='ip1') session = self.make_session() - conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False) + conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100) session.cluster.connection_factory.return_value = conn session.cluster.get_core_connections_per_host.return_value = 1 @@ -118,7 +118,7 @@ class HostConnectionPoolTests(unittest.TestCase): def get_conn(): conn.reset_mock() - c = pool.borrow_connection(1.0) + c, request_id = pool.borrow_connection(1.0) self.assertIs(conn, c) self.assertEqual(1, conn.in_flight) conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace') @@ -140,7 +140,8 @@ class HostConnectionPoolTests(unittest.TestCase): def test_spawn_when_at_max(self): host = Mock(spec=Host, address='ip1') session = self.make_session() - conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False) + conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100) + conn.max_request_id = 100 session.cluster.connection_factory.return_value = conn # core conns = 1, max conns = 2 @@ -153,7 +154,7 @@ class HostConnectionPoolTests(unittest.TestCase): self.assertEqual(1, conn.in_flight) # make this conn full - conn.in_flight = MAX_STREAM_PER_CONNECTION + conn.in_flight = conn.max_request_id # we don't care about making this borrow_connection call succeed for the # purposes of this test, as long as it results in a new connection @@ -164,7 +165,7 @@ class HostConnectionPoolTests(unittest.TestCase): def test_return_defunct_connection(self): host = Mock(spec=Host, address='ip1') session = self.make_session() - conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False) + conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100) session.cluster.connection_factory.return_value = conn pool = HostConnectionPool(host, HostDistance.LOCAL, session) @@ -183,7 +184,7 @@ class HostConnectionPoolTests(unittest.TestCase): def test_return_defunct_connection_on_down_host(self): host = Mock(spec=Host, address='ip1') session = self.make_session() - conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False) + conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100) session.cluster.connection_factory.return_value = conn pool = HostConnectionPool(host, HostDistance.LOCAL, session) @@ -203,7 +204,7 @@ class HostConnectionPoolTests(unittest.TestCase): def test_return_closed_connection(self): host = Mock(spec=Host, address='ip1') session = self.make_session() - conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=True) + conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=True, max_request_id=100) session.cluster.connection_factory.return_value = conn pool = HostConnectionPool(host, HostDistance.LOCAL, session) diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index e9e0b686..ca97a729 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -21,7 +21,7 @@ from mock import Mock, MagicMock, ANY from cassandra import ConsistencyLevel from cassandra.cluster import Session, ResponseFuture, NoHostAvailable -from cassandra.connection import ConnectionException +from cassandra.connection import Connection, ConnectionException from cassandra.protocol import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, UnavailableErrorMessage, ResultMessage, QueryMessage, OverloadedErrorMessage, IsBootstrappingErrorMessage, @@ -58,13 +58,16 @@ class ResponseFutureTests(unittest.TestCase): pool = session._pools.get.return_value pool.is_shutdown = False + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 1) + rf = self.make_response_future(session) rf.send_request() rf.session._pools.get.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY) - connection = pool.borrow_connection.return_value - connection.send_msg.assert_called_once_with(rf.message, cb=ANY) + + connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY) rf._set_result(self.make_mock_response([{'col': 'val'}])) result = rf.result() @@ -72,6 +75,10 @@ class ResponseFutureTests(unittest.TestCase): def test_unknown_result_class(self): session = self.make_session() + pool = session._pools.get.return_value + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 1) + rf = self.make_response_future(session) rf.send_request() rf._set_result(object()) @@ -168,18 +175,21 @@ class ResponseFutureTests(unittest.TestCase): def test_retry_policy_says_retry(self): session = self.make_session() pool = session._pools.get.return_value + query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)") query.retry_policy = Mock() query.retry_policy.on_unavailable.return_value = (RetryPolicy.RETRY, ConsistencyLevel.ONE) message = QueryMessage(query=query, consistency_level=ConsistencyLevel.QUORUM) + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 1) + rf = ResponseFuture(session, message, query) rf.send_request() rf.session._pools.get.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY) - connection = pool.borrow_connection.return_value - connection.send_msg.assert_called_once_with(rf.message, cb=ANY) + connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY) result = Mock(spec=UnavailableErrorMessage, info={}) rf._set_result(result) @@ -187,6 +197,9 @@ class ResponseFutureTests(unittest.TestCase): session.submit.assert_called_once_with(rf._retry_task, True) self.assertEqual(1, rf._query_retries) + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 2) + # simulate the executor running this rf._retry_task(True) @@ -194,21 +207,22 @@ class ResponseFutureTests(unittest.TestCase): # an UnavailableException rf.session._pools.get.assert_called_with('ip1') pool.borrow_connection.assert_called_with(timeout=ANY) - connection = pool.borrow_connection.return_value - connection.send_msg.assert_called_with(rf.message, cb=ANY) + connection.send_msg.assert_called_with(rf.message, 2, cb=ANY) def test_retry_with_different_host(self): session = self.make_session() pool = session._pools.get.return_value + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 1) + rf = self.make_response_future(session) rf.message.consistency_level = ConsistencyLevel.QUORUM rf.send_request() rf.session._pools.get.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY) - connection = pool.borrow_connection.return_value - connection.send_msg.assert_called_once_with(rf.message, cb=ANY) + connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY) self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level) result = Mock(spec=OverloadedErrorMessage, info={}) @@ -218,20 +232,24 @@ class ResponseFutureTests(unittest.TestCase): # query_retries does not get incremented for Overloaded/Bootstrapping errors self.assertEqual(0, rf._query_retries) + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 2) # simulate the executor running this rf._retry_task(False) # it should try with a different host rf.session._pools.get.assert_called_with('ip2') pool.borrow_connection.assert_called_with(timeout=ANY) - connection = pool.borrow_connection.return_value - connection.send_msg.assert_called_with(rf.message, cb=ANY) + connection.send_msg.assert_called_with(rf.message, 2, cb=ANY) # the consistency level should be the same self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level) def test_all_retries_fail(self): session = self.make_session() + pool = session._pools.get.return_value + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 1) rf = self.make_response_future(session) rf.send_request() @@ -287,7 +305,11 @@ class ResponseFutureTests(unittest.TestCase): exc = NoConnectionsAvailable() first_pool = Mock(is_shutdown=False) first_pool.borrow_connection.side_effect = exc + + # the second pool will return a connection second_pool = Mock(is_shutdown=False) + connection = Mock(spec=Connection) + second_pool.borrow_connection.return_value = (connection, 1) session._pools.get.side_effect = [first_pool, second_pool] @@ -317,6 +339,10 @@ class ResponseFutureTests(unittest.TestCase): def test_errback(self): session = self.make_session() + pool = session._pools.get.return_value + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 1) + query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)") query.retry_policy = Mock() query.retry_policy.on_unavailable.return_value = (RetryPolicy.RETHROW, None) @@ -366,6 +392,10 @@ class ResponseFutureTests(unittest.TestCase): def test_prepared_query_not_found(self): session = self.make_session() + pool = session._pools.get.return_value + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 1) + rf = self.make_response_future(session) rf.send_request() @@ -386,6 +416,10 @@ class ResponseFutureTests(unittest.TestCase): def test_prepared_query_not_found_bad_keyspace(self): session = self.make_session() + pool = session._pools.get.return_value + connection = Mock(spec=Connection) + pool.borrow_connection.return_value = (connection, 1) + rf = self.make_response_future(session) rf.send_request()