Custom payloads for protocol v4
PYTHON-280
This commit is contained in:
@@ -78,7 +78,6 @@ Bug Fixes
|
||||
---------
|
||||
* Make execute_concurrent compatible with Python 2.6 (PYTHON-159)
|
||||
* Handle Unauthorized message on schema_triggers query (PYTHON-155)
|
||||
* Make execute_concurrent compatible with Python 2.6 (github-197)
|
||||
* Pure Python sorted set in support of UDTs nested in collections (PYTON-167)
|
||||
* Support CUSTOM index metadata and string export (PYTHON-165)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ try:
|
||||
except ImportError:
|
||||
SASLClient = None
|
||||
|
||||
|
||||
class AuthProvider(object):
|
||||
"""
|
||||
An abstract class that defines the interface that will be used for
|
||||
@@ -157,6 +158,7 @@ class SaslAuthProvider(AuthProvider):
|
||||
def new_authenticator(self, host):
|
||||
return SaslAuthenticator(**self.sasl_kwargs)
|
||||
|
||||
|
||||
class SaslAuthenticator(Authenticator):
|
||||
"""
|
||||
A pass-through :class:`~.Authenticator` using the third party package
|
||||
|
||||
@@ -1361,7 +1361,7 @@ class Session(object):
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False):
|
||||
def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False, custom_payload=None):
|
||||
"""
|
||||
Execute the given query and synchronously wait for the response.
|
||||
|
||||
@@ -1389,6 +1389,10 @@ class Session(object):
|
||||
instance and not just a string. If there is an error fetching the
|
||||
trace details, the :attr:`~.Statement.trace` attribute will be left as
|
||||
:const:`None`.
|
||||
|
||||
`custom_payload` is a dict as described in TODO section. If `query` is a Statement
|
||||
with its own custom_payload. the message will be a union of the two,
|
||||
with the values specified here taking precedence.
|
||||
"""
|
||||
if timeout is _NOT_SET:
|
||||
timeout = self.default_timeout
|
||||
@@ -1398,7 +1402,7 @@ class Session(object):
|
||||
"The query argument must be an instance of a subclass of "
|
||||
"cassandra.query.Statement when trace=True")
|
||||
|
||||
future = self.execute_async(query, parameters, trace)
|
||||
future = self.execute_async(query, parameters, trace, custom_payload)
|
||||
try:
|
||||
result = future.result(timeout)
|
||||
finally:
|
||||
@@ -1410,7 +1414,7 @@ class Session(object):
|
||||
|
||||
return result
|
||||
|
||||
def execute_async(self, query, parameters=None, trace=False):
|
||||
def execute_async(self, query, parameters=None, trace=False, custom_payload=None):
|
||||
"""
|
||||
Execute the given query and return a :class:`~.ResponseFuture` object
|
||||
which callbacks may be attached to for asynchronous response
|
||||
@@ -1422,6 +1426,13 @@ class Session(object):
|
||||
:meth:`.ResponseFuture.get_query_trace()` after the request
|
||||
completes to retrieve a :class:`.QueryTrace` instance.
|
||||
|
||||
`custom_payload` is a dict as described in TODO section. If `query` is
|
||||
a Statement with a custom_payload specified. the message will be a
|
||||
union of the two, with the values specified here taking precedence.
|
||||
|
||||
If the server sends a custom payload in the response message,
|
||||
the dict can be obtained via :attr:`.ResponseFuture.custom_payload`
|
||||
|
||||
Example usage::
|
||||
|
||||
>>> session = cluster.connect()
|
||||
@@ -1447,11 +1458,11 @@ class Session(object):
|
||||
... log.exception("Operation failed:")
|
||||
|
||||
"""
|
||||
future = self._create_response_future(query, parameters, trace)
|
||||
future = self._create_response_future(query, parameters, trace, custom_payload)
|
||||
future.send_request()
|
||||
return future
|
||||
|
||||
def _create_response_future(self, query, parameters, trace):
|
||||
def _create_response_future(self, query, parameters, trace, custom_payload):
|
||||
""" Returns the ResponseFuture before calling send_request() on it """
|
||||
|
||||
prepared_statement = None
|
||||
@@ -1501,13 +1512,16 @@ class Session(object):
|
||||
if trace:
|
||||
message.tracing = True
|
||||
|
||||
message.update_custom_payload(query.custom_payload)
|
||||
message.update_custom_payload(custom_payload)
|
||||
|
||||
return ResponseFuture(
|
||||
self, message, query, self.default_timeout, metrics=self._metrics,
|
||||
prepared_statement=prepared_statement)
|
||||
|
||||
def prepare(self, query):
|
||||
def prepare(self, query, custom_payload=None):
|
||||
"""
|
||||
Prepares a query string, returing a :class:`~cassandra.query.PreparedStatement`
|
||||
Prepares a query string, returning a :class:`~cassandra.query.PreparedStatement`
|
||||
instance which can be used as follows::
|
||||
|
||||
>>> session = cluster.connect("mykeyspace")
|
||||
@@ -1530,8 +1544,12 @@ class Session(object):
|
||||
|
||||
**Important**: PreparedStatements should be prepared only once.
|
||||
Preparing the same query more than once will likely affect performance.
|
||||
|
||||
`custom_payload` is a key value map to be passed along with the prepare
|
||||
message. See TODO: refer to doc section
|
||||
"""
|
||||
message = PrepareMessage(query=query)
|
||||
message.custom_payload = custom_payload
|
||||
future = ResponseFuture(self, message, query=None)
|
||||
try:
|
||||
future.send_request()
|
||||
@@ -1543,6 +1561,7 @@ class Session(object):
|
||||
prepared_statement = PreparedStatement.from_message(
|
||||
query_id, column_metadata, pk_indexes, self.cluster.metadata, query, self.keyspace,
|
||||
self._protocol_version)
|
||||
prepared_statement.custom_payload = future.custom_payload
|
||||
|
||||
host = future._current_host
|
||||
try:
|
||||
@@ -2567,6 +2586,7 @@ class ResponseFuture(object):
|
||||
_start_time = None
|
||||
_metrics = None
|
||||
_paging_state = None
|
||||
_custom_payload = None
|
||||
|
||||
def __init__(self, session, message, query, default_timeout=None, metrics=None, prepared_statement=None):
|
||||
self.session = session
|
||||
@@ -2654,6 +2674,12 @@ class ResponseFuture(object):
|
||||
"""
|
||||
return self._paging_state is not None
|
||||
|
||||
@property
|
||||
def custom_payload(self):
|
||||
if not self._event.is_set():
|
||||
raise Exception("custom_payload cannot be retrieved before ResponseFuture is finalized")
|
||||
return self._custom_payload
|
||||
|
||||
def start_fetching_next_page(self):
|
||||
"""
|
||||
If there are more pages left in the query result, this asynchronously
|
||||
@@ -2690,6 +2716,8 @@ class ResponseFuture(object):
|
||||
if trace_id:
|
||||
self._query_trace = QueryTrace(trace_id, self.session)
|
||||
|
||||
self._custom_payload = getattr(response, 'custom_payload', None)
|
||||
|
||||
if isinstance(response, ResultMessage):
|
||||
if response.kind == RESULT_KIND_SET_KEYSPACE:
|
||||
session = getattr(self, 'session', None)
|
||||
|
||||
@@ -54,6 +54,7 @@ HEADER_DIRECTION_MASK = 0x80
|
||||
|
||||
COMPRESSED_FLAG = 0x01
|
||||
TRACING_FLAG = 0x02
|
||||
CUSTOM_PAYLOAD_FLAG = 0x04
|
||||
|
||||
_message_types_by_name = {}
|
||||
_message_types_by_opcode = {}
|
||||
@@ -70,13 +71,19 @@ class _RegisterMessageType(type):
|
||||
class _MessageType(object):
|
||||
|
||||
tracing = False
|
||||
custom_payload = None
|
||||
|
||||
def to_binary(self, stream_id, protocol_version, compression=None):
|
||||
flags = 0
|
||||
body = io.BytesIO()
|
||||
if self.custom_payload:
|
||||
if protocol_version < 4:
|
||||
raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher")
|
||||
flags |= CUSTOM_PAYLOAD_FLAG
|
||||
write_bytesmap(body, self.custom_payload)
|
||||
self.send_body(body, protocol_version)
|
||||
body = body.getvalue()
|
||||
|
||||
flags = 0
|
||||
if compression and len(body) > 0:
|
||||
body = compression(body)
|
||||
flags |= COMPRESSED_FLAG
|
||||
@@ -89,6 +96,12 @@ class _MessageType(object):
|
||||
|
||||
return msg.getvalue()
|
||||
|
||||
def update_custom_payload(self, other):
|
||||
if other:
|
||||
if not self.custom_payload:
|
||||
self.custom_payload = {}
|
||||
self.custom_payload.update(other)
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self)))
|
||||
|
||||
@@ -116,6 +129,12 @@ def decode_response(protocol_version, user_type_map, stream_id, flags, opcode, b
|
||||
else:
|
||||
trace_id = None
|
||||
|
||||
if flags & CUSTOM_PAYLOAD_FLAG:
|
||||
custom_payload = read_bytesmap(body)
|
||||
flags ^= CUSTOM_PAYLOAD_FLAG
|
||||
else:
|
||||
custom_payload = None
|
||||
|
||||
if flags:
|
||||
log.warning("Unknown protocol flags set: %02x. May cause problems.", flags)
|
||||
|
||||
@@ -123,6 +142,7 @@ def decode_response(protocol_version, user_type_map, stream_id, flags, opcode, b
|
||||
msg = msg_class.recv_body(body, protocol_version, user_type_map)
|
||||
msg.stream_id = stream_id
|
||||
msg.trace_id = trace_id
|
||||
msg.custom_payload = custom_payload
|
||||
return msg
|
||||
|
||||
|
||||
@@ -918,6 +938,11 @@ def read_binary_string(f):
|
||||
return contents
|
||||
|
||||
|
||||
def write_binary_string(f, s):
|
||||
write_short(f, len(s))
|
||||
f.write(s)
|
||||
|
||||
|
||||
def write_string(f, s):
|
||||
if isinstance(s, six.text_type):
|
||||
s = s.encode('utf8')
|
||||
@@ -969,6 +994,22 @@ def write_stringmap(f, strmap):
|
||||
write_string(f, v)
|
||||
|
||||
|
||||
def read_bytesmap(f):
|
||||
numpairs = read_short(f)
|
||||
bytesmap = {}
|
||||
for _ in range(numpairs):
|
||||
k = read_string(f)
|
||||
bytesmap[k] = read_binary_string(f)
|
||||
return bytesmap
|
||||
|
||||
|
||||
def write_bytesmap(f, bytesmap):
|
||||
write_short(f, len(bytesmap))
|
||||
for k, v in bytesmap.items():
|
||||
write_string(f, k)
|
||||
write_binary_string(f, v)
|
||||
|
||||
|
||||
def read_stringmultimap(f):
|
||||
numkeys = read_short(f)
|
||||
strmmap = {}
|
||||
|
||||
@@ -197,11 +197,25 @@ class Statement(object):
|
||||
.. versionadded:: 2.1.3
|
||||
"""
|
||||
|
||||
custom_payload = None
|
||||
"""
|
||||
TODO: refer to custom proto doc section
|
||||
A string:binary_type dict holding custom key/value pairs to be passed
|
||||
in the frame to a custom QueryHandler on the server side.
|
||||
|
||||
By default these values are ignored by the server.
|
||||
|
||||
These are only allowed when using protocol version 4 or higher.
|
||||
|
||||
.. versionadded:: 3.0.0
|
||||
"""
|
||||
|
||||
_serial_consistency_level = None
|
||||
_routing_key = None
|
||||
|
||||
def __init__(self, retry_policy=None, consistency_level=None, routing_key=None,
|
||||
serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None):
|
||||
serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None,
|
||||
custom_payload=None):
|
||||
self.retry_policy = retry_policy
|
||||
if consistency_level is not None:
|
||||
self.consistency_level = consistency_level
|
||||
@@ -212,6 +226,8 @@ class Statement(object):
|
||||
self.fetch_size = fetch_size
|
||||
if keyspace is not None:
|
||||
self.keyspace = keyspace
|
||||
if custom_payload is not None:
|
||||
self.custom_payload = custom_payload
|
||||
|
||||
def _get_routing_key(self):
|
||||
return self._routing_key
|
||||
@@ -290,8 +306,7 @@ class Statement(object):
|
||||
|
||||
class SimpleStatement(Statement):
|
||||
"""
|
||||
A simple, un-prepared query. All attributes of :class:`Statement` apply
|
||||
to this class as well.
|
||||
A simple, un-prepared query.
|
||||
"""
|
||||
|
||||
def __init__(self, query_string, *args, **kwargs):
|
||||
@@ -299,6 +314,8 @@ class SimpleStatement(Statement):
|
||||
`query_string` should be a literal CQL statement with the exception
|
||||
of parameter placeholders that will be filled through the
|
||||
`parameters` argument of :meth:`.Session.execute()`.
|
||||
|
||||
All arguments to :class:`Statement` apply to this class as well
|
||||
"""
|
||||
Statement.__init__(self, *args, **kwargs)
|
||||
self._query_string = query_string
|
||||
@@ -338,6 +355,8 @@ class PreparedStatement(object):
|
||||
|
||||
fetch_size = FETCH_SIZE_UNSET
|
||||
|
||||
custom_payload = None
|
||||
|
||||
def __init__(self, column_metadata, query_id, routing_key_indexes, query,
|
||||
keyspace, protocol_version):
|
||||
self.column_metadata = column_metadata
|
||||
@@ -397,8 +416,6 @@ class BoundStatement(Statement):
|
||||
"""
|
||||
A prepared statement that has been bound to a particular set of values.
|
||||
These may be created directly or through :meth:`.PreparedStatement.bind()`.
|
||||
|
||||
All attributes of :class:`Statement` apply to this class as well.
|
||||
"""
|
||||
|
||||
prepared_statement = None
|
||||
@@ -414,13 +431,15 @@ class BoundStatement(Statement):
|
||||
def __init__(self, prepared_statement, *args, **kwargs):
|
||||
"""
|
||||
`prepared_statement` should be an instance of :class:`PreparedStatement`.
|
||||
All other ``*args`` and ``**kwargs`` will be passed to :class:`.Statement`.
|
||||
|
||||
All arguments to :class:`Statement` apply to this class as well
|
||||
"""
|
||||
self.prepared_statement = prepared_statement
|
||||
|
||||
self.consistency_level = prepared_statement.consistency_level
|
||||
self.serial_consistency_level = prepared_statement.serial_consistency_level
|
||||
self.fetch_size = prepared_statement.fetch_size
|
||||
self.custom_payload = prepared_statement.custom_payload
|
||||
self.values = []
|
||||
|
||||
meta = prepared_statement.column_metadata
|
||||
@@ -601,7 +620,8 @@ class BatchStatement(Statement):
|
||||
_session = None
|
||||
|
||||
def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None,
|
||||
consistency_level=None, serial_consistency_level=None, session=None):
|
||||
consistency_level=None, serial_consistency_level=None,
|
||||
session=None, custom_payload=None):
|
||||
"""
|
||||
`batch_type` specifies The :class:`.BatchType` for the batch operation.
|
||||
Defaults to :attr:`.BatchType.LOGGED`.
|
||||
@@ -612,6 +632,10 @@ class BatchStatement(Statement):
|
||||
`consistency_level` should be a :class:`~.ConsistencyLevel` value
|
||||
to be used for all operations in the batch.
|
||||
|
||||
`custom_payload` is a key-value map TODO: refer to doc section
|
||||
Note: as Statement objects are added to the batch, this map is
|
||||
updated with values from their custom payloads.
|
||||
|
||||
Example usage:
|
||||
|
||||
.. code-block:: python
|
||||
@@ -637,12 +661,15 @@ class BatchStatement(Statement):
|
||||
|
||||
.. versionchanged:: 2.1.0
|
||||
Added `serial_consistency_level` as a parameter
|
||||
|
||||
.. versionchanged:: 3.0.0
|
||||
Added `custom_payload` as a parameter
|
||||
"""
|
||||
self.batch_type = batch_type
|
||||
self._statements_and_parameters = []
|
||||
self._session = session
|
||||
Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level,
|
||||
serial_consistency_level=serial_consistency_level)
|
||||
serial_consistency_level=serial_consistency_level, custom_payload=custom_payload)
|
||||
|
||||
def add(self, statement, parameters=None):
|
||||
"""
|
||||
@@ -660,7 +687,7 @@ class BatchStatement(Statement):
|
||||
elif isinstance(statement, PreparedStatement):
|
||||
query_id = statement.query_id
|
||||
bound_statement = statement.bind(() if parameters is None else parameters)
|
||||
self._maybe_set_routing_attributes(bound_statement)
|
||||
self._update_state(bound_statement)
|
||||
self._statements_and_parameters.append(
|
||||
(True, query_id, bound_statement.values))
|
||||
elif isinstance(statement, BoundStatement):
|
||||
@@ -668,7 +695,7 @@ class BatchStatement(Statement):
|
||||
raise ValueError(
|
||||
"Parameters cannot be passed with a BoundStatement "
|
||||
"to BatchStatement.add()")
|
||||
self._maybe_set_routing_attributes(statement)
|
||||
self._update_state(statement)
|
||||
self._statements_and_parameters.append(
|
||||
(True, statement.prepared_statement.query_id, statement.values))
|
||||
else:
|
||||
@@ -677,7 +704,7 @@ class BatchStatement(Statement):
|
||||
if parameters:
|
||||
encoder = Encoder() if self._session is None else self._session.encoder
|
||||
query_string = bind_params(query_string, parameters, encoder)
|
||||
self._maybe_set_routing_attributes(statement)
|
||||
self._update_state(statement)
|
||||
self._statements_and_parameters.append((False, query_string, ()))
|
||||
return self
|
||||
|
||||
@@ -696,6 +723,16 @@ class BatchStatement(Statement):
|
||||
self.routing_key = statement.routing_key
|
||||
self.keyspace = statement.keyspace
|
||||
|
||||
def _update_custom_payload(self, statement):
|
||||
if statement.custom_payload:
|
||||
if self.custom_payload is None:
|
||||
self.custom_payload = {}
|
||||
self.custom_payload.update(statement.custom_payload)
|
||||
|
||||
def _update_state(self, statement):
|
||||
self._maybe_set_routing_attributes(statement)
|
||||
self._update_custom_payload(statement)
|
||||
|
||||
def __str__(self):
|
||||
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
|
||||
return (u'<BatchStatement type=%s, statements=%d, consistency=%s>' %
|
||||
@@ -836,7 +873,7 @@ class QueryTrace(object):
|
||||
|
||||
def _execute(self, query, parameters, time_spent, max_wait):
|
||||
# in case the user switched the row factory, set it to namedtuple for this query
|
||||
future = self._session._create_response_future(query, parameters, trace=False)
|
||||
future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None)
|
||||
future.row_factory = named_tuple_factory
|
||||
future.send_request()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user