Custom payloads for protocol v4
PYTHON-280
This commit is contained in:
@@ -78,7 +78,6 @@ Bug Fixes
|
|||||||
---------
|
---------
|
||||||
* Make execute_concurrent compatible with Python 2.6 (PYTHON-159)
|
* Make execute_concurrent compatible with Python 2.6 (PYTHON-159)
|
||||||
* Handle Unauthorized message on schema_triggers query (PYTHON-155)
|
* Handle Unauthorized message on schema_triggers query (PYTHON-155)
|
||||||
* Make execute_concurrent compatible with Python 2.6 (github-197)
|
|
||||||
* Pure Python sorted set in support of UDTs nested in collections (PYTON-167)
|
* Pure Python sorted set in support of UDTs nested in collections (PYTON-167)
|
||||||
* Support CUSTOM index metadata and string export (PYTHON-165)
|
* Support CUSTOM index metadata and string export (PYTHON-165)
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
SASLClient = None
|
SASLClient = None
|
||||||
|
|
||||||
|
|
||||||
class AuthProvider(object):
|
class AuthProvider(object):
|
||||||
"""
|
"""
|
||||||
An abstract class that defines the interface that will be used for
|
An abstract class that defines the interface that will be used for
|
||||||
@@ -157,6 +158,7 @@ class SaslAuthProvider(AuthProvider):
|
|||||||
def new_authenticator(self, host):
|
def new_authenticator(self, host):
|
||||||
return SaslAuthenticator(**self.sasl_kwargs)
|
return SaslAuthenticator(**self.sasl_kwargs)
|
||||||
|
|
||||||
|
|
||||||
class SaslAuthenticator(Authenticator):
|
class SaslAuthenticator(Authenticator):
|
||||||
"""
|
"""
|
||||||
A pass-through :class:`~.Authenticator` using the third party package
|
A pass-through :class:`~.Authenticator` using the third party package
|
||||||
|
|||||||
@@ -1361,7 +1361,7 @@ class Session(object):
|
|||||||
for future in futures:
|
for future in futures:
|
||||||
future.result()
|
future.result()
|
||||||
|
|
||||||
def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False):
|
def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False, custom_payload=None):
|
||||||
"""
|
"""
|
||||||
Execute the given query and synchronously wait for the response.
|
Execute the given query and synchronously wait for the response.
|
||||||
|
|
||||||
@@ -1389,6 +1389,10 @@ class Session(object):
|
|||||||
instance and not just a string. If there is an error fetching the
|
instance and not just a string. If there is an error fetching the
|
||||||
trace details, the :attr:`~.Statement.trace` attribute will be left as
|
trace details, the :attr:`~.Statement.trace` attribute will be left as
|
||||||
:const:`None`.
|
:const:`None`.
|
||||||
|
|
||||||
|
`custom_payload` is a dict as described in TODO section. If `query` is a Statement
|
||||||
|
with its own custom_payload. the message will be a union of the two,
|
||||||
|
with the values specified here taking precedence.
|
||||||
"""
|
"""
|
||||||
if timeout is _NOT_SET:
|
if timeout is _NOT_SET:
|
||||||
timeout = self.default_timeout
|
timeout = self.default_timeout
|
||||||
@@ -1398,7 +1402,7 @@ class Session(object):
|
|||||||
"The query argument must be an instance of a subclass of "
|
"The query argument must be an instance of a subclass of "
|
||||||
"cassandra.query.Statement when trace=True")
|
"cassandra.query.Statement when trace=True")
|
||||||
|
|
||||||
future = self.execute_async(query, parameters, trace)
|
future = self.execute_async(query, parameters, trace, custom_payload)
|
||||||
try:
|
try:
|
||||||
result = future.result(timeout)
|
result = future.result(timeout)
|
||||||
finally:
|
finally:
|
||||||
@@ -1410,7 +1414,7 @@ class Session(object):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def execute_async(self, query, parameters=None, trace=False):
|
def execute_async(self, query, parameters=None, trace=False, custom_payload=None):
|
||||||
"""
|
"""
|
||||||
Execute the given query and return a :class:`~.ResponseFuture` object
|
Execute the given query and return a :class:`~.ResponseFuture` object
|
||||||
which callbacks may be attached to for asynchronous response
|
which callbacks may be attached to for asynchronous response
|
||||||
@@ -1422,6 +1426,13 @@ class Session(object):
|
|||||||
:meth:`.ResponseFuture.get_query_trace()` after the request
|
:meth:`.ResponseFuture.get_query_trace()` after the request
|
||||||
completes to retrieve a :class:`.QueryTrace` instance.
|
completes to retrieve a :class:`.QueryTrace` instance.
|
||||||
|
|
||||||
|
`custom_payload` is a dict as described in TODO section. If `query` is
|
||||||
|
a Statement with a custom_payload specified. the message will be a
|
||||||
|
union of the two, with the values specified here taking precedence.
|
||||||
|
|
||||||
|
If the server sends a custom payload in the response message,
|
||||||
|
the dict can be obtained via :attr:`.ResponseFuture.custom_payload`
|
||||||
|
|
||||||
Example usage::
|
Example usage::
|
||||||
|
|
||||||
>>> session = cluster.connect()
|
>>> session = cluster.connect()
|
||||||
@@ -1447,11 +1458,11 @@ class Session(object):
|
|||||||
... log.exception("Operation failed:")
|
... log.exception("Operation failed:")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
future = self._create_response_future(query, parameters, trace)
|
future = self._create_response_future(query, parameters, trace, custom_payload)
|
||||||
future.send_request()
|
future.send_request()
|
||||||
return future
|
return future
|
||||||
|
|
||||||
def _create_response_future(self, query, parameters, trace):
|
def _create_response_future(self, query, parameters, trace, custom_payload):
|
||||||
""" Returns the ResponseFuture before calling send_request() on it """
|
""" Returns the ResponseFuture before calling send_request() on it """
|
||||||
|
|
||||||
prepared_statement = None
|
prepared_statement = None
|
||||||
@@ -1501,13 +1512,16 @@ class Session(object):
|
|||||||
if trace:
|
if trace:
|
||||||
message.tracing = True
|
message.tracing = True
|
||||||
|
|
||||||
|
message.update_custom_payload(query.custom_payload)
|
||||||
|
message.update_custom_payload(custom_payload)
|
||||||
|
|
||||||
return ResponseFuture(
|
return ResponseFuture(
|
||||||
self, message, query, self.default_timeout, metrics=self._metrics,
|
self, message, query, self.default_timeout, metrics=self._metrics,
|
||||||
prepared_statement=prepared_statement)
|
prepared_statement=prepared_statement)
|
||||||
|
|
||||||
def prepare(self, query):
|
def prepare(self, query, custom_payload=None):
|
||||||
"""
|
"""
|
||||||
Prepares a query string, returing a :class:`~cassandra.query.PreparedStatement`
|
Prepares a query string, returning a :class:`~cassandra.query.PreparedStatement`
|
||||||
instance which can be used as follows::
|
instance which can be used as follows::
|
||||||
|
|
||||||
>>> session = cluster.connect("mykeyspace")
|
>>> session = cluster.connect("mykeyspace")
|
||||||
@@ -1530,8 +1544,12 @@ class Session(object):
|
|||||||
|
|
||||||
**Important**: PreparedStatements should be prepared only once.
|
**Important**: PreparedStatements should be prepared only once.
|
||||||
Preparing the same query more than once will likely affect performance.
|
Preparing the same query more than once will likely affect performance.
|
||||||
|
|
||||||
|
`custom_payload` is a key value map to be passed along with the prepare
|
||||||
|
message. See TODO: refer to doc section
|
||||||
"""
|
"""
|
||||||
message = PrepareMessage(query=query)
|
message = PrepareMessage(query=query)
|
||||||
|
message.custom_payload = custom_payload
|
||||||
future = ResponseFuture(self, message, query=None)
|
future = ResponseFuture(self, message, query=None)
|
||||||
try:
|
try:
|
||||||
future.send_request()
|
future.send_request()
|
||||||
@@ -1543,6 +1561,7 @@ class Session(object):
|
|||||||
prepared_statement = PreparedStatement.from_message(
|
prepared_statement = PreparedStatement.from_message(
|
||||||
query_id, column_metadata, pk_indexes, self.cluster.metadata, query, self.keyspace,
|
query_id, column_metadata, pk_indexes, self.cluster.metadata, query, self.keyspace,
|
||||||
self._protocol_version)
|
self._protocol_version)
|
||||||
|
prepared_statement.custom_payload = future.custom_payload
|
||||||
|
|
||||||
host = future._current_host
|
host = future._current_host
|
||||||
try:
|
try:
|
||||||
@@ -2567,6 +2586,7 @@ class ResponseFuture(object):
|
|||||||
_start_time = None
|
_start_time = None
|
||||||
_metrics = None
|
_metrics = None
|
||||||
_paging_state = None
|
_paging_state = None
|
||||||
|
_custom_payload = None
|
||||||
|
|
||||||
def __init__(self, session, message, query, default_timeout=None, metrics=None, prepared_statement=None):
|
def __init__(self, session, message, query, default_timeout=None, metrics=None, prepared_statement=None):
|
||||||
self.session = session
|
self.session = session
|
||||||
@@ -2654,6 +2674,12 @@ class ResponseFuture(object):
|
|||||||
"""
|
"""
|
||||||
return self._paging_state is not None
|
return self._paging_state is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def custom_payload(self):
|
||||||
|
if not self._event.is_set():
|
||||||
|
raise Exception("custom_payload cannot be retrieved before ResponseFuture is finalized")
|
||||||
|
return self._custom_payload
|
||||||
|
|
||||||
def start_fetching_next_page(self):
|
def start_fetching_next_page(self):
|
||||||
"""
|
"""
|
||||||
If there are more pages left in the query result, this asynchronously
|
If there are more pages left in the query result, this asynchronously
|
||||||
@@ -2690,6 +2716,8 @@ class ResponseFuture(object):
|
|||||||
if trace_id:
|
if trace_id:
|
||||||
self._query_trace = QueryTrace(trace_id, self.session)
|
self._query_trace = QueryTrace(trace_id, self.session)
|
||||||
|
|
||||||
|
self._custom_payload = getattr(response, 'custom_payload', None)
|
||||||
|
|
||||||
if isinstance(response, ResultMessage):
|
if isinstance(response, ResultMessage):
|
||||||
if response.kind == RESULT_KIND_SET_KEYSPACE:
|
if response.kind == RESULT_KIND_SET_KEYSPACE:
|
||||||
session = getattr(self, 'session', None)
|
session = getattr(self, 'session', None)
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ HEADER_DIRECTION_MASK = 0x80
|
|||||||
|
|
||||||
COMPRESSED_FLAG = 0x01
|
COMPRESSED_FLAG = 0x01
|
||||||
TRACING_FLAG = 0x02
|
TRACING_FLAG = 0x02
|
||||||
|
CUSTOM_PAYLOAD_FLAG = 0x04
|
||||||
|
|
||||||
_message_types_by_name = {}
|
_message_types_by_name = {}
|
||||||
_message_types_by_opcode = {}
|
_message_types_by_opcode = {}
|
||||||
@@ -70,13 +71,19 @@ class _RegisterMessageType(type):
|
|||||||
class _MessageType(object):
|
class _MessageType(object):
|
||||||
|
|
||||||
tracing = False
|
tracing = False
|
||||||
|
custom_payload = None
|
||||||
|
|
||||||
def to_binary(self, stream_id, protocol_version, compression=None):
|
def to_binary(self, stream_id, protocol_version, compression=None):
|
||||||
|
flags = 0
|
||||||
body = io.BytesIO()
|
body = io.BytesIO()
|
||||||
|
if self.custom_payload:
|
||||||
|
if protocol_version < 4:
|
||||||
|
raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher")
|
||||||
|
flags |= CUSTOM_PAYLOAD_FLAG
|
||||||
|
write_bytesmap(body, self.custom_payload)
|
||||||
self.send_body(body, protocol_version)
|
self.send_body(body, protocol_version)
|
||||||
body = body.getvalue()
|
body = body.getvalue()
|
||||||
|
|
||||||
flags = 0
|
|
||||||
if compression and len(body) > 0:
|
if compression and len(body) > 0:
|
||||||
body = compression(body)
|
body = compression(body)
|
||||||
flags |= COMPRESSED_FLAG
|
flags |= COMPRESSED_FLAG
|
||||||
@@ -89,6 +96,12 @@ class _MessageType(object):
|
|||||||
|
|
||||||
return msg.getvalue()
|
return msg.getvalue()
|
||||||
|
|
||||||
|
def update_custom_payload(self, other):
|
||||||
|
if other:
|
||||||
|
if not self.custom_payload:
|
||||||
|
self.custom_payload = {}
|
||||||
|
self.custom_payload.update(other)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self)))
|
return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self)))
|
||||||
|
|
||||||
@@ -116,6 +129,12 @@ def decode_response(protocol_version, user_type_map, stream_id, flags, opcode, b
|
|||||||
else:
|
else:
|
||||||
trace_id = None
|
trace_id = None
|
||||||
|
|
||||||
|
if flags & CUSTOM_PAYLOAD_FLAG:
|
||||||
|
custom_payload = read_bytesmap(body)
|
||||||
|
flags ^= CUSTOM_PAYLOAD_FLAG
|
||||||
|
else:
|
||||||
|
custom_payload = None
|
||||||
|
|
||||||
if flags:
|
if flags:
|
||||||
log.warning("Unknown protocol flags set: %02x. May cause problems.", flags)
|
log.warning("Unknown protocol flags set: %02x. May cause problems.", flags)
|
||||||
|
|
||||||
@@ -123,6 +142,7 @@ def decode_response(protocol_version, user_type_map, stream_id, flags, opcode, b
|
|||||||
msg = msg_class.recv_body(body, protocol_version, user_type_map)
|
msg = msg_class.recv_body(body, protocol_version, user_type_map)
|
||||||
msg.stream_id = stream_id
|
msg.stream_id = stream_id
|
||||||
msg.trace_id = trace_id
|
msg.trace_id = trace_id
|
||||||
|
msg.custom_payload = custom_payload
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
@@ -918,6 +938,11 @@ def read_binary_string(f):
|
|||||||
return contents
|
return contents
|
||||||
|
|
||||||
|
|
||||||
|
def write_binary_string(f, s):
|
||||||
|
write_short(f, len(s))
|
||||||
|
f.write(s)
|
||||||
|
|
||||||
|
|
||||||
def write_string(f, s):
|
def write_string(f, s):
|
||||||
if isinstance(s, six.text_type):
|
if isinstance(s, six.text_type):
|
||||||
s = s.encode('utf8')
|
s = s.encode('utf8')
|
||||||
@@ -969,6 +994,22 @@ def write_stringmap(f, strmap):
|
|||||||
write_string(f, v)
|
write_string(f, v)
|
||||||
|
|
||||||
|
|
||||||
|
def read_bytesmap(f):
|
||||||
|
numpairs = read_short(f)
|
||||||
|
bytesmap = {}
|
||||||
|
for _ in range(numpairs):
|
||||||
|
k = read_string(f)
|
||||||
|
bytesmap[k] = read_binary_string(f)
|
||||||
|
return bytesmap
|
||||||
|
|
||||||
|
|
||||||
|
def write_bytesmap(f, bytesmap):
|
||||||
|
write_short(f, len(bytesmap))
|
||||||
|
for k, v in bytesmap.items():
|
||||||
|
write_string(f, k)
|
||||||
|
write_binary_string(f, v)
|
||||||
|
|
||||||
|
|
||||||
def read_stringmultimap(f):
|
def read_stringmultimap(f):
|
||||||
numkeys = read_short(f)
|
numkeys = read_short(f)
|
||||||
strmmap = {}
|
strmmap = {}
|
||||||
|
|||||||
@@ -197,11 +197,25 @@ class Statement(object):
|
|||||||
.. versionadded:: 2.1.3
|
.. versionadded:: 2.1.3
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
custom_payload = None
|
||||||
|
"""
|
||||||
|
TODO: refer to custom proto doc section
|
||||||
|
A string:binary_type dict holding custom key/value pairs to be passed
|
||||||
|
in the frame to a custom QueryHandler on the server side.
|
||||||
|
|
||||||
|
By default these values are ignored by the server.
|
||||||
|
|
||||||
|
These are only allowed when using protocol version 4 or higher.
|
||||||
|
|
||||||
|
.. versionadded:: 3.0.0
|
||||||
|
"""
|
||||||
|
|
||||||
_serial_consistency_level = None
|
_serial_consistency_level = None
|
||||||
_routing_key = None
|
_routing_key = None
|
||||||
|
|
||||||
def __init__(self, retry_policy=None, consistency_level=None, routing_key=None,
|
def __init__(self, retry_policy=None, consistency_level=None, routing_key=None,
|
||||||
serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None):
|
serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None,
|
||||||
|
custom_payload=None):
|
||||||
self.retry_policy = retry_policy
|
self.retry_policy = retry_policy
|
||||||
if consistency_level is not None:
|
if consistency_level is not None:
|
||||||
self.consistency_level = consistency_level
|
self.consistency_level = consistency_level
|
||||||
@@ -212,6 +226,8 @@ class Statement(object):
|
|||||||
self.fetch_size = fetch_size
|
self.fetch_size = fetch_size
|
||||||
if keyspace is not None:
|
if keyspace is not None:
|
||||||
self.keyspace = keyspace
|
self.keyspace = keyspace
|
||||||
|
if custom_payload is not None:
|
||||||
|
self.custom_payload = custom_payload
|
||||||
|
|
||||||
def _get_routing_key(self):
|
def _get_routing_key(self):
|
||||||
return self._routing_key
|
return self._routing_key
|
||||||
@@ -290,8 +306,7 @@ class Statement(object):
|
|||||||
|
|
||||||
class SimpleStatement(Statement):
|
class SimpleStatement(Statement):
|
||||||
"""
|
"""
|
||||||
A simple, un-prepared query. All attributes of :class:`Statement` apply
|
A simple, un-prepared query.
|
||||||
to this class as well.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, query_string, *args, **kwargs):
|
def __init__(self, query_string, *args, **kwargs):
|
||||||
@@ -299,6 +314,8 @@ class SimpleStatement(Statement):
|
|||||||
`query_string` should be a literal CQL statement with the exception
|
`query_string` should be a literal CQL statement with the exception
|
||||||
of parameter placeholders that will be filled through the
|
of parameter placeholders that will be filled through the
|
||||||
`parameters` argument of :meth:`.Session.execute()`.
|
`parameters` argument of :meth:`.Session.execute()`.
|
||||||
|
|
||||||
|
All arguments to :class:`Statement` apply to this class as well
|
||||||
"""
|
"""
|
||||||
Statement.__init__(self, *args, **kwargs)
|
Statement.__init__(self, *args, **kwargs)
|
||||||
self._query_string = query_string
|
self._query_string = query_string
|
||||||
@@ -338,6 +355,8 @@ class PreparedStatement(object):
|
|||||||
|
|
||||||
fetch_size = FETCH_SIZE_UNSET
|
fetch_size = FETCH_SIZE_UNSET
|
||||||
|
|
||||||
|
custom_payload = None
|
||||||
|
|
||||||
def __init__(self, column_metadata, query_id, routing_key_indexes, query,
|
def __init__(self, column_metadata, query_id, routing_key_indexes, query,
|
||||||
keyspace, protocol_version):
|
keyspace, protocol_version):
|
||||||
self.column_metadata = column_metadata
|
self.column_metadata = column_metadata
|
||||||
@@ -397,8 +416,6 @@ class BoundStatement(Statement):
|
|||||||
"""
|
"""
|
||||||
A prepared statement that has been bound to a particular set of values.
|
A prepared statement that has been bound to a particular set of values.
|
||||||
These may be created directly or through :meth:`.PreparedStatement.bind()`.
|
These may be created directly or through :meth:`.PreparedStatement.bind()`.
|
||||||
|
|
||||||
All attributes of :class:`Statement` apply to this class as well.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
prepared_statement = None
|
prepared_statement = None
|
||||||
@@ -414,13 +431,15 @@ class BoundStatement(Statement):
|
|||||||
def __init__(self, prepared_statement, *args, **kwargs):
|
def __init__(self, prepared_statement, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
`prepared_statement` should be an instance of :class:`PreparedStatement`.
|
`prepared_statement` should be an instance of :class:`PreparedStatement`.
|
||||||
All other ``*args`` and ``**kwargs`` will be passed to :class:`.Statement`.
|
|
||||||
|
All arguments to :class:`Statement` apply to this class as well
|
||||||
"""
|
"""
|
||||||
self.prepared_statement = prepared_statement
|
self.prepared_statement = prepared_statement
|
||||||
|
|
||||||
self.consistency_level = prepared_statement.consistency_level
|
self.consistency_level = prepared_statement.consistency_level
|
||||||
self.serial_consistency_level = prepared_statement.serial_consistency_level
|
self.serial_consistency_level = prepared_statement.serial_consistency_level
|
||||||
self.fetch_size = prepared_statement.fetch_size
|
self.fetch_size = prepared_statement.fetch_size
|
||||||
|
self.custom_payload = prepared_statement.custom_payload
|
||||||
self.values = []
|
self.values = []
|
||||||
|
|
||||||
meta = prepared_statement.column_metadata
|
meta = prepared_statement.column_metadata
|
||||||
@@ -601,7 +620,8 @@ class BatchStatement(Statement):
|
|||||||
_session = None
|
_session = None
|
||||||
|
|
||||||
def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None,
|
def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None,
|
||||||
consistency_level=None, serial_consistency_level=None, session=None):
|
consistency_level=None, serial_consistency_level=None,
|
||||||
|
session=None, custom_payload=None):
|
||||||
"""
|
"""
|
||||||
`batch_type` specifies The :class:`.BatchType` for the batch operation.
|
`batch_type` specifies The :class:`.BatchType` for the batch operation.
|
||||||
Defaults to :attr:`.BatchType.LOGGED`.
|
Defaults to :attr:`.BatchType.LOGGED`.
|
||||||
@@ -612,6 +632,10 @@ class BatchStatement(Statement):
|
|||||||
`consistency_level` should be a :class:`~.ConsistencyLevel` value
|
`consistency_level` should be a :class:`~.ConsistencyLevel` value
|
||||||
to be used for all operations in the batch.
|
to be used for all operations in the batch.
|
||||||
|
|
||||||
|
`custom_payload` is a key-value map TODO: refer to doc section
|
||||||
|
Note: as Statement objects are added to the batch, this map is
|
||||||
|
updated with values from their custom payloads.
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@@ -637,12 +661,15 @@ class BatchStatement(Statement):
|
|||||||
|
|
||||||
.. versionchanged:: 2.1.0
|
.. versionchanged:: 2.1.0
|
||||||
Added `serial_consistency_level` as a parameter
|
Added `serial_consistency_level` as a parameter
|
||||||
|
|
||||||
|
.. versionchanged:: 3.0.0
|
||||||
|
Added `custom_payload` as a parameter
|
||||||
"""
|
"""
|
||||||
self.batch_type = batch_type
|
self.batch_type = batch_type
|
||||||
self._statements_and_parameters = []
|
self._statements_and_parameters = []
|
||||||
self._session = session
|
self._session = session
|
||||||
Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level,
|
Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level,
|
||||||
serial_consistency_level=serial_consistency_level)
|
serial_consistency_level=serial_consistency_level, custom_payload=custom_payload)
|
||||||
|
|
||||||
def add(self, statement, parameters=None):
|
def add(self, statement, parameters=None):
|
||||||
"""
|
"""
|
||||||
@@ -660,7 +687,7 @@ class BatchStatement(Statement):
|
|||||||
elif isinstance(statement, PreparedStatement):
|
elif isinstance(statement, PreparedStatement):
|
||||||
query_id = statement.query_id
|
query_id = statement.query_id
|
||||||
bound_statement = statement.bind(() if parameters is None else parameters)
|
bound_statement = statement.bind(() if parameters is None else parameters)
|
||||||
self._maybe_set_routing_attributes(bound_statement)
|
self._update_state(bound_statement)
|
||||||
self._statements_and_parameters.append(
|
self._statements_and_parameters.append(
|
||||||
(True, query_id, bound_statement.values))
|
(True, query_id, bound_statement.values))
|
||||||
elif isinstance(statement, BoundStatement):
|
elif isinstance(statement, BoundStatement):
|
||||||
@@ -668,7 +695,7 @@ class BatchStatement(Statement):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Parameters cannot be passed with a BoundStatement "
|
"Parameters cannot be passed with a BoundStatement "
|
||||||
"to BatchStatement.add()")
|
"to BatchStatement.add()")
|
||||||
self._maybe_set_routing_attributes(statement)
|
self._update_state(statement)
|
||||||
self._statements_and_parameters.append(
|
self._statements_and_parameters.append(
|
||||||
(True, statement.prepared_statement.query_id, statement.values))
|
(True, statement.prepared_statement.query_id, statement.values))
|
||||||
else:
|
else:
|
||||||
@@ -677,7 +704,7 @@ class BatchStatement(Statement):
|
|||||||
if parameters:
|
if parameters:
|
||||||
encoder = Encoder() if self._session is None else self._session.encoder
|
encoder = Encoder() if self._session is None else self._session.encoder
|
||||||
query_string = bind_params(query_string, parameters, encoder)
|
query_string = bind_params(query_string, parameters, encoder)
|
||||||
self._maybe_set_routing_attributes(statement)
|
self._update_state(statement)
|
||||||
self._statements_and_parameters.append((False, query_string, ()))
|
self._statements_and_parameters.append((False, query_string, ()))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -696,6 +723,16 @@ class BatchStatement(Statement):
|
|||||||
self.routing_key = statement.routing_key
|
self.routing_key = statement.routing_key
|
||||||
self.keyspace = statement.keyspace
|
self.keyspace = statement.keyspace
|
||||||
|
|
||||||
|
def _update_custom_payload(self, statement):
|
||||||
|
if statement.custom_payload:
|
||||||
|
if self.custom_payload is None:
|
||||||
|
self.custom_payload = {}
|
||||||
|
self.custom_payload.update(statement.custom_payload)
|
||||||
|
|
||||||
|
def _update_state(self, statement):
|
||||||
|
self._maybe_set_routing_attributes(statement)
|
||||||
|
self._update_custom_payload(statement)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
|
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
|
||||||
return (u'<BatchStatement type=%s, statements=%d, consistency=%s>' %
|
return (u'<BatchStatement type=%s, statements=%d, consistency=%s>' %
|
||||||
@@ -836,7 +873,7 @@ class QueryTrace(object):
|
|||||||
|
|
||||||
def _execute(self, query, parameters, time_spent, max_wait):
|
def _execute(self, query, parameters, time_spent, max_wait):
|
||||||
# in case the user switched the row factory, set it to namedtuple for this query
|
# in case the user switched the row factory, set it to namedtuple for this query
|
||||||
future = self._session._create_response_future(query, parameters, trace=False)
|
future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None)
|
||||||
future.row_factory = named_tuple_factory
|
future.row_factory = named_tuple_factory
|
||||||
future.send_request()
|
future.send_request()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user