Isolate custom protocol handling to client requests.

This commit is contained in:
Adam Holmberg
2015-07-16 15:13:19 -05:00
parent 7dcabd744a
commit ab9f690707
10 changed files with 61 additions and 58 deletions

View File

@@ -420,15 +420,6 @@ class Cluster(object):
GeventConnection will be used automatically.
"""
protocol_handler_class = ProtocolHandler
"""
Specifies a protocol handler class, which can be used to override or extend features
such as message or type deserialization.
The class must conform to the public classmethod interface defined in the default
implementation, :class:`cassandra.protocol.ProtocolHandler`
"""
control_connection_timeout = 2.0
"""
A timeout, in seconds, for queries made by the control connection, such
@@ -525,8 +516,7 @@ class Cluster(object):
idle_heartbeat_interval=30,
schema_event_refresh_window=2,
topology_event_refresh_window=10,
connect_timeout=5,
protocol_handler_class=None):
connect_timeout=5):
"""
Any of the mutable Cluster attributes may be set as keyword arguments
to the constructor.
@@ -570,9 +560,6 @@ class Cluster(object):
if connection_class is not None:
self.connection_class = connection_class
if protocol_handler_class is not None:
self.protocol_handler_class = protocol_handler_class
self.metrics_enabled = metrics_enabled
self.ssl_options = ssl_options
self.sockopts = sockopts
@@ -812,7 +799,6 @@ class Cluster(object):
kwargs_dict.setdefault('cql_version', self.cql_version)
kwargs_dict.setdefault('protocol_version', self.protocol_version)
kwargs_dict.setdefault('user_type_map', self._user_types)
kwargs_dict.setdefault('protocol_handler_class', self.protocol_handler_class)
return kwargs_dict
@@ -1372,7 +1358,7 @@ class Cluster(object):
log.debug("Preparing all known prepared statements against host %s", host)
connection = None
try:
connection = self.connection_factory(host.address, protocol_handler_class=ProtocolHandler)
connection = self.connection_factory(host.address)
try:
self.control_connection.wait_for_schema_agreement(connection)
except Exception:
@@ -1535,6 +1521,20 @@ class Session(object):
.. versionadded:: 2.1.0
"""
client_protocol_handler = ProtocolHandler
"""
Specifies a protocol handler that will be used for client-initiated requests (i.e. no
internal driver requests). This can be used to override or extend features such as
message or type ser/des.
The class must conform to the public classmethod interface defined in the default
implementation, :class:`cassandra.protocol.ProtocolHandler`
This is not included in published documentation as it is not intended for the casual user.
It requires knowledge of the native protocol and driver internals. Only advanced, specialized
use cases should need to do anything with this.
"""
_lock = None
_pools = None
_load_balancer = None
@@ -1661,6 +1661,7 @@ class Session(object):
timeout = self.default_timeout
future = self._create_response_future(query, parameters, trace, custom_payload, timeout)
future._protocol_handler = self.client_protocol_handler
future.send_request()
return future
@@ -2131,7 +2132,7 @@ class ControlConnection(object):
while True:
try:
connection = self._cluster.connection_factory(host.address, is_control_connection=True, protocol_handler_class=ProtocolHandler)
connection = self._cluster.connection_factory(host.address, is_control_connection=True)
break
except ProtocolVersionUnsupported as e:
self._cluster.protocol_downgrade(host.address, e.startup_version)
@@ -2846,6 +2847,7 @@ class ResponseFuture(object):
_custom_payload = None
_warnings = None
_timer = None
_protocol_handler = ProtocolHandler
def __init__(self, session, message, query, timeout, metrics=None, prepared_statement=None):
self.session = session
@@ -2919,7 +2921,7 @@ class ResponseFuture(object):
# TODO get connectTimeout from cluster settings
connection, request_id = pool.borrow_connection(timeout=2.0)
self._connection = connection
connection.send_msg(message, request_id, cb=cb)
connection.send_msg(message, request_id, cb=cb, encoder=self._protocol_handler.encode_message, decoder=self._protocol_handler.decode_message)
return request_id
except NoConnectionsAvailable as exc:
log.debug("All connections for host %s are at capacity, moving to the next host", host)

View File

@@ -209,7 +209,7 @@ class Connection(object):
def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
ssl_options=None, sockopts=None, compression=True,
cql_version=None, protocol_version=MAX_SUPPORTED_VERSION, is_control_connection=False,
user_type_map=None, protocol_handler_class=ProtocolHandler):
user_type_map=None):
self.host = host
self.port = port
self.authenticator = authenticator
@@ -220,10 +220,8 @@ class Connection(object):
self.protocol_version = protocol_version
self.is_control_connection = is_control_connection
self.user_type_map = user_type_map
self.decoder = protocol_handler_class.decode_message
self.encoder = protocol_handler_class.encode_message
self._push_watchers = defaultdict(set)
self._callbacks = {}
self._requests = {}
self._iobuf = io.BytesIO()
if protocol_version >= 3:
@@ -320,20 +318,20 @@ class Connection(object):
self.last_error = exc
self.close()
self.error_all_callbacks(exc)
self.error_all_requests(exc)
self.connected_event.set()
return exc
def error_all_callbacks(self, exc):
def error_all_requests(self, exc):
with self.lock:
callbacks = self._callbacks
self._callbacks = {}
requests = self._requests
self._requests = {}
new_exc = ConnectionShutdown(str(exc))
for cb in callbacks.values():
for cb, _ in requests.values():
try:
cb(new_exc)
except Exception:
log.warning("Ignoring unhandled exception while erroring callbacks for a "
log.warning("Ignoring unhandled exception while erroring requests for a "
"failed connection (%s) to host %s:",
id(self), self.host, exc_info=True)
@@ -357,14 +355,16 @@ class Connection(object):
except Exception:
log.exception("Pushed event handler errored, ignoring:")
def send_msg(self, msg, request_id, cb):
def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message):
if self.is_defunct:
raise ConnectionShutdown("Connection to %s is defunct" % self.host)
elif self.is_closed:
raise ConnectionShutdown("Connection to %s is closed" % self.host)
self._callbacks[request_id] = cb
self.push(self.encoder(msg, request_id, self.protocol_version, compressor=self.compressor))
# queue the decoder function with the request
# this allows us to inject custom functions per request to encode, decode messages
self._requests[request_id] = (cb, decoder)
self.push(encoder(msg, request_id, self.protocol_version, compressor=self.compressor))
return request_id
def wait_for_response(self, msg, timeout=None):
@@ -492,16 +492,17 @@ class Connection(object):
stream_id = header.stream
if stream_id < 0:
callback = None
decoder = ProtocolHandler.decode_message
else:
callback = self._callbacks.pop(stream_id, None)
callback, decoder = self._requests.pop(stream_id, None)
with self.lock:
self.request_ids.append(stream_id)
self.msg_received = True
try:
response = self.decoder(header.version, self.user_type_map, stream_id,
header.flags, header.opcode, body, self.decompressor)
response = decoder(header.version, self.user_type_map, stream_id,
header.flags, header.opcode, body, self.decompressor)
except Exception as exc:
log.exception("Error decoding response from Cassandra. "
"opcode: %04x; message contents: %r", header.opcode, body)

View File

@@ -183,7 +183,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
log.debug("Closed socket to %s", self.host)
if not self.is_defunct:
self.error_all_callbacks(
self.error_all_requests(
ConnectionShutdown("Connection to %s was closed" % self.host))
# don't leave in-progress operations hanging
self.connected_event.set()
@@ -239,7 +239,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
if self._iobuf.tell():
self.process_io_buffer()
if not self._callbacks and not self.is_control_connection:
if not self._requests and not self.is_control_connection:
self._readable = False
def push(self, data):

View File

@@ -116,7 +116,7 @@ class EventletConnection(Connection):
log.debug("Closed socket to %s" % (self.host,))
if not self.is_defunct:
self.error_all_callbacks(
self.error_all_requests(
ConnectionShutdown("Connection to %s was closed" % self.host))
# don't leave in-progress operations hanging
self.connected_event.set()

View File

@@ -106,7 +106,7 @@ class GeventConnection(Connection):
log.debug("Closed socket to %s" % (self.host,))
if not self.is_defunct:
self.error_all_callbacks(
self.error_all_requests(
ConnectionShutdown("Connection to %s was closed" % self.host))
# don't leave in-progress operations hanging
self.connected_event.set()

View File

@@ -290,7 +290,7 @@ class LibevConnection(Connection):
# don't leave in-progress operations hanging
if not self.is_defunct:
self.error_all_callbacks(
self.error_all_requests(
ConnectionShutdown("Connection to %s was closed" % self.host))
def handle_write(self, watcher, revents, errno=None):

View File

@@ -96,7 +96,7 @@ class TwistedConnectionClientFactory(protocol.ClientFactory):
It should be safe to call defunct() here instead of just close, because
we can assume that if the connection was closed cleanly, there are no
callbacks to error out. If this assumption turns out to be false, we
requests to error out. If this assumption turns out to be false, we
can call close() instead of defunct() when "reason" is an appropriate
type.
"""
@@ -213,7 +213,7 @@ class TwistedConnection(Connection):
def close(self):
"""
Disconnect and error-out all callbacks.
Disconnect and error-out all requests.
"""
with self.lock:
if self.is_closed:
@@ -225,7 +225,7 @@ class TwistedConnection(Connection):
log.debug("Closed socket to %s", self.host)
if not self.is_defunct:
self.error_all_callbacks(
self.error_all_requests(
ConnectionShutdown("Connection to %s was closed" % self.host))
# don't leave in-progress operations hanging
self.connected_event.set()

View File

@@ -140,13 +140,13 @@ class TestTwistedConnection(unittest.TestCase):
"""
Verify that close() disconnects the connector and errors callbacks.
"""
self.obj_ut.error_all_callbacks = Mock()
self.obj_ut.error_all_requests = Mock()
self.obj_ut.add_connection()
self.obj_ut.is_closed = False
self.obj_ut.close()
self.obj_ut.connector.disconnect.assert_called_with()
self.assertTrue(self.obj_ut.connected_event.is_set())
self.assertTrue(self.obj_ut.error_all_callbacks.called)
self.assertTrue(self.obj_ut.error_all_requests.called)
def test_handle_read__incomplete(self):
"""

View File

@@ -27,7 +27,7 @@ from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, Protoc
locally_supported_compressions, ConnectionHeartbeat, _Frame)
from cassandra.marshal import uint8_pack, uint32_pack, int32_pack
from cassandra.protocol import (write_stringmultimap, write_int, write_string,
SupportedMessage)
SupportedMessage, ProtocolHandler)
class ConnectionTest(unittest.TestCase):
@@ -75,7 +75,7 @@ class ConnectionTest(unittest.TestCase):
def test_bad_protocol_version(self, *args):
c = self.make_connection()
c._callbacks = Mock()
c._requests = Mock()
c.defunct = Mock()
# read in a SupportedMessage response
@@ -93,7 +93,7 @@ class ConnectionTest(unittest.TestCase):
def test_negative_body_length(self, *args):
c = self.make_connection()
c._callbacks = Mock()
c._requests = Mock()
c.defunct = Mock()
# read in a SupportedMessage response
@@ -110,7 +110,7 @@ class ConnectionTest(unittest.TestCase):
def test_unsupported_cql_version(self, *args):
c = self.make_connection()
c._callbacks = {0: c._handle_options_response}
c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)}
c.defunct = Mock()
c.cql_version = "3.0.3"
@@ -133,7 +133,7 @@ class ConnectionTest(unittest.TestCase):
def test_prefer_lz4_compression(self, *args):
c = self.make_connection()
c._callbacks = {0: c._handle_options_response}
c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)}
c.defunct = Mock()
c.cql_version = "3.0.3"
@@ -156,7 +156,7 @@ class ConnectionTest(unittest.TestCase):
def test_requested_compression_not_available(self, *args):
c = self.make_connection()
c._callbacks = {0: c._handle_options_response}
c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)}
c.defunct = Mock()
# request lz4 compression
c.compression = "lz4"
@@ -186,7 +186,7 @@ class ConnectionTest(unittest.TestCase):
def test_use_requested_compression(self, *args):
c = self.make_connection()
c._callbacks = {0: c._handle_options_response}
c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)}
c.defunct = Mock()
# request snappy compression
c.compression = "snappy"
@@ -213,7 +213,7 @@ class ConnectionTest(unittest.TestCase):
def test_disable_compression(self, *args):
c = self.make_connection()
c._callbacks = {0: c._handle_options_response}
c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)}
c.defunct = Mock()
# disable compression
c.compression = False

View File

@@ -27,7 +27,7 @@ from cassandra.protocol import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessag
OverloadedErrorMessage, IsBootstrappingErrorMessage,
PreparedQueryNotFound, PrepareMessage,
RESULT_KIND_ROWS, RESULT_KIND_SET_KEYSPACE,
RESULT_KIND_SCHEMA_CHANGE)
RESULT_KIND_SCHEMA_CHANGE, ProtocolHandler)
from cassandra.policies import RetryPolicy
from cassandra.pool import NoConnectionsAvailable
from cassandra.query import SimpleStatement
@@ -67,7 +67,7 @@ class ResponseFutureTests(unittest.TestCase):
rf.session._pools.get.assert_called_once_with('ip1')
pool.borrow_connection.assert_called_once_with(timeout=ANY)
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY)
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message)
rf._set_result(self.make_mock_response([{'col': 'val'}]))
result = rf.result()
@@ -189,7 +189,7 @@ class ResponseFutureTests(unittest.TestCase):
rf.session._pools.get.assert_called_once_with('ip1')
pool.borrow_connection.assert_called_once_with(timeout=ANY)
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY)
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message)
result = Mock(spec=UnavailableErrorMessage, info={})
rf._set_result(result)
@@ -207,7 +207,7 @@ class ResponseFutureTests(unittest.TestCase):
# an UnavailableException
rf.session._pools.get.assert_called_with('ip1')
pool.borrow_connection.assert_called_with(timeout=ANY)
connection.send_msg.assert_called_with(rf.message, 2, cb=ANY)
connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message)
def test_retry_with_different_host(self):
session = self.make_session()
@@ -222,7 +222,7 @@ class ResponseFutureTests(unittest.TestCase):
rf.session._pools.get.assert_called_once_with('ip1')
pool.borrow_connection.assert_called_once_with(timeout=ANY)
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY)
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message)
self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level)
result = Mock(spec=OverloadedErrorMessage, info={})
@@ -240,7 +240,7 @@ class ResponseFutureTests(unittest.TestCase):
# it should try with a different host
rf.session._pools.get.assert_called_with('ip2')
pool.borrow_connection.assert_called_with(timeout=ANY)
connection.send_msg.assert_called_with(rf.message, 2, cb=ANY)
connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message)
# the consistency level should be the same
self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level)