diff --git a/cassandra/cluster.py b/cassandra/cluster.py index c26e10f1..5f25e178 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -420,15 +420,6 @@ class Cluster(object): GeventConnection will be used automatically. """ - protocol_handler_class = ProtocolHandler - """ - Specifies a protocol handler class, which can be used to override or extend features - such as message or type deserialization. - - The class must conform to the public classmethod interface defined in the default - implementation, :class:`cassandra.protocol.ProtocolHandler` - """ - control_connection_timeout = 2.0 """ A timeout, in seconds, for queries made by the control connection, such @@ -525,8 +516,7 @@ class Cluster(object): idle_heartbeat_interval=30, schema_event_refresh_window=2, topology_event_refresh_window=10, - connect_timeout=5, - protocol_handler_class=None): + connect_timeout=5): """ Any of the mutable Cluster attributes may be set as keyword arguments to the constructor. @@ -570,9 +560,6 @@ class Cluster(object): if connection_class is not None: self.connection_class = connection_class - if protocol_handler_class is not None: - self.protocol_handler_class = protocol_handler_class - self.metrics_enabled = metrics_enabled self.ssl_options = ssl_options self.sockopts = sockopts @@ -812,7 +799,6 @@ class Cluster(object): kwargs_dict.setdefault('cql_version', self.cql_version) kwargs_dict.setdefault('protocol_version', self.protocol_version) kwargs_dict.setdefault('user_type_map', self._user_types) - kwargs_dict.setdefault('protocol_handler_class', self.protocol_handler_class) return kwargs_dict @@ -1372,7 +1358,7 @@ class Cluster(object): log.debug("Preparing all known prepared statements against host %s", host) connection = None try: - connection = self.connection_factory(host.address, protocol_handler_class=ProtocolHandler) + connection = self.connection_factory(host.address) try: self.control_connection.wait_for_schema_agreement(connection) except Exception: @@ -1535,6 +1521,20 @@ class Session(object): .. versionadded:: 2.1.0 """ + client_protocol_handler = ProtocolHandler + """ + Specifies a protocol handler that will be used for client-initiated requests (i.e. no + internal driver requests). This can be used to override or extend features such as + message or type ser/des. + + The class must conform to the public classmethod interface defined in the default + implementation, :class:`cassandra.protocol.ProtocolHandler` + + This is not included in published documentation as it is not intended for the casual user. + It requires knowledge of the native protocol and driver internals. Only advanced, specialized + use cases should need to do anything with this. + """ + _lock = None _pools = None _load_balancer = None @@ -1661,6 +1661,7 @@ class Session(object): timeout = self.default_timeout future = self._create_response_future(query, parameters, trace, custom_payload, timeout) + future._protocol_handler = self.client_protocol_handler future.send_request() return future @@ -2131,7 +2132,7 @@ class ControlConnection(object): while True: try: - connection = self._cluster.connection_factory(host.address, is_control_connection=True, protocol_handler_class=ProtocolHandler) + connection = self._cluster.connection_factory(host.address, is_control_connection=True) break except ProtocolVersionUnsupported as e: self._cluster.protocol_downgrade(host.address, e.startup_version) @@ -2846,6 +2847,7 @@ class ResponseFuture(object): _custom_payload = None _warnings = None _timer = None + _protocol_handler = ProtocolHandler def __init__(self, session, message, query, timeout, metrics=None, prepared_statement=None): self.session = session @@ -2919,7 +2921,7 @@ class ResponseFuture(object): # TODO get connectTimeout from cluster settings connection, request_id = pool.borrow_connection(timeout=2.0) self._connection = connection - connection.send_msg(message, request_id, cb=cb) + connection.send_msg(message, request_id, cb=cb, encoder=self._protocol_handler.encode_message, decoder=self._protocol_handler.decode_message) return request_id except NoConnectionsAvailable as exc: log.debug("All connections for host %s are at capacity, moving to the next host", host) diff --git a/cassandra/connection.py b/cassandra/connection.py index d1e3a9c9..5db88985 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -209,7 +209,7 @@ class Connection(object): def __init__(self, host='127.0.0.1', port=9042, authenticator=None, ssl_options=None, sockopts=None, compression=True, cql_version=None, protocol_version=MAX_SUPPORTED_VERSION, is_control_connection=False, - user_type_map=None, protocol_handler_class=ProtocolHandler): + user_type_map=None): self.host = host self.port = port self.authenticator = authenticator @@ -220,10 +220,8 @@ class Connection(object): self.protocol_version = protocol_version self.is_control_connection = is_control_connection self.user_type_map = user_type_map - self.decoder = protocol_handler_class.decode_message - self.encoder = protocol_handler_class.encode_message self._push_watchers = defaultdict(set) - self._callbacks = {} + self._requests = {} self._iobuf = io.BytesIO() if protocol_version >= 3: @@ -320,20 +318,20 @@ class Connection(object): self.last_error = exc self.close() - self.error_all_callbacks(exc) + self.error_all_requests(exc) self.connected_event.set() return exc - def error_all_callbacks(self, exc): + def error_all_requests(self, exc): with self.lock: - callbacks = self._callbacks - self._callbacks = {} + requests = self._requests + self._requests = {} new_exc = ConnectionShutdown(str(exc)) - for cb in callbacks.values(): + for cb, _ in requests.values(): try: cb(new_exc) except Exception: - log.warning("Ignoring unhandled exception while erroring callbacks for a " + log.warning("Ignoring unhandled exception while erroring requests for a " "failed connection (%s) to host %s:", id(self), self.host, exc_info=True) @@ -357,14 +355,16 @@ class Connection(object): except Exception: log.exception("Pushed event handler errored, ignoring:") - def send_msg(self, msg, request_id, cb): + def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message): if self.is_defunct: raise ConnectionShutdown("Connection to %s is defunct" % self.host) elif self.is_closed: raise ConnectionShutdown("Connection to %s is closed" % self.host) - self._callbacks[request_id] = cb - self.push(self.encoder(msg, request_id, self.protocol_version, compressor=self.compressor)) + # queue the decoder function with the request + # this allows us to inject custom functions per request to encode, decode messages + self._requests[request_id] = (cb, decoder) + self.push(encoder(msg, request_id, self.protocol_version, compressor=self.compressor)) return request_id def wait_for_response(self, msg, timeout=None): @@ -492,16 +492,17 @@ class Connection(object): stream_id = header.stream if stream_id < 0: callback = None + decoder = ProtocolHandler.decode_message else: - callback = self._callbacks.pop(stream_id, None) + callback, decoder = self._requests.pop(stream_id, None) with self.lock: self.request_ids.append(stream_id) self.msg_received = True try: - response = self.decoder(header.version, self.user_type_map, stream_id, - header.flags, header.opcode, body, self.decompressor) + response = decoder(header.version, self.user_type_map, stream_id, + header.flags, header.opcode, body, self.decompressor) except Exception as exc: log.exception("Error decoding response from Cassandra. " "opcode: %04x; message contents: %r", header.opcode, body) diff --git a/cassandra/io/asyncorereactor.py b/cassandra/io/asyncorereactor.py index fae87a73..dc9d26c6 100644 --- a/cassandra/io/asyncorereactor.py +++ b/cassandra/io/asyncorereactor.py @@ -183,7 +183,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher): log.debug("Closed socket to %s", self.host) if not self.is_defunct: - self.error_all_callbacks( + self.error_all_requests( ConnectionShutdown("Connection to %s was closed" % self.host)) # don't leave in-progress operations hanging self.connected_event.set() @@ -239,7 +239,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher): if self._iobuf.tell(): self.process_io_buffer() - if not self._callbacks and not self.is_control_connection: + if not self._requests and not self.is_control_connection: self._readable = False def push(self, data): diff --git a/cassandra/io/eventletreactor.py b/cassandra/io/eventletreactor.py index b0206f43..aa90cc9a 100644 --- a/cassandra/io/eventletreactor.py +++ b/cassandra/io/eventletreactor.py @@ -116,7 +116,7 @@ class EventletConnection(Connection): log.debug("Closed socket to %s" % (self.host,)) if not self.is_defunct: - self.error_all_callbacks( + self.error_all_requests( ConnectionShutdown("Connection to %s was closed" % self.host)) # don't leave in-progress operations hanging self.connected_event.set() diff --git a/cassandra/io/geventreactor.py b/cassandra/io/geventreactor.py index 60a3f2d0..f26e6152 100644 --- a/cassandra/io/geventreactor.py +++ b/cassandra/io/geventreactor.py @@ -106,7 +106,7 @@ class GeventConnection(Connection): log.debug("Closed socket to %s" % (self.host,)) if not self.is_defunct: - self.error_all_callbacks( + self.error_all_requests( ConnectionShutdown("Connection to %s was closed" % self.host)) # don't leave in-progress operations hanging self.connected_event.set() diff --git a/cassandra/io/libevreactor.py b/cassandra/io/libevreactor.py index 9114af13..6b2036e2 100644 --- a/cassandra/io/libevreactor.py +++ b/cassandra/io/libevreactor.py @@ -290,7 +290,7 @@ class LibevConnection(Connection): # don't leave in-progress operations hanging if not self.is_defunct: - self.error_all_callbacks( + self.error_all_requests( ConnectionShutdown("Connection to %s was closed" % self.host)) def handle_write(self, watcher, revents, errno=None): diff --git a/cassandra/io/twistedreactor.py b/cassandra/io/twistedreactor.py index b02fb6ab..967a968f 100644 --- a/cassandra/io/twistedreactor.py +++ b/cassandra/io/twistedreactor.py @@ -96,7 +96,7 @@ class TwistedConnectionClientFactory(protocol.ClientFactory): It should be safe to call defunct() here instead of just close, because we can assume that if the connection was closed cleanly, there are no - callbacks to error out. If this assumption turns out to be false, we + requests to error out. If this assumption turns out to be false, we can call close() instead of defunct() when "reason" is an appropriate type. """ @@ -213,7 +213,7 @@ class TwistedConnection(Connection): def close(self): """ - Disconnect and error-out all callbacks. + Disconnect and error-out all requests. """ with self.lock: if self.is_closed: @@ -225,7 +225,7 @@ class TwistedConnection(Connection): log.debug("Closed socket to %s", self.host) if not self.is_defunct: - self.error_all_callbacks( + self.error_all_requests( ConnectionShutdown("Connection to %s was closed" % self.host)) # don't leave in-progress operations hanging self.connected_event.set() diff --git a/tests/unit/io/test_twistedreactor.py b/tests/unit/io/test_twistedreactor.py index d2142b09..928436e0 100644 --- a/tests/unit/io/test_twistedreactor.py +++ b/tests/unit/io/test_twistedreactor.py @@ -140,13 +140,13 @@ class TestTwistedConnection(unittest.TestCase): """ Verify that close() disconnects the connector and errors callbacks. """ - self.obj_ut.error_all_callbacks = Mock() + self.obj_ut.error_all_requests = Mock() self.obj_ut.add_connection() self.obj_ut.is_closed = False self.obj_ut.close() self.obj_ut.connector.disconnect.assert_called_with() self.assertTrue(self.obj_ut.connected_event.is_set()) - self.assertTrue(self.obj_ut.error_all_callbacks.called) + self.assertTrue(self.obj_ut.error_all_requests.called) def test_handle_read__incomplete(self): """ diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 268e19d5..521fab61 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -27,7 +27,7 @@ from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, Protoc locally_supported_compressions, ConnectionHeartbeat, _Frame) from cassandra.marshal import uint8_pack, uint32_pack, int32_pack from cassandra.protocol import (write_stringmultimap, write_int, write_string, - SupportedMessage) + SupportedMessage, ProtocolHandler) class ConnectionTest(unittest.TestCase): @@ -75,7 +75,7 @@ class ConnectionTest(unittest.TestCase): def test_bad_protocol_version(self, *args): c = self.make_connection() - c._callbacks = Mock() + c._requests = Mock() c.defunct = Mock() # read in a SupportedMessage response @@ -93,7 +93,7 @@ class ConnectionTest(unittest.TestCase): def test_negative_body_length(self, *args): c = self.make_connection() - c._callbacks = Mock() + c._requests = Mock() c.defunct = Mock() # read in a SupportedMessage response @@ -110,7 +110,7 @@ class ConnectionTest(unittest.TestCase): def test_unsupported_cql_version(self, *args): c = self.make_connection() - c._callbacks = {0: c._handle_options_response} + c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)} c.defunct = Mock() c.cql_version = "3.0.3" @@ -133,7 +133,7 @@ class ConnectionTest(unittest.TestCase): def test_prefer_lz4_compression(self, *args): c = self.make_connection() - c._callbacks = {0: c._handle_options_response} + c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)} c.defunct = Mock() c.cql_version = "3.0.3" @@ -156,7 +156,7 @@ class ConnectionTest(unittest.TestCase): def test_requested_compression_not_available(self, *args): c = self.make_connection() - c._callbacks = {0: c._handle_options_response} + c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)} c.defunct = Mock() # request lz4 compression c.compression = "lz4" @@ -186,7 +186,7 @@ class ConnectionTest(unittest.TestCase): def test_use_requested_compression(self, *args): c = self.make_connection() - c._callbacks = {0: c._handle_options_response} + c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)} c.defunct = Mock() # request snappy compression c.compression = "snappy" @@ -213,7 +213,7 @@ class ConnectionTest(unittest.TestCase): def test_disable_compression(self, *args): c = self.make_connection() - c._callbacks = {0: c._handle_options_response} + c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)} c.defunct = Mock() # disable compression c.compression = False diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index eea43f75..d266ab77 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -27,7 +27,7 @@ from cassandra.protocol import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessag OverloadedErrorMessage, IsBootstrappingErrorMessage, PreparedQueryNotFound, PrepareMessage, RESULT_KIND_ROWS, RESULT_KIND_SET_KEYSPACE, - RESULT_KIND_SCHEMA_CHANGE) + RESULT_KIND_SCHEMA_CHANGE, ProtocolHandler) from cassandra.policies import RetryPolicy from cassandra.pool import NoConnectionsAvailable from cassandra.query import SimpleStatement @@ -67,7 +67,7 @@ class ResponseFutureTests(unittest.TestCase): rf.session._pools.get.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY) - connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY) + connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message) rf._set_result(self.make_mock_response([{'col': 'val'}])) result = rf.result() @@ -189,7 +189,7 @@ class ResponseFutureTests(unittest.TestCase): rf.session._pools.get.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY) - connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY) + connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message) result = Mock(spec=UnavailableErrorMessage, info={}) rf._set_result(result) @@ -207,7 +207,7 @@ class ResponseFutureTests(unittest.TestCase): # an UnavailableException rf.session._pools.get.assert_called_with('ip1') pool.borrow_connection.assert_called_with(timeout=ANY) - connection.send_msg.assert_called_with(rf.message, 2, cb=ANY) + connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message) def test_retry_with_different_host(self): session = self.make_session() @@ -222,7 +222,7 @@ class ResponseFutureTests(unittest.TestCase): rf.session._pools.get.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY) - connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY) + connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message) self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level) result = Mock(spec=OverloadedErrorMessage, info={}) @@ -240,7 +240,7 @@ class ResponseFutureTests(unittest.TestCase): # it should try with a different host rf.session._pools.get.assert_called_with('ip2') pool.borrow_connection.assert_called_with(timeout=ANY) - connection.send_msg.assert_called_with(rf.message, 2, cb=ANY) + connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message) # the consistency level should be the same self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level)