diff --git a/CHANGELOG.rst b/CHANGELOG.rst index af2d6113..8f914168 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -78,7 +78,6 @@ Bug Fixes --------- * Make execute_concurrent compatible with Python 2.6 (PYTHON-159) * Handle Unauthorized message on schema_triggers query (PYTHON-155) -* Make execute_concurrent compatible with Python 2.6 (github-197) * Pure Python sorted set in support of UDTs nested in collections (PYTON-167) * Support CUSTOM index metadata and string export (PYTHON-165) diff --git a/cassandra/auth.py b/cassandra/auth.py index 67d302a9..508bd150 100644 --- a/cassandra/auth.py +++ b/cassandra/auth.py @@ -15,6 +15,7 @@ try: except ImportError: SASLClient = None + class AuthProvider(object): """ An abstract class that defines the interface that will be used for @@ -157,6 +158,7 @@ class SaslAuthProvider(AuthProvider): def new_authenticator(self, host): return SaslAuthenticator(**self.sasl_kwargs) + class SaslAuthenticator(Authenticator): """ A pass-through :class:`~.Authenticator` using the third party package diff --git a/cassandra/cluster.py b/cassandra/cluster.py index cb04152a..e84afd2e 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1361,7 +1361,7 @@ class Session(object): for future in futures: future.result() - def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False): + def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False, custom_payload=None): """ Execute the given query and synchronously wait for the response. @@ -1389,6 +1389,10 @@ class Session(object): instance and not just a string. If there is an error fetching the trace details, the :attr:`~.Statement.trace` attribute will be left as :const:`None`. + + `custom_payload` is a dict as described in TODO section. If `query` is a Statement + with its own custom_payload. the message will be a union of the two, + with the values specified here taking precedence. """ if timeout is _NOT_SET: timeout = self.default_timeout @@ -1398,7 +1402,7 @@ class Session(object): "The query argument must be an instance of a subclass of " "cassandra.query.Statement when trace=True") - future = self.execute_async(query, parameters, trace) + future = self.execute_async(query, parameters, trace, custom_payload) try: result = future.result(timeout) finally: @@ -1410,7 +1414,7 @@ class Session(object): return result - def execute_async(self, query, parameters=None, trace=False): + def execute_async(self, query, parameters=None, trace=False, custom_payload=None): """ Execute the given query and return a :class:`~.ResponseFuture` object which callbacks may be attached to for asynchronous response @@ -1422,6 +1426,13 @@ class Session(object): :meth:`.ResponseFuture.get_query_trace()` after the request completes to retrieve a :class:`.QueryTrace` instance. + `custom_payload` is a dict as described in TODO section. If `query` is + a Statement with a custom_payload specified. the message will be a + union of the two, with the values specified here taking precedence. + + If the server sends a custom payload in the response message, + the dict can be obtained via :attr:`.ResponseFuture.custom_payload` + Example usage:: >>> session = cluster.connect() @@ -1447,11 +1458,11 @@ class Session(object): ... log.exception("Operation failed:") """ - future = self._create_response_future(query, parameters, trace) + future = self._create_response_future(query, parameters, trace, custom_payload) future.send_request() return future - def _create_response_future(self, query, parameters, trace): + def _create_response_future(self, query, parameters, trace, custom_payload): """ Returns the ResponseFuture before calling send_request() on it """ prepared_statement = None @@ -1501,13 +1512,16 @@ class Session(object): if trace: message.tracing = True + message.update_custom_payload(query.custom_payload) + message.update_custom_payload(custom_payload) + return ResponseFuture( self, message, query, self.default_timeout, metrics=self._metrics, prepared_statement=prepared_statement) - def prepare(self, query): + def prepare(self, query, custom_payload=None): """ - Prepares a query string, returing a :class:`~cassandra.query.PreparedStatement` + Prepares a query string, returning a :class:`~cassandra.query.PreparedStatement` instance which can be used as follows:: >>> session = cluster.connect("mykeyspace") @@ -1530,8 +1544,12 @@ class Session(object): **Important**: PreparedStatements should be prepared only once. Preparing the same query more than once will likely affect performance. + + `custom_payload` is a key value map to be passed along with the prepare + message. See TODO: refer to doc section """ message = PrepareMessage(query=query) + message.custom_payload = custom_payload future = ResponseFuture(self, message, query=None) try: future.send_request() @@ -1543,6 +1561,7 @@ class Session(object): prepared_statement = PreparedStatement.from_message( query_id, column_metadata, pk_indexes, self.cluster.metadata, query, self.keyspace, self._protocol_version) + prepared_statement.custom_payload = future.custom_payload host = future._current_host try: @@ -2567,6 +2586,7 @@ class ResponseFuture(object): _start_time = None _metrics = None _paging_state = None + _custom_payload = None def __init__(self, session, message, query, default_timeout=None, metrics=None, prepared_statement=None): self.session = session @@ -2654,6 +2674,12 @@ class ResponseFuture(object): """ return self._paging_state is not None + @property + def custom_payload(self): + if not self._event.is_set(): + raise Exception("custom_payload cannot be retrieved before ResponseFuture is finalized") + return self._custom_payload + def start_fetching_next_page(self): """ If there are more pages left in the query result, this asynchronously @@ -2690,6 +2716,8 @@ class ResponseFuture(object): if trace_id: self._query_trace = QueryTrace(trace_id, self.session) + self._custom_payload = getattr(response, 'custom_payload', None) + if isinstance(response, ResultMessage): if response.kind == RESULT_KIND_SET_KEYSPACE: session = getattr(self, 'session', None) diff --git a/cassandra/protocol.py b/cassandra/protocol.py index d81a6507..a48af348 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -54,6 +54,7 @@ HEADER_DIRECTION_MASK = 0x80 COMPRESSED_FLAG = 0x01 TRACING_FLAG = 0x02 +CUSTOM_PAYLOAD_FLAG = 0x04 _message_types_by_name = {} _message_types_by_opcode = {} @@ -70,13 +71,19 @@ class _RegisterMessageType(type): class _MessageType(object): tracing = False + custom_payload = None def to_binary(self, stream_id, protocol_version, compression=None): + flags = 0 body = io.BytesIO() + if self.custom_payload: + if protocol_version < 4: + raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher") + flags |= CUSTOM_PAYLOAD_FLAG + write_bytesmap(body, self.custom_payload) self.send_body(body, protocol_version) body = body.getvalue() - flags = 0 if compression and len(body) > 0: body = compression(body) flags |= COMPRESSED_FLAG @@ -89,6 +96,12 @@ class _MessageType(object): return msg.getvalue() + def update_custom_payload(self, other): + if other: + if not self.custom_payload: + self.custom_payload = {} + self.custom_payload.update(other) + def __repr__(self): return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self))) @@ -116,6 +129,12 @@ def decode_response(protocol_version, user_type_map, stream_id, flags, opcode, b else: trace_id = None + if flags & CUSTOM_PAYLOAD_FLAG: + custom_payload = read_bytesmap(body) + flags ^= CUSTOM_PAYLOAD_FLAG + else: + custom_payload = None + if flags: log.warning("Unknown protocol flags set: %02x. May cause problems.", flags) @@ -123,6 +142,7 @@ def decode_response(protocol_version, user_type_map, stream_id, flags, opcode, b msg = msg_class.recv_body(body, protocol_version, user_type_map) msg.stream_id = stream_id msg.trace_id = trace_id + msg.custom_payload = custom_payload return msg @@ -918,6 +938,11 @@ def read_binary_string(f): return contents +def write_binary_string(f, s): + write_short(f, len(s)) + f.write(s) + + def write_string(f, s): if isinstance(s, six.text_type): s = s.encode('utf8') @@ -969,6 +994,22 @@ def write_stringmap(f, strmap): write_string(f, v) +def read_bytesmap(f): + numpairs = read_short(f) + bytesmap = {} + for _ in range(numpairs): + k = read_string(f) + bytesmap[k] = read_binary_string(f) + return bytesmap + + +def write_bytesmap(f, bytesmap): + write_short(f, len(bytesmap)) + for k, v in bytesmap.items(): + write_string(f, k) + write_binary_string(f, v) + + def read_stringmultimap(f): numkeys = read_short(f) strmmap = {} diff --git a/cassandra/query.py b/cassandra/query.py index 031a1684..70d16e9e 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -197,11 +197,25 @@ class Statement(object): .. versionadded:: 2.1.3 """ + custom_payload = None + """ + TODO: refer to custom proto doc section + A string:binary_type dict holding custom key/value pairs to be passed + in the frame to a custom QueryHandler on the server side. + + By default these values are ignored by the server. + + These are only allowed when using protocol version 4 or higher. + + .. versionadded:: 3.0.0 + """ + _serial_consistency_level = None _routing_key = None def __init__(self, retry_policy=None, consistency_level=None, routing_key=None, - serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None): + serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, + custom_payload=None): self.retry_policy = retry_policy if consistency_level is not None: self.consistency_level = consistency_level @@ -212,6 +226,8 @@ class Statement(object): self.fetch_size = fetch_size if keyspace is not None: self.keyspace = keyspace + if custom_payload is not None: + self.custom_payload = custom_payload def _get_routing_key(self): return self._routing_key @@ -290,8 +306,7 @@ class Statement(object): class SimpleStatement(Statement): """ - A simple, un-prepared query. All attributes of :class:`Statement` apply - to this class as well. + A simple, un-prepared query. """ def __init__(self, query_string, *args, **kwargs): @@ -299,6 +314,8 @@ class SimpleStatement(Statement): `query_string` should be a literal CQL statement with the exception of parameter placeholders that will be filled through the `parameters` argument of :meth:`.Session.execute()`. + + All arguments to :class:`Statement` apply to this class as well """ Statement.__init__(self, *args, **kwargs) self._query_string = query_string @@ -338,6 +355,8 @@ class PreparedStatement(object): fetch_size = FETCH_SIZE_UNSET + custom_payload = None + def __init__(self, column_metadata, query_id, routing_key_indexes, query, keyspace, protocol_version): self.column_metadata = column_metadata @@ -397,8 +416,6 @@ class BoundStatement(Statement): """ A prepared statement that has been bound to a particular set of values. These may be created directly or through :meth:`.PreparedStatement.bind()`. - - All attributes of :class:`Statement` apply to this class as well. """ prepared_statement = None @@ -414,13 +431,15 @@ class BoundStatement(Statement): def __init__(self, prepared_statement, *args, **kwargs): """ `prepared_statement` should be an instance of :class:`PreparedStatement`. - All other ``*args`` and ``**kwargs`` will be passed to :class:`.Statement`. + + All arguments to :class:`Statement` apply to this class as well """ self.prepared_statement = prepared_statement self.consistency_level = prepared_statement.consistency_level self.serial_consistency_level = prepared_statement.serial_consistency_level self.fetch_size = prepared_statement.fetch_size + self.custom_payload = prepared_statement.custom_payload self.values = [] meta = prepared_statement.column_metadata @@ -601,7 +620,8 @@ class BatchStatement(Statement): _session = None def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None, - consistency_level=None, serial_consistency_level=None, session=None): + consistency_level=None, serial_consistency_level=None, + session=None, custom_payload=None): """ `batch_type` specifies The :class:`.BatchType` for the batch operation. Defaults to :attr:`.BatchType.LOGGED`. @@ -612,6 +632,10 @@ class BatchStatement(Statement): `consistency_level` should be a :class:`~.ConsistencyLevel` value to be used for all operations in the batch. + `custom_payload` is a key-value map TODO: refer to doc section + Note: as Statement objects are added to the batch, this map is + updated with values from their custom payloads. + Example usage: .. code-block:: python @@ -637,12 +661,15 @@ class BatchStatement(Statement): .. versionchanged:: 2.1.0 Added `serial_consistency_level` as a parameter + + .. versionchanged:: 3.0.0 + Added `custom_payload` as a parameter """ self.batch_type = batch_type self._statements_and_parameters = [] self._session = session Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level, - serial_consistency_level=serial_consistency_level) + serial_consistency_level=serial_consistency_level, custom_payload=custom_payload) def add(self, statement, parameters=None): """ @@ -660,7 +687,7 @@ class BatchStatement(Statement): elif isinstance(statement, PreparedStatement): query_id = statement.query_id bound_statement = statement.bind(() if parameters is None else parameters) - self._maybe_set_routing_attributes(bound_statement) + self._update_state(bound_statement) self._statements_and_parameters.append( (True, query_id, bound_statement.values)) elif isinstance(statement, BoundStatement): @@ -668,7 +695,7 @@ class BatchStatement(Statement): raise ValueError( "Parameters cannot be passed with a BoundStatement " "to BatchStatement.add()") - self._maybe_set_routing_attributes(statement) + self._update_state(statement) self._statements_and_parameters.append( (True, statement.prepared_statement.query_id, statement.values)) else: @@ -677,7 +704,7 @@ class BatchStatement(Statement): if parameters: encoder = Encoder() if self._session is None else self._session.encoder query_string = bind_params(query_string, parameters, encoder) - self._maybe_set_routing_attributes(statement) + self._update_state(statement) self._statements_and_parameters.append((False, query_string, ())) return self @@ -696,6 +723,16 @@ class BatchStatement(Statement): self.routing_key = statement.routing_key self.keyspace = statement.keyspace + def _update_custom_payload(self, statement): + if statement.custom_payload: + if self.custom_payload is None: + self.custom_payload = {} + self.custom_payload.update(statement.custom_payload) + + def _update_state(self, statement): + self._maybe_set_routing_attributes(statement) + self._update_custom_payload(statement) + def __str__(self): consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') return (u'' % @@ -836,7 +873,7 @@ class QueryTrace(object): def _execute(self, query, parameters, time_spent, max_wait): # in case the user switched the row factory, set it to namedtuple for this query - future = self._session._create_response_future(query, parameters, trace=False) + future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None) future.row_factory = named_tuple_factory future.send_request()