Add protocol_hander_class to allow extension, serdes specialization

This commit is contained in:
Adam Holmberg
2015-06-22 17:11:43 -05:00
parent a327ab13a5
commit 22b92df9f9
6 changed files with 151 additions and 82 deletions

View File

@@ -60,7 +60,7 @@ from cassandra.protocol import (QueryMessage, ResultMessage,
IsBootstrappingErrorMessage,
BatchMessage, RESULT_KIND_PREPARED,
RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS,
RESULT_KIND_SCHEMA_CHANGE, MIN_SUPPORTED_VERSION)
RESULT_KIND_SCHEMA_CHANGE, MIN_SUPPORTED_VERSION, ProtocolHandler)
from cassandra.metadata import Metadata, protect_name, murmur3
from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy,
ExponentialReconnectionPolicy, HostDistance,
@@ -419,6 +419,15 @@ class Cluster(object):
GeventConnection will be used automatically.
"""
protocol_handler_class = ProtocolHandler
"""
Specifies a protocol handler class, which can be used to override or extend features
such as message or type deserialization.
The class must conform to the public classmethod interface defined in the default
implementation, :class:`cassandra.protocol.ProtocolHandler`
"""
control_connection_timeout = 2.0
"""
A timeout, in seconds, for queries made by the control connection, such
@@ -515,7 +524,8 @@ class Cluster(object):
idle_heartbeat_interval=30,
schema_event_refresh_window=2,
topology_event_refresh_window=10,
connect_timeout=5):
connect_timeout=5,
protocol_handler_class=None):
"""
Any of the mutable Cluster attributes may be set as keyword arguments
to the constructor.
@@ -559,6 +569,9 @@ class Cluster(object):
if connection_class is not None:
self.connection_class = connection_class
if protocol_handler_class is not None:
self.protocol_handler_class = protocol_handler_class
self.metrics_enabled = metrics_enabled
self.ssl_options = ssl_options
self.sockopts = sockopts
@@ -798,6 +811,7 @@ class Cluster(object):
kwargs_dict['cql_version'] = self.cql_version
kwargs_dict['protocol_version'] = self.protocol_version
kwargs_dict['user_type_map'] = self._user_types
kwargs_dict['protocol_handler_class'] = self.protocol_handler_class
return kwargs_dict

View File

@@ -42,7 +42,7 @@ from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut
from cassandra.marshal import int32_pack, uint8_unpack
from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessage,
StartupMessage, ErrorMessage, CredentialsMessage,
QueryMessage, ResultMessage, decode_response,
QueryMessage, ResultMessage, ProtocolHandler,
InvalidRequestException, SupportedMessage,
AuthResponseMessage, AuthChallengeMessage,
AuthSuccessMessage, ProtocolException,
@@ -209,7 +209,7 @@ class Connection(object):
def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
ssl_options=None, sockopts=None, compression=True,
cql_version=None, protocol_version=MAX_SUPPORTED_VERSION, is_control_connection=False,
user_type_map=None):
user_type_map=None, protocol_handler_class=ProtocolHandler):
self.host = host
self.port = port
self.authenticator = authenticator
@@ -220,6 +220,8 @@ class Connection(object):
self.protocol_version = protocol_version
self.is_control_connection = is_control_connection
self.user_type_map = user_type_map
self.decoder = protocol_handler_class.decode_message
self.encoder = protocol_handler_class.encode_message
self._push_watchers = defaultdict(set)
self._callbacks = {}
self._iobuf = io.BytesIO()
@@ -362,7 +364,7 @@ class Connection(object):
raise ConnectionShutdown("Connection to %s is closed" % self.host)
self._callbacks[request_id] = cb
self.push(msg.to_binary(request_id, self.protocol_version, compression=self.compressor))
self.push(self.encoder(msg, request_id, self.protocol_version, compressor=self.compressor))
return request_id
def wait_for_response(self, msg, timeout=None):
@@ -498,8 +500,8 @@ class Connection(object):
self.msg_received = True
try:
response = decode_response(header.version, self.user_type_map, stream_id,
header.flags, header.opcode, body, self.decompressor)
response = self.decoder(header.version, self.user_type_map, stream_id,
header.flags, header.opcode, body, self.decompressor)
except Exception as exc:
log.exception("Error decoding response from Cassandra. "
"opcode: %04x; message contents: %r", header.opcode, body)

View File

@@ -83,29 +83,6 @@ class _MessageType(object):
custom_payload = None
warnings = None
def to_binary(self, stream_id, protocol_version, compression=None):
flags = 0
body = io.BytesIO()
if self.custom_payload:
if protocol_version < 4:
raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher")
flags |= CUSTOM_PAYLOAD_FLAG
write_bytesmap(body, self.custom_payload)
self.send_body(body, protocol_version)
body = body.getvalue()
if compression and len(body) > 0:
body = compression(body)
flags |= COMPRESSED_FLAG
if self.tracing:
flags |= TRACING_FLAG
msg = io.BytesIO()
write_header(msg, protocol_version, flags, stream_id, self.opcode, len(body))
msg.write(body)
return msg.getvalue()
def update_custom_payload(self, other):
if other:
if not self.custom_payload:
@@ -126,50 +103,6 @@ def _get_params(message_obj):
)
def decode_response(protocol_version, user_type_map, stream_id, flags, opcode, body,
decompressor=None):
if flags & COMPRESSED_FLAG:
if decompressor is None:
raise Exception("No de-compressor available for compressed frame!")
body = decompressor(body)
flags ^= COMPRESSED_FLAG
body = io.BytesIO(body)
if flags & TRACING_FLAG:
trace_id = UUID(bytes=body.read(16))
flags ^= TRACING_FLAG
else:
trace_id = None
if flags & WARNING_FLAG:
warnings = read_stringlist(body)
flags ^= WARNING_FLAG
else:
warnings = None
if flags & CUSTOM_PAYLOAD_FLAG:
custom_payload = read_bytesmap(body)
flags ^= CUSTOM_PAYLOAD_FLAG
else:
custom_payload = None
if flags:
log.warning("Unknown protocol flags set: %02x. May cause problems.", flags)
msg_class = _message_types_by_opcode[opcode]
msg = msg_class.recv_body(body, protocol_version, user_type_map)
msg.stream_id = stream_id
msg.trace_id = trace_id
msg.custom_payload = custom_payload
msg.warnings = warnings
if msg.warnings:
for w in msg.warnings:
log.warning("Server warning: %s", w)
return msg
error_classes = {}
@@ -609,7 +542,7 @@ class ResultMessage(_MessageType):
results = None
paging_state = None
_type_codes = {
type_codes = {
0x0000: CUSTOM_TYPE,
0x0001: AsciiType,
0x0002: LongType,
@@ -744,7 +677,7 @@ class ResultMessage(_MessageType):
def read_type(cls, f, user_type_map):
optid = read_short(f)
try:
typeclass = cls._type_codes[optid]
typeclass = cls.type_codes[optid]
except KeyError:
raise NotSupportedError("Unknown data type code 0x%04x. Have to skip"
" entire result set." % (optid,))
@@ -964,13 +897,122 @@ class EventMessage(_MessageType):
return event
def write_header(f, version, flags, stream_id, opcode, length):
class ProtocolHandler(object):
"""
Write a CQL protocol frame header.
ProtocolHander handles encoding and decoding messages.
This class can be specialized to compose Handlers which implement alternative
result decoding or type deserialization. Class definitions are passed to :class:`cassandra.cluster.Cluster`
on initialization.
Contracted class methods are :meth:`ProtocolHandler.encode_message` and :meth:`ProtocolHandler.decode_message`.
"""
pack = v3_header_pack if version >= 3 else header_pack
f.write(pack(version, flags, stream_id, opcode))
write_int(f, length)
message_types_by_opcode = _message_types_by_opcode.copy()
"""
Default mapping of opcode to Message implementation. The default ``decode_message`` implementation uses
this to instantiate a message and populate using ``recv_body``. This mapping can be updated to inject specialized
result decoding implementations.
"""
@classmethod
def encode_message(cls, msg, stream_id, protocol_version, compressor):
"""
Encodes a message using the specified frame parameters, and compressor
:param msg: the message, typically of cassandra.protocol._MessageType, generated by the driver
:param stream_id: protocol stream id for the frame header
:param protocol_version: version for the frame header, and used encoding contents
:param compressor: optional compression function to be used on the body
:return:
"""
flags = 0
body = io.BytesIO()
if msg.custom_payload:
if protocol_version < 4:
raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher")
flags |= CUSTOM_PAYLOAD_FLAG
write_bytesmap(body, msg.custom_payload)
msg.send_body(body, protocol_version)
body = body.getvalue()
if compressor and len(body) > 0:
body = compressor(body)
flags |= COMPRESSED_FLAG
if msg.tracing:
flags |= TRACING_FLAG
buff = io.BytesIO()
cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, len(body))
buff.write(body)
return buff.getvalue()
@staticmethod
def _write_header(f, version, flags, stream_id, opcode, length):
"""
Write a CQL protocol frame header.
"""
pack = v3_header_pack if version >= 3 else header_pack
f.write(pack(version, flags, stream_id, opcode))
write_int(f, length)
@classmethod
def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcode, body,
decompressor):
"""
Decodes a native protocol message body
:param protocol_version: version to use decoding contents
:param user_type_map: map[keyspace name] = map[type name] = custom type to instantiate when deserializing this type
:param stream_id: native protocol stream id from the frame header
:param flags: native protocol flags bitmap from the header
:param opcode: native protocol opcode from the header
:param body: frame body
:param decompressor: optional decompression function to inflate the body
:return: a message decoded from the body and frame attributes
"""
if flags & COMPRESSED_FLAG:
if decompressor is None:
raise Exception("No de-compressor available for compressed frame!")
body = decompressor(body)
flags ^= COMPRESSED_FLAG
body = io.BytesIO(body)
if flags & TRACING_FLAG:
trace_id = UUID(bytes=body.read(16))
flags ^= TRACING_FLAG
else:
trace_id = None
if flags & WARNING_FLAG:
warnings = read_stringlist(body)
flags ^= WARNING_FLAG
else:
warnings = None
if flags & CUSTOM_PAYLOAD_FLAG:
custom_payload = read_bytesmap(body)
flags ^= CUSTOM_PAYLOAD_FLAG
else:
custom_payload = None
if flags:
log.warning("Unknown protocol flags set: %02x. May cause problems.", flags)
msg_class = cls.message_types_by_opcode[opcode]
msg = msg_class.recv_body(body, protocol_version, user_type_map)
msg.stream_id = stream_id
msg.trace_id = trace_id
msg.custom_payload = custom_payload
msg.warnings = warnings
if msg.warnings:
for w in msg.warnings:
log.warning("Server warning: %s", w)
return msg
def read_byte(f):

View File

@@ -27,6 +27,8 @@
.. autoattribute:: connection_class
.. autoattribute:: protocol_handler_class
.. autoattribute:: metrics_enabled
.. autoattribute:: metrics

View File

@@ -15,3 +15,12 @@ By default these are ignored by the server. They can be useful for servers imple
a custom QueryHandler.
See :meth:`.Session.execute`, ::meth:`.Session.execute_async`, :attr:`.ResponseFuture.custom_payload`.
.. autoclass:: ProtocolHandler
.. autoattribute:: message_types_by_opcode
:annotation: = {default mapping}
.. automethod:: encode_message
.. automethod:: decode_message

View File

@@ -80,7 +80,7 @@ def _tuple_version(version_string):
USE_CASS_EXTERNAL = bool(os.getenv('USE_CASS_EXTERNAL', False))
default_cassandra_version = '2.1.5'
default_cassandra_version = '2.1.6'
if USE_CASS_EXTERNAL:
if CCMClusterFactory: