Isolate custom protocol handling to client requests.

This commit is contained in:
Adam Holmberg
2015-07-16 15:13:19 -05:00
parent 7dcabd744a
commit ab9f690707
10 changed files with 61 additions and 58 deletions

View File

@@ -420,15 +420,6 @@ class Cluster(object):
GeventConnection will be used automatically. GeventConnection will be used automatically.
""" """
protocol_handler_class = ProtocolHandler
"""
Specifies a protocol handler class, which can be used to override or extend features
such as message or type deserialization.
The class must conform to the public classmethod interface defined in the default
implementation, :class:`cassandra.protocol.ProtocolHandler`
"""
control_connection_timeout = 2.0 control_connection_timeout = 2.0
""" """
A timeout, in seconds, for queries made by the control connection, such A timeout, in seconds, for queries made by the control connection, such
@@ -525,8 +516,7 @@ class Cluster(object):
idle_heartbeat_interval=30, idle_heartbeat_interval=30,
schema_event_refresh_window=2, schema_event_refresh_window=2,
topology_event_refresh_window=10, topology_event_refresh_window=10,
connect_timeout=5, connect_timeout=5):
protocol_handler_class=None):
""" """
Any of the mutable Cluster attributes may be set as keyword arguments Any of the mutable Cluster attributes may be set as keyword arguments
to the constructor. to the constructor.
@@ -570,9 +560,6 @@ class Cluster(object):
if connection_class is not None: if connection_class is not None:
self.connection_class = connection_class self.connection_class = connection_class
if protocol_handler_class is not None:
self.protocol_handler_class = protocol_handler_class
self.metrics_enabled = metrics_enabled self.metrics_enabled = metrics_enabled
self.ssl_options = ssl_options self.ssl_options = ssl_options
self.sockopts = sockopts self.sockopts = sockopts
@@ -812,7 +799,6 @@ class Cluster(object):
kwargs_dict.setdefault('cql_version', self.cql_version) kwargs_dict.setdefault('cql_version', self.cql_version)
kwargs_dict.setdefault('protocol_version', self.protocol_version) kwargs_dict.setdefault('protocol_version', self.protocol_version)
kwargs_dict.setdefault('user_type_map', self._user_types) kwargs_dict.setdefault('user_type_map', self._user_types)
kwargs_dict.setdefault('protocol_handler_class', self.protocol_handler_class)
return kwargs_dict return kwargs_dict
@@ -1372,7 +1358,7 @@ class Cluster(object):
log.debug("Preparing all known prepared statements against host %s", host) log.debug("Preparing all known prepared statements against host %s", host)
connection = None connection = None
try: try:
connection = self.connection_factory(host.address, protocol_handler_class=ProtocolHandler) connection = self.connection_factory(host.address)
try: try:
self.control_connection.wait_for_schema_agreement(connection) self.control_connection.wait_for_schema_agreement(connection)
except Exception: except Exception:
@@ -1535,6 +1521,20 @@ class Session(object):
.. versionadded:: 2.1.0 .. versionadded:: 2.1.0
""" """
client_protocol_handler = ProtocolHandler
"""
Specifies a protocol handler that will be used for client-initiated requests (i.e. no
internal driver requests). This can be used to override or extend features such as
message or type ser/des.
The class must conform to the public classmethod interface defined in the default
implementation, :class:`cassandra.protocol.ProtocolHandler`
This is not included in published documentation as it is not intended for the casual user.
It requires knowledge of the native protocol and driver internals. Only advanced, specialized
use cases should need to do anything with this.
"""
_lock = None _lock = None
_pools = None _pools = None
_load_balancer = None _load_balancer = None
@@ -1661,6 +1661,7 @@ class Session(object):
timeout = self.default_timeout timeout = self.default_timeout
future = self._create_response_future(query, parameters, trace, custom_payload, timeout) future = self._create_response_future(query, parameters, trace, custom_payload, timeout)
future._protocol_handler = self.client_protocol_handler
future.send_request() future.send_request()
return future return future
@@ -2131,7 +2132,7 @@ class ControlConnection(object):
while True: while True:
try: try:
connection = self._cluster.connection_factory(host.address, is_control_connection=True, protocol_handler_class=ProtocolHandler) connection = self._cluster.connection_factory(host.address, is_control_connection=True)
break break
except ProtocolVersionUnsupported as e: except ProtocolVersionUnsupported as e:
self._cluster.protocol_downgrade(host.address, e.startup_version) self._cluster.protocol_downgrade(host.address, e.startup_version)
@@ -2846,6 +2847,7 @@ class ResponseFuture(object):
_custom_payload = None _custom_payload = None
_warnings = None _warnings = None
_timer = None _timer = None
_protocol_handler = ProtocolHandler
def __init__(self, session, message, query, timeout, metrics=None, prepared_statement=None): def __init__(self, session, message, query, timeout, metrics=None, prepared_statement=None):
self.session = session self.session = session
@@ -2919,7 +2921,7 @@ class ResponseFuture(object):
# TODO get connectTimeout from cluster settings # TODO get connectTimeout from cluster settings
connection, request_id = pool.borrow_connection(timeout=2.0) connection, request_id = pool.borrow_connection(timeout=2.0)
self._connection = connection self._connection = connection
connection.send_msg(message, request_id, cb=cb) connection.send_msg(message, request_id, cb=cb, encoder=self._protocol_handler.encode_message, decoder=self._protocol_handler.decode_message)
return request_id return request_id
except NoConnectionsAvailable as exc: except NoConnectionsAvailable as exc:
log.debug("All connections for host %s are at capacity, moving to the next host", host) log.debug("All connections for host %s are at capacity, moving to the next host", host)

View File

@@ -209,7 +209,7 @@ class Connection(object):
def __init__(self, host='127.0.0.1', port=9042, authenticator=None, def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
ssl_options=None, sockopts=None, compression=True, ssl_options=None, sockopts=None, compression=True,
cql_version=None, protocol_version=MAX_SUPPORTED_VERSION, is_control_connection=False, cql_version=None, protocol_version=MAX_SUPPORTED_VERSION, is_control_connection=False,
user_type_map=None, protocol_handler_class=ProtocolHandler): user_type_map=None):
self.host = host self.host = host
self.port = port self.port = port
self.authenticator = authenticator self.authenticator = authenticator
@@ -220,10 +220,8 @@ class Connection(object):
self.protocol_version = protocol_version self.protocol_version = protocol_version
self.is_control_connection = is_control_connection self.is_control_connection = is_control_connection
self.user_type_map = user_type_map self.user_type_map = user_type_map
self.decoder = protocol_handler_class.decode_message
self.encoder = protocol_handler_class.encode_message
self._push_watchers = defaultdict(set) self._push_watchers = defaultdict(set)
self._callbacks = {} self._requests = {}
self._iobuf = io.BytesIO() self._iobuf = io.BytesIO()
if protocol_version >= 3: if protocol_version >= 3:
@@ -320,20 +318,20 @@ class Connection(object):
self.last_error = exc self.last_error = exc
self.close() self.close()
self.error_all_callbacks(exc) self.error_all_requests(exc)
self.connected_event.set() self.connected_event.set()
return exc return exc
def error_all_callbacks(self, exc): def error_all_requests(self, exc):
with self.lock: with self.lock:
callbacks = self._callbacks requests = self._requests
self._callbacks = {} self._requests = {}
new_exc = ConnectionShutdown(str(exc)) new_exc = ConnectionShutdown(str(exc))
for cb in callbacks.values(): for cb, _ in requests.values():
try: try:
cb(new_exc) cb(new_exc)
except Exception: except Exception:
log.warning("Ignoring unhandled exception while erroring callbacks for a " log.warning("Ignoring unhandled exception while erroring requests for a "
"failed connection (%s) to host %s:", "failed connection (%s) to host %s:",
id(self), self.host, exc_info=True) id(self), self.host, exc_info=True)
@@ -357,14 +355,16 @@ class Connection(object):
except Exception: except Exception:
log.exception("Pushed event handler errored, ignoring:") log.exception("Pushed event handler errored, ignoring:")
def send_msg(self, msg, request_id, cb): def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message):
if self.is_defunct: if self.is_defunct:
raise ConnectionShutdown("Connection to %s is defunct" % self.host) raise ConnectionShutdown("Connection to %s is defunct" % self.host)
elif self.is_closed: elif self.is_closed:
raise ConnectionShutdown("Connection to %s is closed" % self.host) raise ConnectionShutdown("Connection to %s is closed" % self.host)
self._callbacks[request_id] = cb # queue the decoder function with the request
self.push(self.encoder(msg, request_id, self.protocol_version, compressor=self.compressor)) # this allows us to inject custom functions per request to encode, decode messages
self._requests[request_id] = (cb, decoder)
self.push(encoder(msg, request_id, self.protocol_version, compressor=self.compressor))
return request_id return request_id
def wait_for_response(self, msg, timeout=None): def wait_for_response(self, msg, timeout=None):
@@ -492,15 +492,16 @@ class Connection(object):
stream_id = header.stream stream_id = header.stream
if stream_id < 0: if stream_id < 0:
callback = None callback = None
decoder = ProtocolHandler.decode_message
else: else:
callback = self._callbacks.pop(stream_id, None) callback, decoder = self._requests.pop(stream_id, None)
with self.lock: with self.lock:
self.request_ids.append(stream_id) self.request_ids.append(stream_id)
self.msg_received = True self.msg_received = True
try: try:
response = self.decoder(header.version, self.user_type_map, stream_id, response = decoder(header.version, self.user_type_map, stream_id,
header.flags, header.opcode, body, self.decompressor) header.flags, header.opcode, body, self.decompressor)
except Exception as exc: except Exception as exc:
log.exception("Error decoding response from Cassandra. " log.exception("Error decoding response from Cassandra. "

View File

@@ -183,7 +183,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
log.debug("Closed socket to %s", self.host) log.debug("Closed socket to %s", self.host)
if not self.is_defunct: if not self.is_defunct:
self.error_all_callbacks( self.error_all_requests(
ConnectionShutdown("Connection to %s was closed" % self.host)) ConnectionShutdown("Connection to %s was closed" % self.host))
# don't leave in-progress operations hanging # don't leave in-progress operations hanging
self.connected_event.set() self.connected_event.set()
@@ -239,7 +239,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
if self._iobuf.tell(): if self._iobuf.tell():
self.process_io_buffer() self.process_io_buffer()
if not self._callbacks and not self.is_control_connection: if not self._requests and not self.is_control_connection:
self._readable = False self._readable = False
def push(self, data): def push(self, data):

View File

@@ -116,7 +116,7 @@ class EventletConnection(Connection):
log.debug("Closed socket to %s" % (self.host,)) log.debug("Closed socket to %s" % (self.host,))
if not self.is_defunct: if not self.is_defunct:
self.error_all_callbacks( self.error_all_requests(
ConnectionShutdown("Connection to %s was closed" % self.host)) ConnectionShutdown("Connection to %s was closed" % self.host))
# don't leave in-progress operations hanging # don't leave in-progress operations hanging
self.connected_event.set() self.connected_event.set()

View File

@@ -106,7 +106,7 @@ class GeventConnection(Connection):
log.debug("Closed socket to %s" % (self.host,)) log.debug("Closed socket to %s" % (self.host,))
if not self.is_defunct: if not self.is_defunct:
self.error_all_callbacks( self.error_all_requests(
ConnectionShutdown("Connection to %s was closed" % self.host)) ConnectionShutdown("Connection to %s was closed" % self.host))
# don't leave in-progress operations hanging # don't leave in-progress operations hanging
self.connected_event.set() self.connected_event.set()

View File

@@ -290,7 +290,7 @@ class LibevConnection(Connection):
# don't leave in-progress operations hanging # don't leave in-progress operations hanging
if not self.is_defunct: if not self.is_defunct:
self.error_all_callbacks( self.error_all_requests(
ConnectionShutdown("Connection to %s was closed" % self.host)) ConnectionShutdown("Connection to %s was closed" % self.host))
def handle_write(self, watcher, revents, errno=None): def handle_write(self, watcher, revents, errno=None):

View File

@@ -96,7 +96,7 @@ class TwistedConnectionClientFactory(protocol.ClientFactory):
It should be safe to call defunct() here instead of just close, because It should be safe to call defunct() here instead of just close, because
we can assume that if the connection was closed cleanly, there are no we can assume that if the connection was closed cleanly, there are no
callbacks to error out. If this assumption turns out to be false, we requests to error out. If this assumption turns out to be false, we
can call close() instead of defunct() when "reason" is an appropriate can call close() instead of defunct() when "reason" is an appropriate
type. type.
""" """
@@ -213,7 +213,7 @@ class TwistedConnection(Connection):
def close(self): def close(self):
""" """
Disconnect and error-out all callbacks. Disconnect and error-out all requests.
""" """
with self.lock: with self.lock:
if self.is_closed: if self.is_closed:
@@ -225,7 +225,7 @@ class TwistedConnection(Connection):
log.debug("Closed socket to %s", self.host) log.debug("Closed socket to %s", self.host)
if not self.is_defunct: if not self.is_defunct:
self.error_all_callbacks( self.error_all_requests(
ConnectionShutdown("Connection to %s was closed" % self.host)) ConnectionShutdown("Connection to %s was closed" % self.host))
# don't leave in-progress operations hanging # don't leave in-progress operations hanging
self.connected_event.set() self.connected_event.set()

View File

@@ -140,13 +140,13 @@ class TestTwistedConnection(unittest.TestCase):
""" """
Verify that close() disconnects the connector and errors callbacks. Verify that close() disconnects the connector and errors callbacks.
""" """
self.obj_ut.error_all_callbacks = Mock() self.obj_ut.error_all_requests = Mock()
self.obj_ut.add_connection() self.obj_ut.add_connection()
self.obj_ut.is_closed = False self.obj_ut.is_closed = False
self.obj_ut.close() self.obj_ut.close()
self.obj_ut.connector.disconnect.assert_called_with() self.obj_ut.connector.disconnect.assert_called_with()
self.assertTrue(self.obj_ut.connected_event.is_set()) self.assertTrue(self.obj_ut.connected_event.is_set())
self.assertTrue(self.obj_ut.error_all_callbacks.called) self.assertTrue(self.obj_ut.error_all_requests.called)
def test_handle_read__incomplete(self): def test_handle_read__incomplete(self):
""" """

View File

@@ -27,7 +27,7 @@ from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, Protoc
locally_supported_compressions, ConnectionHeartbeat, _Frame) locally_supported_compressions, ConnectionHeartbeat, _Frame)
from cassandra.marshal import uint8_pack, uint32_pack, int32_pack from cassandra.marshal import uint8_pack, uint32_pack, int32_pack
from cassandra.protocol import (write_stringmultimap, write_int, write_string, from cassandra.protocol import (write_stringmultimap, write_int, write_string,
SupportedMessage) SupportedMessage, ProtocolHandler)
class ConnectionTest(unittest.TestCase): class ConnectionTest(unittest.TestCase):
@@ -75,7 +75,7 @@ class ConnectionTest(unittest.TestCase):
def test_bad_protocol_version(self, *args): def test_bad_protocol_version(self, *args):
c = self.make_connection() c = self.make_connection()
c._callbacks = Mock() c._requests = Mock()
c.defunct = Mock() c.defunct = Mock()
# read in a SupportedMessage response # read in a SupportedMessage response
@@ -93,7 +93,7 @@ class ConnectionTest(unittest.TestCase):
def test_negative_body_length(self, *args): def test_negative_body_length(self, *args):
c = self.make_connection() c = self.make_connection()
c._callbacks = Mock() c._requests = Mock()
c.defunct = Mock() c.defunct = Mock()
# read in a SupportedMessage response # read in a SupportedMessage response
@@ -110,7 +110,7 @@ class ConnectionTest(unittest.TestCase):
def test_unsupported_cql_version(self, *args): def test_unsupported_cql_version(self, *args):
c = self.make_connection() c = self.make_connection()
c._callbacks = {0: c._handle_options_response} c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)}
c.defunct = Mock() c.defunct = Mock()
c.cql_version = "3.0.3" c.cql_version = "3.0.3"
@@ -133,7 +133,7 @@ class ConnectionTest(unittest.TestCase):
def test_prefer_lz4_compression(self, *args): def test_prefer_lz4_compression(self, *args):
c = self.make_connection() c = self.make_connection()
c._callbacks = {0: c._handle_options_response} c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)}
c.defunct = Mock() c.defunct = Mock()
c.cql_version = "3.0.3" c.cql_version = "3.0.3"
@@ -156,7 +156,7 @@ class ConnectionTest(unittest.TestCase):
def test_requested_compression_not_available(self, *args): def test_requested_compression_not_available(self, *args):
c = self.make_connection() c = self.make_connection()
c._callbacks = {0: c._handle_options_response} c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)}
c.defunct = Mock() c.defunct = Mock()
# request lz4 compression # request lz4 compression
c.compression = "lz4" c.compression = "lz4"
@@ -186,7 +186,7 @@ class ConnectionTest(unittest.TestCase):
def test_use_requested_compression(self, *args): def test_use_requested_compression(self, *args):
c = self.make_connection() c = self.make_connection()
c._callbacks = {0: c._handle_options_response} c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)}
c.defunct = Mock() c.defunct = Mock()
# request snappy compression # request snappy compression
c.compression = "snappy" c.compression = "snappy"
@@ -213,7 +213,7 @@ class ConnectionTest(unittest.TestCase):
def test_disable_compression(self, *args): def test_disable_compression(self, *args):
c = self.make_connection() c = self.make_connection()
c._callbacks = {0: c._handle_options_response} c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)}
c.defunct = Mock() c.defunct = Mock()
# disable compression # disable compression
c.compression = False c.compression = False

View File

@@ -27,7 +27,7 @@ from cassandra.protocol import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessag
OverloadedErrorMessage, IsBootstrappingErrorMessage, OverloadedErrorMessage, IsBootstrappingErrorMessage,
PreparedQueryNotFound, PrepareMessage, PreparedQueryNotFound, PrepareMessage,
RESULT_KIND_ROWS, RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS, RESULT_KIND_SET_KEYSPACE,
RESULT_KIND_SCHEMA_CHANGE) RESULT_KIND_SCHEMA_CHANGE, ProtocolHandler)
from cassandra.policies import RetryPolicy from cassandra.policies import RetryPolicy
from cassandra.pool import NoConnectionsAvailable from cassandra.pool import NoConnectionsAvailable
from cassandra.query import SimpleStatement from cassandra.query import SimpleStatement
@@ -67,7 +67,7 @@ class ResponseFutureTests(unittest.TestCase):
rf.session._pools.get.assert_called_once_with('ip1') rf.session._pools.get.assert_called_once_with('ip1')
pool.borrow_connection.assert_called_once_with(timeout=ANY) pool.borrow_connection.assert_called_once_with(timeout=ANY)
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY) connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message)
rf._set_result(self.make_mock_response([{'col': 'val'}])) rf._set_result(self.make_mock_response([{'col': 'val'}]))
result = rf.result() result = rf.result()
@@ -189,7 +189,7 @@ class ResponseFutureTests(unittest.TestCase):
rf.session._pools.get.assert_called_once_with('ip1') rf.session._pools.get.assert_called_once_with('ip1')
pool.borrow_connection.assert_called_once_with(timeout=ANY) pool.borrow_connection.assert_called_once_with(timeout=ANY)
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY) connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message)
result = Mock(spec=UnavailableErrorMessage, info={}) result = Mock(spec=UnavailableErrorMessage, info={})
rf._set_result(result) rf._set_result(result)
@@ -207,7 +207,7 @@ class ResponseFutureTests(unittest.TestCase):
# an UnavailableException # an UnavailableException
rf.session._pools.get.assert_called_with('ip1') rf.session._pools.get.assert_called_with('ip1')
pool.borrow_connection.assert_called_with(timeout=ANY) pool.borrow_connection.assert_called_with(timeout=ANY)
connection.send_msg.assert_called_with(rf.message, 2, cb=ANY) connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message)
def test_retry_with_different_host(self): def test_retry_with_different_host(self):
session = self.make_session() session = self.make_session()
@@ -222,7 +222,7 @@ class ResponseFutureTests(unittest.TestCase):
rf.session._pools.get.assert_called_once_with('ip1') rf.session._pools.get.assert_called_once_with('ip1')
pool.borrow_connection.assert_called_once_with(timeout=ANY) pool.borrow_connection.assert_called_once_with(timeout=ANY)
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY) connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message)
self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level) self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level)
result = Mock(spec=OverloadedErrorMessage, info={}) result = Mock(spec=OverloadedErrorMessage, info={})
@@ -240,7 +240,7 @@ class ResponseFutureTests(unittest.TestCase):
# it should try with a different host # it should try with a different host
rf.session._pools.get.assert_called_with('ip2') rf.session._pools.get.assert_called_with('ip2')
pool.borrow_connection.assert_called_with(timeout=ANY) pool.borrow_connection.assert_called_with(timeout=ANY)
connection.send_msg.assert_called_with(rf.message, 2, cb=ANY) connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message)
# the consistency level should be the same # the consistency level should be the same
self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level) self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level)