Add protocol_hander_class to allow extension, serdes specialization
This commit is contained in:
@@ -60,7 +60,7 @@ from cassandra.protocol import (QueryMessage, ResultMessage,
|
||||
IsBootstrappingErrorMessage,
|
||||
BatchMessage, RESULT_KIND_PREPARED,
|
||||
RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS,
|
||||
RESULT_KIND_SCHEMA_CHANGE, MIN_SUPPORTED_VERSION)
|
||||
RESULT_KIND_SCHEMA_CHANGE, MIN_SUPPORTED_VERSION, ProtocolHandler)
|
||||
from cassandra.metadata import Metadata, protect_name, murmur3
|
||||
from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy,
|
||||
ExponentialReconnectionPolicy, HostDistance,
|
||||
@@ -419,6 +419,15 @@ class Cluster(object):
|
||||
GeventConnection will be used automatically.
|
||||
"""
|
||||
|
||||
protocol_handler_class = ProtocolHandler
|
||||
"""
|
||||
Specifies a protocol handler class, which can be used to override or extend features
|
||||
such as message or type deserialization.
|
||||
|
||||
The class must conform to the public classmethod interface defined in the default
|
||||
implementation, :class:`cassandra.protocol.ProtocolHandler`
|
||||
"""
|
||||
|
||||
control_connection_timeout = 2.0
|
||||
"""
|
||||
A timeout, in seconds, for queries made by the control connection, such
|
||||
@@ -515,7 +524,8 @@ class Cluster(object):
|
||||
idle_heartbeat_interval=30,
|
||||
schema_event_refresh_window=2,
|
||||
topology_event_refresh_window=10,
|
||||
connect_timeout=5):
|
||||
connect_timeout=5,
|
||||
protocol_handler_class=None):
|
||||
"""
|
||||
Any of the mutable Cluster attributes may be set as keyword arguments
|
||||
to the constructor.
|
||||
@@ -559,6 +569,9 @@ class Cluster(object):
|
||||
if connection_class is not None:
|
||||
self.connection_class = connection_class
|
||||
|
||||
if protocol_handler_class is not None:
|
||||
self.protocol_handler_class = protocol_handler_class
|
||||
|
||||
self.metrics_enabled = metrics_enabled
|
||||
self.ssl_options = ssl_options
|
||||
self.sockopts = sockopts
|
||||
@@ -798,6 +811,7 @@ class Cluster(object):
|
||||
kwargs_dict['cql_version'] = self.cql_version
|
||||
kwargs_dict['protocol_version'] = self.protocol_version
|
||||
kwargs_dict['user_type_map'] = self._user_types
|
||||
kwargs_dict['protocol_handler_class'] = self.protocol_handler_class
|
||||
|
||||
return kwargs_dict
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut
|
||||
from cassandra.marshal import int32_pack, uint8_unpack
|
||||
from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessage,
|
||||
StartupMessage, ErrorMessage, CredentialsMessage,
|
||||
QueryMessage, ResultMessage, decode_response,
|
||||
QueryMessage, ResultMessage, ProtocolHandler,
|
||||
InvalidRequestException, SupportedMessage,
|
||||
AuthResponseMessage, AuthChallengeMessage,
|
||||
AuthSuccessMessage, ProtocolException,
|
||||
@@ -209,7 +209,7 @@ class Connection(object):
|
||||
def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
|
||||
ssl_options=None, sockopts=None, compression=True,
|
||||
cql_version=None, protocol_version=MAX_SUPPORTED_VERSION, is_control_connection=False,
|
||||
user_type_map=None):
|
||||
user_type_map=None, protocol_handler_class=ProtocolHandler):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.authenticator = authenticator
|
||||
@@ -220,6 +220,8 @@ class Connection(object):
|
||||
self.protocol_version = protocol_version
|
||||
self.is_control_connection = is_control_connection
|
||||
self.user_type_map = user_type_map
|
||||
self.decoder = protocol_handler_class.decode_message
|
||||
self.encoder = protocol_handler_class.encode_message
|
||||
self._push_watchers = defaultdict(set)
|
||||
self._callbacks = {}
|
||||
self._iobuf = io.BytesIO()
|
||||
@@ -362,7 +364,7 @@ class Connection(object):
|
||||
raise ConnectionShutdown("Connection to %s is closed" % self.host)
|
||||
|
||||
self._callbacks[request_id] = cb
|
||||
self.push(msg.to_binary(request_id, self.protocol_version, compression=self.compressor))
|
||||
self.push(self.encoder(msg, request_id, self.protocol_version, compressor=self.compressor))
|
||||
return request_id
|
||||
|
||||
def wait_for_response(self, msg, timeout=None):
|
||||
@@ -498,8 +500,8 @@ class Connection(object):
|
||||
self.msg_received = True
|
||||
|
||||
try:
|
||||
response = decode_response(header.version, self.user_type_map, stream_id,
|
||||
header.flags, header.opcode, body, self.decompressor)
|
||||
response = self.decoder(header.version, self.user_type_map, stream_id,
|
||||
header.flags, header.opcode, body, self.decompressor)
|
||||
except Exception as exc:
|
||||
log.exception("Error decoding response from Cassandra. "
|
||||
"opcode: %04x; message contents: %r", header.opcode, body)
|
||||
|
||||
@@ -83,29 +83,6 @@ class _MessageType(object):
|
||||
custom_payload = None
|
||||
warnings = None
|
||||
|
||||
def to_binary(self, stream_id, protocol_version, compression=None):
|
||||
flags = 0
|
||||
body = io.BytesIO()
|
||||
if self.custom_payload:
|
||||
if protocol_version < 4:
|
||||
raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher")
|
||||
flags |= CUSTOM_PAYLOAD_FLAG
|
||||
write_bytesmap(body, self.custom_payload)
|
||||
self.send_body(body, protocol_version)
|
||||
body = body.getvalue()
|
||||
|
||||
if compression and len(body) > 0:
|
||||
body = compression(body)
|
||||
flags |= COMPRESSED_FLAG
|
||||
if self.tracing:
|
||||
flags |= TRACING_FLAG
|
||||
|
||||
msg = io.BytesIO()
|
||||
write_header(msg, protocol_version, flags, stream_id, self.opcode, len(body))
|
||||
msg.write(body)
|
||||
|
||||
return msg.getvalue()
|
||||
|
||||
def update_custom_payload(self, other):
|
||||
if other:
|
||||
if not self.custom_payload:
|
||||
@@ -126,50 +103,6 @@ def _get_params(message_obj):
|
||||
)
|
||||
|
||||
|
||||
def decode_response(protocol_version, user_type_map, stream_id, flags, opcode, body,
|
||||
decompressor=None):
|
||||
if flags & COMPRESSED_FLAG:
|
||||
if decompressor is None:
|
||||
raise Exception("No de-compressor available for compressed frame!")
|
||||
body = decompressor(body)
|
||||
flags ^= COMPRESSED_FLAG
|
||||
|
||||
body = io.BytesIO(body)
|
||||
if flags & TRACING_FLAG:
|
||||
trace_id = UUID(bytes=body.read(16))
|
||||
flags ^= TRACING_FLAG
|
||||
else:
|
||||
trace_id = None
|
||||
|
||||
if flags & WARNING_FLAG:
|
||||
warnings = read_stringlist(body)
|
||||
flags ^= WARNING_FLAG
|
||||
else:
|
||||
warnings = None
|
||||
|
||||
if flags & CUSTOM_PAYLOAD_FLAG:
|
||||
custom_payload = read_bytesmap(body)
|
||||
flags ^= CUSTOM_PAYLOAD_FLAG
|
||||
else:
|
||||
custom_payload = None
|
||||
|
||||
if flags:
|
||||
log.warning("Unknown protocol flags set: %02x. May cause problems.", flags)
|
||||
|
||||
msg_class = _message_types_by_opcode[opcode]
|
||||
msg = msg_class.recv_body(body, protocol_version, user_type_map)
|
||||
msg.stream_id = stream_id
|
||||
msg.trace_id = trace_id
|
||||
msg.custom_payload = custom_payload
|
||||
msg.warnings = warnings
|
||||
|
||||
if msg.warnings:
|
||||
for w in msg.warnings:
|
||||
log.warning("Server warning: %s", w)
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
error_classes = {}
|
||||
|
||||
|
||||
@@ -609,7 +542,7 @@ class ResultMessage(_MessageType):
|
||||
results = None
|
||||
paging_state = None
|
||||
|
||||
_type_codes = {
|
||||
type_codes = {
|
||||
0x0000: CUSTOM_TYPE,
|
||||
0x0001: AsciiType,
|
||||
0x0002: LongType,
|
||||
@@ -744,7 +677,7 @@ class ResultMessage(_MessageType):
|
||||
def read_type(cls, f, user_type_map):
|
||||
optid = read_short(f)
|
||||
try:
|
||||
typeclass = cls._type_codes[optid]
|
||||
typeclass = cls.type_codes[optid]
|
||||
except KeyError:
|
||||
raise NotSupportedError("Unknown data type code 0x%04x. Have to skip"
|
||||
" entire result set." % (optid,))
|
||||
@@ -964,13 +897,122 @@ class EventMessage(_MessageType):
|
||||
return event
|
||||
|
||||
|
||||
def write_header(f, version, flags, stream_id, opcode, length):
|
||||
class ProtocolHandler(object):
|
||||
"""
|
||||
Write a CQL protocol frame header.
|
||||
ProtocolHander handles encoding and decoding messages.
|
||||
|
||||
This class can be specialized to compose Handlers which implement alternative
|
||||
result decoding or type deserialization. Class definitions are passed to :class:`cassandra.cluster.Cluster`
|
||||
on initialization.
|
||||
|
||||
Contracted class methods are :meth:`ProtocolHandler.encode_message` and :meth:`ProtocolHandler.decode_message`.
|
||||
"""
|
||||
pack = v3_header_pack if version >= 3 else header_pack
|
||||
f.write(pack(version, flags, stream_id, opcode))
|
||||
write_int(f, length)
|
||||
|
||||
message_types_by_opcode = _message_types_by_opcode.copy()
|
||||
"""
|
||||
Default mapping of opcode to Message implementation. The default ``decode_message`` implementation uses
|
||||
this to instantiate a message and populate using ``recv_body``. This mapping can be updated to inject specialized
|
||||
result decoding implementations.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def encode_message(cls, msg, stream_id, protocol_version, compressor):
|
||||
"""
|
||||
Encodes a message using the specified frame parameters, and compressor
|
||||
|
||||
:param msg: the message, typically of cassandra.protocol._MessageType, generated by the driver
|
||||
:param stream_id: protocol stream id for the frame header
|
||||
:param protocol_version: version for the frame header, and used encoding contents
|
||||
:param compressor: optional compression function to be used on the body
|
||||
:return:
|
||||
"""
|
||||
flags = 0
|
||||
body = io.BytesIO()
|
||||
if msg.custom_payload:
|
||||
if protocol_version < 4:
|
||||
raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher")
|
||||
flags |= CUSTOM_PAYLOAD_FLAG
|
||||
write_bytesmap(body, msg.custom_payload)
|
||||
msg.send_body(body, protocol_version)
|
||||
body = body.getvalue()
|
||||
|
||||
if compressor and len(body) > 0:
|
||||
body = compressor(body)
|
||||
flags |= COMPRESSED_FLAG
|
||||
|
||||
if msg.tracing:
|
||||
flags |= TRACING_FLAG
|
||||
|
||||
buff = io.BytesIO()
|
||||
cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, len(body))
|
||||
buff.write(body)
|
||||
|
||||
return buff.getvalue()
|
||||
|
||||
@staticmethod
|
||||
def _write_header(f, version, flags, stream_id, opcode, length):
|
||||
"""
|
||||
Write a CQL protocol frame header.
|
||||
"""
|
||||
pack = v3_header_pack if version >= 3 else header_pack
|
||||
f.write(pack(version, flags, stream_id, opcode))
|
||||
write_int(f, length)
|
||||
|
||||
@classmethod
|
||||
def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcode, body,
|
||||
decompressor):
|
||||
"""
|
||||
Decodes a native protocol message body
|
||||
|
||||
:param protocol_version: version to use decoding contents
|
||||
:param user_type_map: map[keyspace name] = map[type name] = custom type to instantiate when deserializing this type
|
||||
:param stream_id: native protocol stream id from the frame header
|
||||
:param flags: native protocol flags bitmap from the header
|
||||
:param opcode: native protocol opcode from the header
|
||||
:param body: frame body
|
||||
:param decompressor: optional decompression function to inflate the body
|
||||
:return: a message decoded from the body and frame attributes
|
||||
"""
|
||||
if flags & COMPRESSED_FLAG:
|
||||
if decompressor is None:
|
||||
raise Exception("No de-compressor available for compressed frame!")
|
||||
body = decompressor(body)
|
||||
flags ^= COMPRESSED_FLAG
|
||||
|
||||
body = io.BytesIO(body)
|
||||
if flags & TRACING_FLAG:
|
||||
trace_id = UUID(bytes=body.read(16))
|
||||
flags ^= TRACING_FLAG
|
||||
else:
|
||||
trace_id = None
|
||||
|
||||
if flags & WARNING_FLAG:
|
||||
warnings = read_stringlist(body)
|
||||
flags ^= WARNING_FLAG
|
||||
else:
|
||||
warnings = None
|
||||
|
||||
if flags & CUSTOM_PAYLOAD_FLAG:
|
||||
custom_payload = read_bytesmap(body)
|
||||
flags ^= CUSTOM_PAYLOAD_FLAG
|
||||
else:
|
||||
custom_payload = None
|
||||
|
||||
if flags:
|
||||
log.warning("Unknown protocol flags set: %02x. May cause problems.", flags)
|
||||
|
||||
msg_class = cls.message_types_by_opcode[opcode]
|
||||
msg = msg_class.recv_body(body, protocol_version, user_type_map)
|
||||
msg.stream_id = stream_id
|
||||
msg.trace_id = trace_id
|
||||
msg.custom_payload = custom_payload
|
||||
msg.warnings = warnings
|
||||
|
||||
if msg.warnings:
|
||||
for w in msg.warnings:
|
||||
log.warning("Server warning: %s", w)
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
def read_byte(f):
|
||||
|
||||
@@ -27,6 +27,8 @@
|
||||
|
||||
.. autoattribute:: connection_class
|
||||
|
||||
.. autoattribute:: protocol_handler_class
|
||||
|
||||
.. autoattribute:: metrics_enabled
|
||||
|
||||
.. autoattribute:: metrics
|
||||
|
||||
@@ -15,3 +15,12 @@ By default these are ignored by the server. They can be useful for servers imple
|
||||
a custom QueryHandler.
|
||||
|
||||
See :meth:`.Session.execute`, ::meth:`.Session.execute_async`, :attr:`.ResponseFuture.custom_payload`.
|
||||
|
||||
.. autoclass:: ProtocolHandler
|
||||
|
||||
.. autoattribute:: message_types_by_opcode
|
||||
:annotation: = {default mapping}
|
||||
|
||||
.. automethod:: encode_message
|
||||
|
||||
.. automethod:: decode_message
|
||||
|
||||
@@ -80,7 +80,7 @@ def _tuple_version(version_string):
|
||||
|
||||
USE_CASS_EXTERNAL = bool(os.getenv('USE_CASS_EXTERNAL', False))
|
||||
|
||||
default_cassandra_version = '2.1.5'
|
||||
default_cassandra_version = '2.1.6'
|
||||
|
||||
if USE_CASS_EXTERNAL:
|
||||
if CCMClusterFactory:
|
||||
|
||||
Reference in New Issue
Block a user