Isolate custom protocol handling to client requests.
This commit is contained in:
@@ -420,15 +420,6 @@ class Cluster(object):
|
||||
GeventConnection will be used automatically.
|
||||
"""
|
||||
|
||||
protocol_handler_class = ProtocolHandler
|
||||
"""
|
||||
Specifies a protocol handler class, which can be used to override or extend features
|
||||
such as message or type deserialization.
|
||||
|
||||
The class must conform to the public classmethod interface defined in the default
|
||||
implementation, :class:`cassandra.protocol.ProtocolHandler`
|
||||
"""
|
||||
|
||||
control_connection_timeout = 2.0
|
||||
"""
|
||||
A timeout, in seconds, for queries made by the control connection, such
|
||||
@@ -525,8 +516,7 @@ class Cluster(object):
|
||||
idle_heartbeat_interval=30,
|
||||
schema_event_refresh_window=2,
|
||||
topology_event_refresh_window=10,
|
||||
connect_timeout=5,
|
||||
protocol_handler_class=None):
|
||||
connect_timeout=5):
|
||||
"""
|
||||
Any of the mutable Cluster attributes may be set as keyword arguments
|
||||
to the constructor.
|
||||
@@ -570,9 +560,6 @@ class Cluster(object):
|
||||
if connection_class is not None:
|
||||
self.connection_class = connection_class
|
||||
|
||||
if protocol_handler_class is not None:
|
||||
self.protocol_handler_class = protocol_handler_class
|
||||
|
||||
self.metrics_enabled = metrics_enabled
|
||||
self.ssl_options = ssl_options
|
||||
self.sockopts = sockopts
|
||||
@@ -812,7 +799,6 @@ class Cluster(object):
|
||||
kwargs_dict.setdefault('cql_version', self.cql_version)
|
||||
kwargs_dict.setdefault('protocol_version', self.protocol_version)
|
||||
kwargs_dict.setdefault('user_type_map', self._user_types)
|
||||
kwargs_dict.setdefault('protocol_handler_class', self.protocol_handler_class)
|
||||
|
||||
return kwargs_dict
|
||||
|
||||
@@ -1372,7 +1358,7 @@ class Cluster(object):
|
||||
log.debug("Preparing all known prepared statements against host %s", host)
|
||||
connection = None
|
||||
try:
|
||||
connection = self.connection_factory(host.address, protocol_handler_class=ProtocolHandler)
|
||||
connection = self.connection_factory(host.address)
|
||||
try:
|
||||
self.control_connection.wait_for_schema_agreement(connection)
|
||||
except Exception:
|
||||
@@ -1535,6 +1521,20 @@ class Session(object):
|
||||
.. versionadded:: 2.1.0
|
||||
"""
|
||||
|
||||
client_protocol_handler = ProtocolHandler
|
||||
"""
|
||||
Specifies a protocol handler that will be used for client-initiated requests (i.e. no
|
||||
internal driver requests). This can be used to override or extend features such as
|
||||
message or type ser/des.
|
||||
|
||||
The class must conform to the public classmethod interface defined in the default
|
||||
implementation, :class:`cassandra.protocol.ProtocolHandler`
|
||||
|
||||
This is not included in published documentation as it is not intended for the casual user.
|
||||
It requires knowledge of the native protocol and driver internals. Only advanced, specialized
|
||||
use cases should need to do anything with this.
|
||||
"""
|
||||
|
||||
_lock = None
|
||||
_pools = None
|
||||
_load_balancer = None
|
||||
@@ -1661,6 +1661,7 @@ class Session(object):
|
||||
timeout = self.default_timeout
|
||||
|
||||
future = self._create_response_future(query, parameters, trace, custom_payload, timeout)
|
||||
future._protocol_handler = self.client_protocol_handler
|
||||
future.send_request()
|
||||
return future
|
||||
|
||||
@@ -2131,7 +2132,7 @@ class ControlConnection(object):
|
||||
|
||||
while True:
|
||||
try:
|
||||
connection = self._cluster.connection_factory(host.address, is_control_connection=True, protocol_handler_class=ProtocolHandler)
|
||||
connection = self._cluster.connection_factory(host.address, is_control_connection=True)
|
||||
break
|
||||
except ProtocolVersionUnsupported as e:
|
||||
self._cluster.protocol_downgrade(host.address, e.startup_version)
|
||||
@@ -2846,6 +2847,7 @@ class ResponseFuture(object):
|
||||
_custom_payload = None
|
||||
_warnings = None
|
||||
_timer = None
|
||||
_protocol_handler = ProtocolHandler
|
||||
|
||||
def __init__(self, session, message, query, timeout, metrics=None, prepared_statement=None):
|
||||
self.session = session
|
||||
@@ -2919,7 +2921,7 @@ class ResponseFuture(object):
|
||||
# TODO get connectTimeout from cluster settings
|
||||
connection, request_id = pool.borrow_connection(timeout=2.0)
|
||||
self._connection = connection
|
||||
connection.send_msg(message, request_id, cb=cb)
|
||||
connection.send_msg(message, request_id, cb=cb, encoder=self._protocol_handler.encode_message, decoder=self._protocol_handler.decode_message)
|
||||
return request_id
|
||||
except NoConnectionsAvailable as exc:
|
||||
log.debug("All connections for host %s are at capacity, moving to the next host", host)
|
||||
|
@@ -209,7 +209,7 @@ class Connection(object):
|
||||
def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
|
||||
ssl_options=None, sockopts=None, compression=True,
|
||||
cql_version=None, protocol_version=MAX_SUPPORTED_VERSION, is_control_connection=False,
|
||||
user_type_map=None, protocol_handler_class=ProtocolHandler):
|
||||
user_type_map=None):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.authenticator = authenticator
|
||||
@@ -220,10 +220,8 @@ class Connection(object):
|
||||
self.protocol_version = protocol_version
|
||||
self.is_control_connection = is_control_connection
|
||||
self.user_type_map = user_type_map
|
||||
self.decoder = protocol_handler_class.decode_message
|
||||
self.encoder = protocol_handler_class.encode_message
|
||||
self._push_watchers = defaultdict(set)
|
||||
self._callbacks = {}
|
||||
self._requests = {}
|
||||
self._iobuf = io.BytesIO()
|
||||
|
||||
if protocol_version >= 3:
|
||||
@@ -320,20 +318,20 @@ class Connection(object):
|
||||
|
||||
self.last_error = exc
|
||||
self.close()
|
||||
self.error_all_callbacks(exc)
|
||||
self.error_all_requests(exc)
|
||||
self.connected_event.set()
|
||||
return exc
|
||||
|
||||
def error_all_callbacks(self, exc):
|
||||
def error_all_requests(self, exc):
|
||||
with self.lock:
|
||||
callbacks = self._callbacks
|
||||
self._callbacks = {}
|
||||
requests = self._requests
|
||||
self._requests = {}
|
||||
new_exc = ConnectionShutdown(str(exc))
|
||||
for cb in callbacks.values():
|
||||
for cb, _ in requests.values():
|
||||
try:
|
||||
cb(new_exc)
|
||||
except Exception:
|
||||
log.warning("Ignoring unhandled exception while erroring callbacks for a "
|
||||
log.warning("Ignoring unhandled exception while erroring requests for a "
|
||||
"failed connection (%s) to host %s:",
|
||||
id(self), self.host, exc_info=True)
|
||||
|
||||
@@ -357,14 +355,16 @@ class Connection(object):
|
||||
except Exception:
|
||||
log.exception("Pushed event handler errored, ignoring:")
|
||||
|
||||
def send_msg(self, msg, request_id, cb):
|
||||
def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message):
|
||||
if self.is_defunct:
|
||||
raise ConnectionShutdown("Connection to %s is defunct" % self.host)
|
||||
elif self.is_closed:
|
||||
raise ConnectionShutdown("Connection to %s is closed" % self.host)
|
||||
|
||||
self._callbacks[request_id] = cb
|
||||
self.push(self.encoder(msg, request_id, self.protocol_version, compressor=self.compressor))
|
||||
# queue the decoder function with the request
|
||||
# this allows us to inject custom functions per request to encode, decode messages
|
||||
self._requests[request_id] = (cb, decoder)
|
||||
self.push(encoder(msg, request_id, self.protocol_version, compressor=self.compressor))
|
||||
return request_id
|
||||
|
||||
def wait_for_response(self, msg, timeout=None):
|
||||
@@ -492,16 +492,17 @@ class Connection(object):
|
||||
stream_id = header.stream
|
||||
if stream_id < 0:
|
||||
callback = None
|
||||
decoder = ProtocolHandler.decode_message
|
||||
else:
|
||||
callback = self._callbacks.pop(stream_id, None)
|
||||
callback, decoder = self._requests.pop(stream_id, None)
|
||||
with self.lock:
|
||||
self.request_ids.append(stream_id)
|
||||
|
||||
self.msg_received = True
|
||||
|
||||
try:
|
||||
response = self.decoder(header.version, self.user_type_map, stream_id,
|
||||
header.flags, header.opcode, body, self.decompressor)
|
||||
response = decoder(header.version, self.user_type_map, stream_id,
|
||||
header.flags, header.opcode, body, self.decompressor)
|
||||
except Exception as exc:
|
||||
log.exception("Error decoding response from Cassandra. "
|
||||
"opcode: %04x; message contents: %r", header.opcode, body)
|
||||
|
@@ -183,7 +183,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
|
||||
log.debug("Closed socket to %s", self.host)
|
||||
|
||||
if not self.is_defunct:
|
||||
self.error_all_callbacks(
|
||||
self.error_all_requests(
|
||||
ConnectionShutdown("Connection to %s was closed" % self.host))
|
||||
# don't leave in-progress operations hanging
|
||||
self.connected_event.set()
|
||||
@@ -239,7 +239,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
|
||||
|
||||
if self._iobuf.tell():
|
||||
self.process_io_buffer()
|
||||
if not self._callbacks and not self.is_control_connection:
|
||||
if not self._requests and not self.is_control_connection:
|
||||
self._readable = False
|
||||
|
||||
def push(self, data):
|
||||
|
@@ -116,7 +116,7 @@ class EventletConnection(Connection):
|
||||
log.debug("Closed socket to %s" % (self.host,))
|
||||
|
||||
if not self.is_defunct:
|
||||
self.error_all_callbacks(
|
||||
self.error_all_requests(
|
||||
ConnectionShutdown("Connection to %s was closed" % self.host))
|
||||
# don't leave in-progress operations hanging
|
||||
self.connected_event.set()
|
||||
|
@@ -106,7 +106,7 @@ class GeventConnection(Connection):
|
||||
log.debug("Closed socket to %s" % (self.host,))
|
||||
|
||||
if not self.is_defunct:
|
||||
self.error_all_callbacks(
|
||||
self.error_all_requests(
|
||||
ConnectionShutdown("Connection to %s was closed" % self.host))
|
||||
# don't leave in-progress operations hanging
|
||||
self.connected_event.set()
|
||||
|
@@ -290,7 +290,7 @@ class LibevConnection(Connection):
|
||||
|
||||
# don't leave in-progress operations hanging
|
||||
if not self.is_defunct:
|
||||
self.error_all_callbacks(
|
||||
self.error_all_requests(
|
||||
ConnectionShutdown("Connection to %s was closed" % self.host))
|
||||
|
||||
def handle_write(self, watcher, revents, errno=None):
|
||||
|
@@ -96,7 +96,7 @@ class TwistedConnectionClientFactory(protocol.ClientFactory):
|
||||
|
||||
It should be safe to call defunct() here instead of just close, because
|
||||
we can assume that if the connection was closed cleanly, there are no
|
||||
callbacks to error out. If this assumption turns out to be false, we
|
||||
requests to error out. If this assumption turns out to be false, we
|
||||
can call close() instead of defunct() when "reason" is an appropriate
|
||||
type.
|
||||
"""
|
||||
@@ -213,7 +213,7 @@ class TwistedConnection(Connection):
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Disconnect and error-out all callbacks.
|
||||
Disconnect and error-out all requests.
|
||||
"""
|
||||
with self.lock:
|
||||
if self.is_closed:
|
||||
@@ -225,7 +225,7 @@ class TwistedConnection(Connection):
|
||||
log.debug("Closed socket to %s", self.host)
|
||||
|
||||
if not self.is_defunct:
|
||||
self.error_all_callbacks(
|
||||
self.error_all_requests(
|
||||
ConnectionShutdown("Connection to %s was closed" % self.host))
|
||||
# don't leave in-progress operations hanging
|
||||
self.connected_event.set()
|
||||
|
@@ -140,13 +140,13 @@ class TestTwistedConnection(unittest.TestCase):
|
||||
"""
|
||||
Verify that close() disconnects the connector and errors callbacks.
|
||||
"""
|
||||
self.obj_ut.error_all_callbacks = Mock()
|
||||
self.obj_ut.error_all_requests = Mock()
|
||||
self.obj_ut.add_connection()
|
||||
self.obj_ut.is_closed = False
|
||||
self.obj_ut.close()
|
||||
self.obj_ut.connector.disconnect.assert_called_with()
|
||||
self.assertTrue(self.obj_ut.connected_event.is_set())
|
||||
self.assertTrue(self.obj_ut.error_all_callbacks.called)
|
||||
self.assertTrue(self.obj_ut.error_all_requests.called)
|
||||
|
||||
def test_handle_read__incomplete(self):
|
||||
"""
|
||||
|
@@ -27,7 +27,7 @@ from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, Protoc
|
||||
locally_supported_compressions, ConnectionHeartbeat, _Frame)
|
||||
from cassandra.marshal import uint8_pack, uint32_pack, int32_pack
|
||||
from cassandra.protocol import (write_stringmultimap, write_int, write_string,
|
||||
SupportedMessage)
|
||||
SupportedMessage, ProtocolHandler)
|
||||
|
||||
|
||||
class ConnectionTest(unittest.TestCase):
|
||||
@@ -75,7 +75,7 @@ class ConnectionTest(unittest.TestCase):
|
||||
|
||||
def test_bad_protocol_version(self, *args):
|
||||
c = self.make_connection()
|
||||
c._callbacks = Mock()
|
||||
c._requests = Mock()
|
||||
c.defunct = Mock()
|
||||
|
||||
# read in a SupportedMessage response
|
||||
@@ -93,7 +93,7 @@ class ConnectionTest(unittest.TestCase):
|
||||
|
||||
def test_negative_body_length(self, *args):
|
||||
c = self.make_connection()
|
||||
c._callbacks = Mock()
|
||||
c._requests = Mock()
|
||||
c.defunct = Mock()
|
||||
|
||||
# read in a SupportedMessage response
|
||||
@@ -110,7 +110,7 @@ class ConnectionTest(unittest.TestCase):
|
||||
|
||||
def test_unsupported_cql_version(self, *args):
|
||||
c = self.make_connection()
|
||||
c._callbacks = {0: c._handle_options_response}
|
||||
c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)}
|
||||
c.defunct = Mock()
|
||||
c.cql_version = "3.0.3"
|
||||
|
||||
@@ -133,7 +133,7 @@ class ConnectionTest(unittest.TestCase):
|
||||
|
||||
def test_prefer_lz4_compression(self, *args):
|
||||
c = self.make_connection()
|
||||
c._callbacks = {0: c._handle_options_response}
|
||||
c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)}
|
||||
c.defunct = Mock()
|
||||
c.cql_version = "3.0.3"
|
||||
|
||||
@@ -156,7 +156,7 @@ class ConnectionTest(unittest.TestCase):
|
||||
|
||||
def test_requested_compression_not_available(self, *args):
|
||||
c = self.make_connection()
|
||||
c._callbacks = {0: c._handle_options_response}
|
||||
c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)}
|
||||
c.defunct = Mock()
|
||||
# request lz4 compression
|
||||
c.compression = "lz4"
|
||||
@@ -186,7 +186,7 @@ class ConnectionTest(unittest.TestCase):
|
||||
|
||||
def test_use_requested_compression(self, *args):
|
||||
c = self.make_connection()
|
||||
c._callbacks = {0: c._handle_options_response}
|
||||
c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)}
|
||||
c.defunct = Mock()
|
||||
# request snappy compression
|
||||
c.compression = "snappy"
|
||||
@@ -213,7 +213,7 @@ class ConnectionTest(unittest.TestCase):
|
||||
|
||||
def test_disable_compression(self, *args):
|
||||
c = self.make_connection()
|
||||
c._callbacks = {0: c._handle_options_response}
|
||||
c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)}
|
||||
c.defunct = Mock()
|
||||
# disable compression
|
||||
c.compression = False
|
||||
|
@@ -27,7 +27,7 @@ from cassandra.protocol import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessag
|
||||
OverloadedErrorMessage, IsBootstrappingErrorMessage,
|
||||
PreparedQueryNotFound, PrepareMessage,
|
||||
RESULT_KIND_ROWS, RESULT_KIND_SET_KEYSPACE,
|
||||
RESULT_KIND_SCHEMA_CHANGE)
|
||||
RESULT_KIND_SCHEMA_CHANGE, ProtocolHandler)
|
||||
from cassandra.policies import RetryPolicy
|
||||
from cassandra.pool import NoConnectionsAvailable
|
||||
from cassandra.query import SimpleStatement
|
||||
@@ -67,7 +67,7 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
rf.session._pools.get.assert_called_once_with('ip1')
|
||||
pool.borrow_connection.assert_called_once_with(timeout=ANY)
|
||||
|
||||
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY)
|
||||
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message)
|
||||
|
||||
rf._set_result(self.make_mock_response([{'col': 'val'}]))
|
||||
result = rf.result()
|
||||
@@ -189,7 +189,7 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
|
||||
rf.session._pools.get.assert_called_once_with('ip1')
|
||||
pool.borrow_connection.assert_called_once_with(timeout=ANY)
|
||||
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY)
|
||||
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message)
|
||||
|
||||
result = Mock(spec=UnavailableErrorMessage, info={})
|
||||
rf._set_result(result)
|
||||
@@ -207,7 +207,7 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
# an UnavailableException
|
||||
rf.session._pools.get.assert_called_with('ip1')
|
||||
pool.borrow_connection.assert_called_with(timeout=ANY)
|
||||
connection.send_msg.assert_called_with(rf.message, 2, cb=ANY)
|
||||
connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message)
|
||||
|
||||
def test_retry_with_different_host(self):
|
||||
session = self.make_session()
|
||||
@@ -222,7 +222,7 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
|
||||
rf.session._pools.get.assert_called_once_with('ip1')
|
||||
pool.borrow_connection.assert_called_once_with(timeout=ANY)
|
||||
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY)
|
||||
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message)
|
||||
self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level)
|
||||
|
||||
result = Mock(spec=OverloadedErrorMessage, info={})
|
||||
@@ -240,7 +240,7 @@ class ResponseFutureTests(unittest.TestCase):
|
||||
# it should try with a different host
|
||||
rf.session._pools.get.assert_called_with('ip2')
|
||||
pool.borrow_connection.assert_called_with(timeout=ANY)
|
||||
connection.send_msg.assert_called_with(rf.message, 2, cb=ANY)
|
||||
connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message)
|
||||
|
||||
# the consistency level should be the same
|
||||
self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level)
|
||||
|
Reference in New Issue
Block a user