Merge pull request #276 from datastax/PYTHON-280
PYTHON-280 - Custom payload for protocol v4
This commit is contained in:
@@ -78,7 +78,6 @@ Bug Fixes
|
||||
---------
|
||||
* Make execute_concurrent compatible with Python 2.6 (PYTHON-159)
|
||||
* Handle Unauthorized message on schema_triggers query (PYTHON-155)
|
||||
* Make execute_concurrent compatible with Python 2.6 (github-197)
|
||||
* Pure Python sorted set in support of UDTs nested in collections (PYTON-167)
|
||||
* Support CUSTOM index metadata and string export (PYTHON-165)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ try:
|
||||
except ImportError:
|
||||
SASLClient = None
|
||||
|
||||
|
||||
class AuthProvider(object):
|
||||
"""
|
||||
An abstract class that defines the interface that will be used for
|
||||
@@ -157,6 +158,7 @@ class SaslAuthProvider(AuthProvider):
|
||||
def new_authenticator(self, host):
|
||||
return SaslAuthenticator(**self.sasl_kwargs)
|
||||
|
||||
|
||||
class SaslAuthenticator(Authenticator):
|
||||
"""
|
||||
A pass-through :class:`~.Authenticator` using the third party package
|
||||
|
||||
@@ -1383,7 +1383,7 @@ class Session(object):
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False):
|
||||
def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False, custom_payload=None):
|
||||
"""
|
||||
Execute the given query and synchronously wait for the response.
|
||||
|
||||
@@ -1411,6 +1411,10 @@ class Session(object):
|
||||
instance and not just a string. If there is an error fetching the
|
||||
trace details, the :attr:`~.Statement.trace` attribute will be left as
|
||||
:const:`None`.
|
||||
|
||||
`custom_payload` is a :ref:`custom_payload` dict to be passed to the server.
|
||||
If `query` is a Statement with its own custom_payload. The message payload
|
||||
will be a union of the two, with the values specified here taking precedence.
|
||||
"""
|
||||
if timeout is _NOT_SET:
|
||||
timeout = self.default_timeout
|
||||
@@ -1420,7 +1424,7 @@ class Session(object):
|
||||
"The query argument must be an instance of a subclass of "
|
||||
"cassandra.query.Statement when trace=True")
|
||||
|
||||
future = self.execute_async(query, parameters, trace)
|
||||
future = self.execute_async(query, parameters, trace, custom_payload)
|
||||
try:
|
||||
result = future.result(timeout)
|
||||
finally:
|
||||
@@ -1432,7 +1436,7 @@ class Session(object):
|
||||
|
||||
return result
|
||||
|
||||
def execute_async(self, query, parameters=None, trace=False):
|
||||
def execute_async(self, query, parameters=None, trace=False, custom_payload=None):
|
||||
"""
|
||||
Execute the given query and return a :class:`~.ResponseFuture` object
|
||||
which callbacks may be attached to for asynchronous response
|
||||
@@ -1444,6 +1448,14 @@ class Session(object):
|
||||
:meth:`.ResponseFuture.get_query_trace()` after the request
|
||||
completes to retrieve a :class:`.QueryTrace` instance.
|
||||
|
||||
`custom_payload` is a :ref:`custom_payload` dict to be passed to the server.
|
||||
If `query` is a Statement with its own custom_payload. The message payload
|
||||
will be a union of the two, with the values specified here taking precedence.
|
||||
|
||||
If the server sends a custom payload in the response message,
|
||||
the dict can be obtained following :meth:`.ResponseFuture.result` via
|
||||
:attr:`.ResponseFuture.custom_payload`
|
||||
|
||||
Example usage::
|
||||
|
||||
>>> session = cluster.connect()
|
||||
@@ -1469,11 +1481,11 @@ class Session(object):
|
||||
... log.exception("Operation failed:")
|
||||
|
||||
"""
|
||||
future = self._create_response_future(query, parameters, trace)
|
||||
future = self._create_response_future(query, parameters, trace, custom_payload)
|
||||
future.send_request()
|
||||
return future
|
||||
|
||||
def _create_response_future(self, query, parameters, trace):
|
||||
def _create_response_future(self, query, parameters, trace, custom_payload):
|
||||
""" Returns the ResponseFuture before calling send_request() on it """
|
||||
|
||||
prepared_statement = None
|
||||
@@ -1523,13 +1535,16 @@ class Session(object):
|
||||
if trace:
|
||||
message.tracing = True
|
||||
|
||||
message.update_custom_payload(query.custom_payload)
|
||||
message.update_custom_payload(custom_payload)
|
||||
|
||||
return ResponseFuture(
|
||||
self, message, query, self.default_timeout, metrics=self._metrics,
|
||||
prepared_statement=prepared_statement)
|
||||
|
||||
def prepare(self, query):
|
||||
def prepare(self, query, custom_payload=None):
|
||||
"""
|
||||
Prepares a query string, returing a :class:`~cassandra.query.PreparedStatement`
|
||||
Prepares a query string, returning a :class:`~cassandra.query.PreparedStatement`
|
||||
instance which can be used as follows::
|
||||
|
||||
>>> session = cluster.connect("mykeyspace")
|
||||
@@ -1552,8 +1567,12 @@ class Session(object):
|
||||
|
||||
**Important**: PreparedStatements should be prepared only once.
|
||||
Preparing the same query more than once will likely affect performance.
|
||||
|
||||
`custom_payload` is a key value map to be passed along with the prepare
|
||||
message. See :ref:`custom_payload`.
|
||||
"""
|
||||
message = PrepareMessage(query=query)
|
||||
message.custom_payload = custom_payload
|
||||
future = ResponseFuture(self, message, query=None)
|
||||
try:
|
||||
future.send_request()
|
||||
@@ -1565,6 +1584,7 @@ class Session(object):
|
||||
prepared_statement = PreparedStatement.from_message(
|
||||
query_id, column_metadata, pk_indexes, self.cluster.metadata, query, self.keyspace,
|
||||
self._protocol_version)
|
||||
prepared_statement.custom_payload = future.custom_payload
|
||||
|
||||
host = future._current_host
|
||||
try:
|
||||
@@ -2636,6 +2656,7 @@ class ResponseFuture(object):
|
||||
_start_time = None
|
||||
_metrics = None
|
||||
_paging_state = None
|
||||
_custom_payload = None
|
||||
|
||||
def __init__(self, session, message, query, default_timeout=None, metrics=None, prepared_statement=None):
|
||||
self.session = session
|
||||
@@ -2723,6 +2744,23 @@ class ResponseFuture(object):
|
||||
"""
|
||||
return self._paging_state is not None
|
||||
|
||||
@property
|
||||
def custom_payload(self):
|
||||
"""
|
||||
The custom payload returned from the server, if any. This will only be
|
||||
set by Cassandra servers implementing a custom QueryHandler, and only
|
||||
for protocol_version 4+.
|
||||
|
||||
Ensure the future is complete before trying to access this property
|
||||
(call :meth:`.result()`, or after callback is invoked).
|
||||
Otherwise it may throw if the response has not been received.
|
||||
|
||||
:return: :ref:`custom_payload`.
|
||||
"""
|
||||
if not self._event.is_set():
|
||||
raise Exception("custom_payload cannot be retrieved before ResponseFuture is finalized")
|
||||
return self._custom_payload
|
||||
|
||||
def start_fetching_next_page(self):
|
||||
"""
|
||||
If there are more pages left in the query result, this asynchronously
|
||||
@@ -2759,6 +2797,8 @@ class ResponseFuture(object):
|
||||
if trace_id:
|
||||
self._query_trace = QueryTrace(trace_id, self.session)
|
||||
|
||||
self._custom_payload = getattr(response, 'custom_payload', None)
|
||||
|
||||
if isinstance(response, ResultMessage):
|
||||
if response.kind == RESULT_KIND_SET_KEYSPACE:
|
||||
session = getattr(self, 'session', None)
|
||||
|
||||
@@ -56,6 +56,7 @@ HEADER_DIRECTION_MASK = 0x80
|
||||
|
||||
COMPRESSED_FLAG = 0x01
|
||||
TRACING_FLAG = 0x02
|
||||
CUSTOM_PAYLOAD_FLAG = 0x04
|
||||
|
||||
_message_types_by_name = {}
|
||||
_message_types_by_opcode = {}
|
||||
@@ -72,13 +73,19 @@ class _RegisterMessageType(type):
|
||||
class _MessageType(object):
|
||||
|
||||
tracing = False
|
||||
custom_payload = None
|
||||
|
||||
def to_binary(self, stream_id, protocol_version, compression=None):
|
||||
flags = 0
|
||||
body = io.BytesIO()
|
||||
if self.custom_payload:
|
||||
if protocol_version < 4:
|
||||
raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher")
|
||||
flags |= CUSTOM_PAYLOAD_FLAG
|
||||
write_bytesmap(body, self.custom_payload)
|
||||
self.send_body(body, protocol_version)
|
||||
body = body.getvalue()
|
||||
|
||||
flags = 0
|
||||
if compression and len(body) > 0:
|
||||
body = compression(body)
|
||||
flags |= COMPRESSED_FLAG
|
||||
@@ -91,6 +98,12 @@ class _MessageType(object):
|
||||
|
||||
return msg.getvalue()
|
||||
|
||||
def update_custom_payload(self, other):
|
||||
if other:
|
||||
if not self.custom_payload:
|
||||
self.custom_payload = {}
|
||||
self.custom_payload.update(other)
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self)))
|
||||
|
||||
@@ -118,6 +131,12 @@ def decode_response(protocol_version, user_type_map, stream_id, flags, opcode, b
|
||||
else:
|
||||
trace_id = None
|
||||
|
||||
if flags & CUSTOM_PAYLOAD_FLAG:
|
||||
custom_payload = read_bytesmap(body)
|
||||
flags ^= CUSTOM_PAYLOAD_FLAG
|
||||
else:
|
||||
custom_payload = None
|
||||
|
||||
if flags:
|
||||
log.warning("Unknown protocol flags set: %02x. May cause problems.", flags)
|
||||
|
||||
@@ -125,6 +144,7 @@ def decode_response(protocol_version, user_type_map, stream_id, flags, opcode, b
|
||||
msg = msg_class.recv_body(body, protocol_version, user_type_map)
|
||||
msg.stream_id = stream_id
|
||||
msg.trace_id = trace_id
|
||||
msg.custom_payload = custom_payload
|
||||
return msg
|
||||
|
||||
|
||||
@@ -977,6 +997,11 @@ def read_binary_string(f):
|
||||
return contents
|
||||
|
||||
|
||||
def write_binary_string(f, s):
|
||||
write_short(f, len(s))
|
||||
f.write(s)
|
||||
|
||||
|
||||
def write_string(f, s):
|
||||
if isinstance(s, six.text_type):
|
||||
s = s.encode('utf8')
|
||||
@@ -1028,6 +1053,22 @@ def write_stringmap(f, strmap):
|
||||
write_string(f, v)
|
||||
|
||||
|
||||
def read_bytesmap(f):
|
||||
numpairs = read_short(f)
|
||||
bytesmap = {}
|
||||
for _ in range(numpairs):
|
||||
k = read_string(f)
|
||||
bytesmap[k] = read_binary_string(f)
|
||||
return bytesmap
|
||||
|
||||
|
||||
def write_bytesmap(f, bytesmap):
|
||||
write_short(f, len(bytesmap))
|
||||
for k, v in bytesmap.items():
|
||||
write_string(f, k)
|
||||
write_binary_string(f, v)
|
||||
|
||||
|
||||
def read_stringmultimap(f):
|
||||
numkeys = read_short(f)
|
||||
strmmap = {}
|
||||
|
||||
@@ -197,11 +197,21 @@ class Statement(object):
|
||||
.. versionadded:: 2.1.3
|
||||
"""
|
||||
|
||||
custom_payload = None
|
||||
"""
|
||||
:ref:`custom_payload` to be passed to the server.
|
||||
|
||||
These are only allowed when using protocol version 4 or higher.
|
||||
|
||||
.. versionadded:: 3.0.0
|
||||
"""
|
||||
|
||||
_serial_consistency_level = None
|
||||
_routing_key = None
|
||||
|
||||
def __init__(self, retry_policy=None, consistency_level=None, routing_key=None,
|
||||
serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None):
|
||||
serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None,
|
||||
custom_payload=None):
|
||||
self.retry_policy = retry_policy
|
||||
if consistency_level is not None:
|
||||
self.consistency_level = consistency_level
|
||||
@@ -212,6 +222,8 @@ class Statement(object):
|
||||
self.fetch_size = fetch_size
|
||||
if keyspace is not None:
|
||||
self.keyspace = keyspace
|
||||
if custom_payload is not None:
|
||||
self.custom_payload = custom_payload
|
||||
|
||||
def _get_routing_key(self):
|
||||
return self._routing_key
|
||||
@@ -290,8 +302,7 @@ class Statement(object):
|
||||
|
||||
class SimpleStatement(Statement):
|
||||
"""
|
||||
A simple, un-prepared query. All attributes of :class:`Statement` apply
|
||||
to this class as well.
|
||||
A simple, un-prepared query.
|
||||
"""
|
||||
|
||||
def __init__(self, query_string, *args, **kwargs):
|
||||
@@ -299,6 +310,8 @@ class SimpleStatement(Statement):
|
||||
`query_string` should be a literal CQL statement with the exception
|
||||
of parameter placeholders that will be filled through the
|
||||
`parameters` argument of :meth:`.Session.execute()`.
|
||||
|
||||
All arguments to :class:`Statement` apply to this class as well
|
||||
"""
|
||||
Statement.__init__(self, *args, **kwargs)
|
||||
self._query_string = query_string
|
||||
@@ -338,19 +351,16 @@ class PreparedStatement(object):
|
||||
|
||||
fetch_size = FETCH_SIZE_UNSET
|
||||
|
||||
def __init__(self, column_metadata, query_id, routing_key_indexes, query, keyspace,
|
||||
protocol_version, consistency_level=None, serial_consistency_level=None,
|
||||
fetch_size=FETCH_SIZE_UNSET):
|
||||
custom_payload = None
|
||||
|
||||
def __init__(self, column_metadata, query_id, routing_key_indexes, query,
|
||||
keyspace, protocol_version):
|
||||
self.column_metadata = column_metadata
|
||||
self.query_id = query_id
|
||||
self.routing_key_indexes = routing_key_indexes
|
||||
self.query_string = query
|
||||
self.keyspace = keyspace
|
||||
self.protocol_version = protocol_version
|
||||
self.consistency_level = consistency_level
|
||||
self.serial_consistency_level = serial_consistency_level
|
||||
if fetch_size is not FETCH_SIZE_UNSET:
|
||||
self.fetch_size = fetch_size
|
||||
|
||||
@classmethod
|
||||
def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, query, prepared_keyspace, protocol_version):
|
||||
@@ -402,8 +412,6 @@ class BoundStatement(Statement):
|
||||
"""
|
||||
A prepared statement that has been bound to a particular set of values.
|
||||
These may be created directly or through :meth:`.PreparedStatement.bind()`.
|
||||
|
||||
All attributes of :class:`Statement` apply to this class as well.
|
||||
"""
|
||||
|
||||
prepared_statement = None
|
||||
@@ -419,13 +427,15 @@ class BoundStatement(Statement):
|
||||
def __init__(self, prepared_statement, *args, **kwargs):
|
||||
"""
|
||||
`prepared_statement` should be an instance of :class:`PreparedStatement`.
|
||||
All other ``*args`` and ``**kwargs`` will be passed to :class:`.Statement`.
|
||||
|
||||
All arguments to :class:`Statement` apply to this class as well
|
||||
"""
|
||||
self.prepared_statement = prepared_statement
|
||||
|
||||
self.consistency_level = prepared_statement.consistency_level
|
||||
self.serial_consistency_level = prepared_statement.serial_consistency_level
|
||||
self.fetch_size = prepared_statement.fetch_size
|
||||
self.custom_payload = prepared_statement.custom_payload
|
||||
self.values = []
|
||||
|
||||
meta = prepared_statement.column_metadata
|
||||
@@ -606,7 +616,8 @@ class BatchStatement(Statement):
|
||||
_session = None
|
||||
|
||||
def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None,
|
||||
consistency_level=None, serial_consistency_level=None, session=None):
|
||||
consistency_level=None, serial_consistency_level=None,
|
||||
session=None, custom_payload=None):
|
||||
"""
|
||||
`batch_type` specifies The :class:`.BatchType` for the batch operation.
|
||||
Defaults to :attr:`.BatchType.LOGGED`.
|
||||
@@ -617,6 +628,11 @@ class BatchStatement(Statement):
|
||||
`consistency_level` should be a :class:`~.ConsistencyLevel` value
|
||||
to be used for all operations in the batch.
|
||||
|
||||
`custom_payload` is a :ref:`custom_payload` passed to the server.
|
||||
Note: as Statement objects are added to the batch, this map is
|
||||
updated with any values found in their custom payloads. These are
|
||||
only allowed when using protocol version 4 or higher.
|
||||
|
||||
Example usage:
|
||||
|
||||
.. code-block:: python
|
||||
@@ -642,12 +658,15 @@ class BatchStatement(Statement):
|
||||
|
||||
.. versionchanged:: 2.1.0
|
||||
Added `serial_consistency_level` as a parameter
|
||||
|
||||
.. versionchanged:: 3.0.0
|
||||
Added `custom_payload` as a parameter
|
||||
"""
|
||||
self.batch_type = batch_type
|
||||
self._statements_and_parameters = []
|
||||
self._session = session
|
||||
Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level,
|
||||
serial_consistency_level=serial_consistency_level)
|
||||
serial_consistency_level=serial_consistency_level, custom_payload=custom_payload)
|
||||
|
||||
def add(self, statement, parameters=None):
|
||||
"""
|
||||
@@ -665,7 +684,7 @@ class BatchStatement(Statement):
|
||||
elif isinstance(statement, PreparedStatement):
|
||||
query_id = statement.query_id
|
||||
bound_statement = statement.bind(() if parameters is None else parameters)
|
||||
self._maybe_set_routing_attributes(bound_statement)
|
||||
self._update_state(bound_statement)
|
||||
self._statements_and_parameters.append(
|
||||
(True, query_id, bound_statement.values))
|
||||
elif isinstance(statement, BoundStatement):
|
||||
@@ -673,7 +692,7 @@ class BatchStatement(Statement):
|
||||
raise ValueError(
|
||||
"Parameters cannot be passed with a BoundStatement "
|
||||
"to BatchStatement.add()")
|
||||
self._maybe_set_routing_attributes(statement)
|
||||
self._update_state(statement)
|
||||
self._statements_and_parameters.append(
|
||||
(True, statement.prepared_statement.query_id, statement.values))
|
||||
else:
|
||||
@@ -682,7 +701,7 @@ class BatchStatement(Statement):
|
||||
if parameters:
|
||||
encoder = Encoder() if self._session is None else self._session.encoder
|
||||
query_string = bind_params(query_string, parameters, encoder)
|
||||
self._maybe_set_routing_attributes(statement)
|
||||
self._update_state(statement)
|
||||
self._statements_and_parameters.append((False, query_string, ()))
|
||||
return self
|
||||
|
||||
@@ -701,6 +720,16 @@ class BatchStatement(Statement):
|
||||
self.routing_key = statement.routing_key
|
||||
self.keyspace = statement.keyspace
|
||||
|
||||
def _update_custom_payload(self, statement):
|
||||
if statement.custom_payload:
|
||||
if self.custom_payload is None:
|
||||
self.custom_payload = {}
|
||||
self.custom_payload.update(statement.custom_payload)
|
||||
|
||||
def _update_state(self, statement):
|
||||
self._maybe_set_routing_attributes(statement)
|
||||
self._update_custom_payload(statement)
|
||||
|
||||
def __str__(self):
|
||||
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
|
||||
return (u'<BatchStatement type=%s, statements=%d, consistency=%s>' %
|
||||
@@ -850,7 +879,7 @@ class QueryTrace(object):
|
||||
|
||||
def _execute(self, query, parameters, time_spent, max_wait):
|
||||
# in case the user switched the row factory, set it to namedtuple for this query
|
||||
future = self._session._create_response_future(query, parameters, trace=False)
|
||||
future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None)
|
||||
future.row_factory = named_tuple_factory
|
||||
future.send_request()
|
||||
|
||||
|
||||
@@ -102,6 +102,8 @@
|
||||
|
||||
.. automethod:: get_query_trace()
|
||||
|
||||
.. autoattribute:: custom_payload()
|
||||
|
||||
.. autoattribute:: has_more_pages
|
||||
|
||||
.. automethod:: start_fetching_next_page()
|
||||
|
||||
11
docs/api/cassandra/protocol.rst
Normal file
11
docs/api/cassandra/protocol.rst
Normal file
@@ -0,0 +1,11 @@
|
||||
.. _custom_payload:
|
||||
|
||||
Custom Payload
|
||||
==============
|
||||
Native protocol version 4+ allows for a custom payload to be sent between clients
|
||||
and custom query handlers. The payload is specified as a string:binary_type dict
|
||||
holding custom key/value pairs.
|
||||
|
||||
By default these are ignored by the server. They can be useful for servers implementing
|
||||
a custom QueryHandler.
|
||||
|
||||
@@ -14,6 +14,7 @@ Core Driver
|
||||
cassandra/metrics
|
||||
cassandra/query
|
||||
cassandra/pool
|
||||
cassandra/protocol
|
||||
cassandra/encoder
|
||||
cassandra/decoder
|
||||
cassandra/concurrent
|
||||
|
||||
@@ -128,8 +128,8 @@ class BoundStatementTestCase(unittest.TestCase):
|
||||
routing_key_indexes=[],
|
||||
query=None,
|
||||
keyspace=keyspace,
|
||||
protocol_version=2,
|
||||
fetch_size=1234)
|
||||
protocol_version=2)
|
||||
prepared_statement.fetch_size = 1234
|
||||
bound_statement = BoundStatement(prepared_statement=prepared_statement)
|
||||
self.assertEqual(1234, bound_statement.fetch_size)
|
||||
|
||||
@@ -147,10 +147,9 @@ class BoundStatementTestCase(unittest.TestCase):
|
||||
routing_key_indexes=[0, 1],
|
||||
query=None,
|
||||
keyspace=keyspace,
|
||||
protocol_version=2,
|
||||
fetch_size=1234)
|
||||
protocol_version=2)
|
||||
|
||||
self.assertRaises(ValueError, prepared_statement.bind, (1,))
|
||||
|
||||
bound = prepared_statement.bind((1,2))
|
||||
bound = prepared_statement.bind((1, 2))
|
||||
self.assertEqual(bound.keyspace, keyspace)
|
||||
|
||||
Reference in New Issue
Block a user