Merge pull request #276 from datastax/PYTHON-280

PYTHON-280 - Custom payload for protocol v4
This commit is contained in:
Adam Holmberg
2015-04-30 10:57:23 -05:00
9 changed files with 157 additions and 33 deletions

View File

@@ -78,7 +78,6 @@ Bug Fixes
---------
* Make execute_concurrent compatible with Python 2.6 (PYTHON-159)
* Handle Unauthorized message on schema_triggers query (PYTHON-155)
* Make execute_concurrent compatible with Python 2.6 (github-197)
* Pure Python sorted set in support of UDTs nested in collections (PYTON-167)
* Support CUSTOM index metadata and string export (PYTHON-165)

View File

@@ -15,6 +15,7 @@ try:
except ImportError:
SASLClient = None
class AuthProvider(object):
"""
An abstract class that defines the interface that will be used for
@@ -157,6 +158,7 @@ class SaslAuthProvider(AuthProvider):
def new_authenticator(self, host):
return SaslAuthenticator(**self.sasl_kwargs)
class SaslAuthenticator(Authenticator):
"""
A pass-through :class:`~.Authenticator` using the third party package

View File

@@ -1383,7 +1383,7 @@ class Session(object):
for future in futures:
future.result()
def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False):
def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False, custom_payload=None):
"""
Execute the given query and synchronously wait for the response.
@@ -1411,6 +1411,10 @@ class Session(object):
instance and not just a string. If there is an error fetching the
trace details, the :attr:`~.Statement.trace` attribute will be left as
:const:`None`.
`custom_payload` is a :ref:`custom_payload` dict to be passed to the server.
If `query` is a Statement with its own custom_payload. The message payload
will be a union of the two, with the values specified here taking precedence.
"""
if timeout is _NOT_SET:
timeout = self.default_timeout
@@ -1420,7 +1424,7 @@ class Session(object):
"The query argument must be an instance of a subclass of "
"cassandra.query.Statement when trace=True")
future = self.execute_async(query, parameters, trace)
future = self.execute_async(query, parameters, trace, custom_payload)
try:
result = future.result(timeout)
finally:
@@ -1432,7 +1436,7 @@ class Session(object):
return result
def execute_async(self, query, parameters=None, trace=False):
def execute_async(self, query, parameters=None, trace=False, custom_payload=None):
"""
Execute the given query and return a :class:`~.ResponseFuture` object
which callbacks may be attached to for asynchronous response
@@ -1444,6 +1448,14 @@ class Session(object):
:meth:`.ResponseFuture.get_query_trace()` after the request
completes to retrieve a :class:`.QueryTrace` instance.
`custom_payload` is a :ref:`custom_payload` dict to be passed to the server.
If `query` is a Statement with its own custom_payload. The message payload
will be a union of the two, with the values specified here taking precedence.
If the server sends a custom payload in the response message,
the dict can be obtained following :meth:`.ResponseFuture.result` via
:attr:`.ResponseFuture.custom_payload`
Example usage::
>>> session = cluster.connect()
@@ -1469,11 +1481,11 @@ class Session(object):
... log.exception("Operation failed:")
"""
future = self._create_response_future(query, parameters, trace)
future = self._create_response_future(query, parameters, trace, custom_payload)
future.send_request()
return future
def _create_response_future(self, query, parameters, trace):
def _create_response_future(self, query, parameters, trace, custom_payload):
""" Returns the ResponseFuture before calling send_request() on it """
prepared_statement = None
@@ -1523,13 +1535,16 @@ class Session(object):
if trace:
message.tracing = True
message.update_custom_payload(query.custom_payload)
message.update_custom_payload(custom_payload)
return ResponseFuture(
self, message, query, self.default_timeout, metrics=self._metrics,
prepared_statement=prepared_statement)
def prepare(self, query):
def prepare(self, query, custom_payload=None):
"""
Prepares a query string, returing a :class:`~cassandra.query.PreparedStatement`
Prepares a query string, returning a :class:`~cassandra.query.PreparedStatement`
instance which can be used as follows::
>>> session = cluster.connect("mykeyspace")
@@ -1552,8 +1567,12 @@ class Session(object):
**Important**: PreparedStatements should be prepared only once.
Preparing the same query more than once will likely affect performance.
`custom_payload` is a key value map to be passed along with the prepare
message. See :ref:`custom_payload`.
"""
message = PrepareMessage(query=query)
message.custom_payload = custom_payload
future = ResponseFuture(self, message, query=None)
try:
future.send_request()
@@ -1565,6 +1584,7 @@ class Session(object):
prepared_statement = PreparedStatement.from_message(
query_id, column_metadata, pk_indexes, self.cluster.metadata, query, self.keyspace,
self._protocol_version)
prepared_statement.custom_payload = future.custom_payload
host = future._current_host
try:
@@ -2636,6 +2656,7 @@ class ResponseFuture(object):
_start_time = None
_metrics = None
_paging_state = None
_custom_payload = None
def __init__(self, session, message, query, default_timeout=None, metrics=None, prepared_statement=None):
self.session = session
@@ -2723,6 +2744,23 @@ class ResponseFuture(object):
"""
return self._paging_state is not None
@property
def custom_payload(self):
"""
The custom payload returned from the server, if any. This will only be
set by Cassandra servers implementing a custom QueryHandler, and only
for protocol_version 4+.
Ensure the future is complete before trying to access this property
(call :meth:`.result()`, or after callback is invoked).
Otherwise it may throw if the response has not been received.
:return: :ref:`custom_payload`.
"""
if not self._event.is_set():
raise Exception("custom_payload cannot be retrieved before ResponseFuture is finalized")
return self._custom_payload
def start_fetching_next_page(self):
"""
If there are more pages left in the query result, this asynchronously
@@ -2759,6 +2797,8 @@ class ResponseFuture(object):
if trace_id:
self._query_trace = QueryTrace(trace_id, self.session)
self._custom_payload = getattr(response, 'custom_payload', None)
if isinstance(response, ResultMessage):
if response.kind == RESULT_KIND_SET_KEYSPACE:
session = getattr(self, 'session', None)

View File

@@ -56,6 +56,7 @@ HEADER_DIRECTION_MASK = 0x80
COMPRESSED_FLAG = 0x01
TRACING_FLAG = 0x02
CUSTOM_PAYLOAD_FLAG = 0x04
_message_types_by_name = {}
_message_types_by_opcode = {}
@@ -72,13 +73,19 @@ class _RegisterMessageType(type):
class _MessageType(object):
tracing = False
custom_payload = None
def to_binary(self, stream_id, protocol_version, compression=None):
flags = 0
body = io.BytesIO()
if self.custom_payload:
if protocol_version < 4:
raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher")
flags |= CUSTOM_PAYLOAD_FLAG
write_bytesmap(body, self.custom_payload)
self.send_body(body, protocol_version)
body = body.getvalue()
flags = 0
if compression and len(body) > 0:
body = compression(body)
flags |= COMPRESSED_FLAG
@@ -91,6 +98,12 @@ class _MessageType(object):
return msg.getvalue()
def update_custom_payload(self, other):
if other:
if not self.custom_payload:
self.custom_payload = {}
self.custom_payload.update(other)
def __repr__(self):
return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self)))
@@ -118,6 +131,12 @@ def decode_response(protocol_version, user_type_map, stream_id, flags, opcode, b
else:
trace_id = None
if flags & CUSTOM_PAYLOAD_FLAG:
custom_payload = read_bytesmap(body)
flags ^= CUSTOM_PAYLOAD_FLAG
else:
custom_payload = None
if flags:
log.warning("Unknown protocol flags set: %02x. May cause problems.", flags)
@@ -125,6 +144,7 @@ def decode_response(protocol_version, user_type_map, stream_id, flags, opcode, b
msg = msg_class.recv_body(body, protocol_version, user_type_map)
msg.stream_id = stream_id
msg.trace_id = trace_id
msg.custom_payload = custom_payload
return msg
@@ -977,6 +997,11 @@ def read_binary_string(f):
return contents
def write_binary_string(f, s):
write_short(f, len(s))
f.write(s)
def write_string(f, s):
if isinstance(s, six.text_type):
s = s.encode('utf8')
@@ -1028,6 +1053,22 @@ def write_stringmap(f, strmap):
write_string(f, v)
def read_bytesmap(f):
numpairs = read_short(f)
bytesmap = {}
for _ in range(numpairs):
k = read_string(f)
bytesmap[k] = read_binary_string(f)
return bytesmap
def write_bytesmap(f, bytesmap):
write_short(f, len(bytesmap))
for k, v in bytesmap.items():
write_string(f, k)
write_binary_string(f, v)
def read_stringmultimap(f):
numkeys = read_short(f)
strmmap = {}

View File

@@ -197,11 +197,21 @@ class Statement(object):
.. versionadded:: 2.1.3
"""
custom_payload = None
"""
:ref:`custom_payload` to be passed to the server.
These are only allowed when using protocol version 4 or higher.
.. versionadded:: 3.0.0
"""
_serial_consistency_level = None
_routing_key = None
def __init__(self, retry_policy=None, consistency_level=None, routing_key=None,
serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None):
serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None,
custom_payload=None):
self.retry_policy = retry_policy
if consistency_level is not None:
self.consistency_level = consistency_level
@@ -212,6 +222,8 @@ class Statement(object):
self.fetch_size = fetch_size
if keyspace is not None:
self.keyspace = keyspace
if custom_payload is not None:
self.custom_payload = custom_payload
def _get_routing_key(self):
return self._routing_key
@@ -290,8 +302,7 @@ class Statement(object):
class SimpleStatement(Statement):
"""
A simple, un-prepared query. All attributes of :class:`Statement` apply
to this class as well.
A simple, un-prepared query.
"""
def __init__(self, query_string, *args, **kwargs):
@@ -299,6 +310,8 @@ class SimpleStatement(Statement):
`query_string` should be a literal CQL statement with the exception
of parameter placeholders that will be filled through the
`parameters` argument of :meth:`.Session.execute()`.
All arguments to :class:`Statement` apply to this class as well
"""
Statement.__init__(self, *args, **kwargs)
self._query_string = query_string
@@ -338,19 +351,16 @@ class PreparedStatement(object):
fetch_size = FETCH_SIZE_UNSET
def __init__(self, column_metadata, query_id, routing_key_indexes, query, keyspace,
protocol_version, consistency_level=None, serial_consistency_level=None,
fetch_size=FETCH_SIZE_UNSET):
custom_payload = None
def __init__(self, column_metadata, query_id, routing_key_indexes, query,
keyspace, protocol_version):
self.column_metadata = column_metadata
self.query_id = query_id
self.routing_key_indexes = routing_key_indexes
self.query_string = query
self.keyspace = keyspace
self.protocol_version = protocol_version
self.consistency_level = consistency_level
self.serial_consistency_level = serial_consistency_level
if fetch_size is not FETCH_SIZE_UNSET:
self.fetch_size = fetch_size
@classmethod
def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, query, prepared_keyspace, protocol_version):
@@ -402,8 +412,6 @@ class BoundStatement(Statement):
"""
A prepared statement that has been bound to a particular set of values.
These may be created directly or through :meth:`.PreparedStatement.bind()`.
All attributes of :class:`Statement` apply to this class as well.
"""
prepared_statement = None
@@ -419,13 +427,15 @@ class BoundStatement(Statement):
def __init__(self, prepared_statement, *args, **kwargs):
"""
`prepared_statement` should be an instance of :class:`PreparedStatement`.
All other ``*args`` and ``**kwargs`` will be passed to :class:`.Statement`.
All arguments to :class:`Statement` apply to this class as well
"""
self.prepared_statement = prepared_statement
self.consistency_level = prepared_statement.consistency_level
self.serial_consistency_level = prepared_statement.serial_consistency_level
self.fetch_size = prepared_statement.fetch_size
self.custom_payload = prepared_statement.custom_payload
self.values = []
meta = prepared_statement.column_metadata
@@ -606,7 +616,8 @@ class BatchStatement(Statement):
_session = None
def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None,
consistency_level=None, serial_consistency_level=None, session=None):
consistency_level=None, serial_consistency_level=None,
session=None, custom_payload=None):
"""
`batch_type` specifies The :class:`.BatchType` for the batch operation.
Defaults to :attr:`.BatchType.LOGGED`.
@@ -617,6 +628,11 @@ class BatchStatement(Statement):
`consistency_level` should be a :class:`~.ConsistencyLevel` value
to be used for all operations in the batch.
`custom_payload` is a :ref:`custom_payload` passed to the server.
Note: as Statement objects are added to the batch, this map is
updated with any values found in their custom payloads. These are
only allowed when using protocol version 4 or higher.
Example usage:
.. code-block:: python
@@ -642,12 +658,15 @@ class BatchStatement(Statement):
.. versionchanged:: 2.1.0
Added `serial_consistency_level` as a parameter
.. versionchanged:: 3.0.0
Added `custom_payload` as a parameter
"""
self.batch_type = batch_type
self._statements_and_parameters = []
self._session = session
Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level,
serial_consistency_level=serial_consistency_level)
serial_consistency_level=serial_consistency_level, custom_payload=custom_payload)
def add(self, statement, parameters=None):
"""
@@ -665,7 +684,7 @@ class BatchStatement(Statement):
elif isinstance(statement, PreparedStatement):
query_id = statement.query_id
bound_statement = statement.bind(() if parameters is None else parameters)
self._maybe_set_routing_attributes(bound_statement)
self._update_state(bound_statement)
self._statements_and_parameters.append(
(True, query_id, bound_statement.values))
elif isinstance(statement, BoundStatement):
@@ -673,7 +692,7 @@ class BatchStatement(Statement):
raise ValueError(
"Parameters cannot be passed with a BoundStatement "
"to BatchStatement.add()")
self._maybe_set_routing_attributes(statement)
self._update_state(statement)
self._statements_and_parameters.append(
(True, statement.prepared_statement.query_id, statement.values))
else:
@@ -682,7 +701,7 @@ class BatchStatement(Statement):
if parameters:
encoder = Encoder() if self._session is None else self._session.encoder
query_string = bind_params(query_string, parameters, encoder)
self._maybe_set_routing_attributes(statement)
self._update_state(statement)
self._statements_and_parameters.append((False, query_string, ()))
return self
@@ -701,6 +720,16 @@ class BatchStatement(Statement):
self.routing_key = statement.routing_key
self.keyspace = statement.keyspace
def _update_custom_payload(self, statement):
if statement.custom_payload:
if self.custom_payload is None:
self.custom_payload = {}
self.custom_payload.update(statement.custom_payload)
def _update_state(self, statement):
self._maybe_set_routing_attributes(statement)
self._update_custom_payload(statement)
def __str__(self):
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
return (u'<BatchStatement type=%s, statements=%d, consistency=%s>' %
@@ -850,7 +879,7 @@ class QueryTrace(object):
def _execute(self, query, parameters, time_spent, max_wait):
# in case the user switched the row factory, set it to namedtuple for this query
future = self._session._create_response_future(query, parameters, trace=False)
future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None)
future.row_factory = named_tuple_factory
future.send_request()

View File

@@ -102,6 +102,8 @@
.. automethod:: get_query_trace()
.. autoattribute:: custom_payload()
.. autoattribute:: has_more_pages
.. automethod:: start_fetching_next_page()

View File

@@ -0,0 +1,11 @@
.. _custom_payload:
Custom Payload
==============
Native protocol version 4+ allows for a custom payload to be sent between clients
and custom query handlers. The payload is specified as a string:binary_type dict
holding custom key/value pairs.
By default these are ignored by the server. They can be useful for servers implementing
a custom QueryHandler.

View File

@@ -14,6 +14,7 @@ Core Driver
cassandra/metrics
cassandra/query
cassandra/pool
cassandra/protocol
cassandra/encoder
cassandra/decoder
cassandra/concurrent

View File

@@ -128,8 +128,8 @@ class BoundStatementTestCase(unittest.TestCase):
routing_key_indexes=[],
query=None,
keyspace=keyspace,
protocol_version=2,
fetch_size=1234)
protocol_version=2)
prepared_statement.fetch_size = 1234
bound_statement = BoundStatement(prepared_statement=prepared_statement)
self.assertEqual(1234, bound_statement.fetch_size)
@@ -147,10 +147,9 @@ class BoundStatementTestCase(unittest.TestCase):
routing_key_indexes=[0, 1],
query=None,
keyspace=keyspace,
protocol_version=2,
fetch_size=1234)
protocol_version=2)
self.assertRaises(ValueError, prepared_statement.bind, (1,))
bound = prepared_statement.bind((1,2))
bound = prepared_statement.bind((1, 2))
self.assertEqual(bound.keyspace, keyspace)