diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a064fa01..deb3cb1b 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,6 +13,8 @@ Bug Fixes Cassandra is a multiple of the read buffer size. Previously, if no more data became available to read on the socket, the message would never be processed, resulting in an OperationTimedOut error. +* Don't break tracing when a Session's row_factory is not the default + namedtuple_factory. Other ----- diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5b6d8602..f672b836 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1039,6 +1039,12 @@ class Session(object): ... log.exception("Operation failed:") """ + future = self._create_response_future(query, parameters, trace) + future.send_request() + return future + + def _create_response_future(self, query, parameters, trace): + """ Returns the ResponseFuture before calling send_request() on it """ prepared_statement = None if isinstance(query, basestring): query = SimpleStatement(query) @@ -1060,11 +1066,9 @@ class Session(object): if trace: message.tracing = True - future = ResponseFuture( + return ResponseFuture( self, message, query, self.default_timeout, metrics=self._metrics, prepared_statement=prepared_statement) - future.send_request() - return future def prepare(self, query): """ diff --git a/cassandra/query.py b/cassandra/query.py index 698eb5d6..327650d6 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -8,10 +8,10 @@ from datetime import datetime, timedelta import struct import time -from cassandra import ConsistencyLevel +from cassandra import ConsistencyLevel, OperationTimedOut from cassandra.cqltypes import unix_time_from_uuid1 from cassandra.decoder import (cql_encoders, cql_encode_object, - cql_encode_sequence) + cql_encode_sequence, named_tuple_factory) import logging log = logging.getLogger(__name__) @@ -409,9 +409,13 @@ class QueryTrace(object): attempt = 0 start = time.time() while True: - if max_wait is not None and time.time() - start >= max_wait: + time_spent = time.time() - start + if max_wait is not None and time_spent >= max_wait: raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,)) - session_results = self._session.execute(self._SELECT_SESSIONS_FORMAT, (self.trace_id,)) + + session_results = self._execute( + self._SELECT_SESSIONS_FORMAT, (self.trace_id,), time_spent, max_wait) + if not session_results or session_results[0].duration is None: time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt)) attempt += 1 @@ -424,11 +428,25 @@ class QueryTrace(object): self.coordinator = session_row.coordinator self.parameters = session_row.parameters - event_results = self._session.execute(self._SELECT_EVENTS_FORMAT, (self.trace_id,)) + time_spent = time.time() - start + event_results = self._execute( + self._SELECT_EVENTS_FORMAT, (self.trace_id,), time_spent, max_wait) self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread) for r in event_results) break + def _execute(self, query, parameters, time_spent, max_wait): + # in case the user switched the row factory, set it to namedtuple for this query + future = self._session._create_response_future(query, parameters, trace=False) + future.row_factory = named_tuple_factory + future.send_request() + + timeout = (max_wait - time_spent) if max_wait is not None else None + try: + return future.result(timeout=timeout) + except OperationTimedOut: + raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,)) + def __str__(self): return "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" \ % (self.request_type, self.trace_id, self.coordinator, self.started_at, diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index b6aad8a7..36e69f8b 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -5,6 +5,7 @@ except ImportError: from cassandra.query import PreparedStatement, BoundStatement, ValueSequence, SimpleStatement from cassandra.cluster import Cluster +from cassandra.decoder import dict_factory class QueryTest(unittest.TestCase): @@ -49,6 +50,20 @@ class QueryTest(unittest.TestCase): for event in statement.trace.events: str(event) + def test_trace_ignores_row_factory(self): + cluster = Cluster() + session = cluster.connect() + session.row_factory = dict_factory + + query = "SELECT * FROM system.local" + statement = SimpleStatement(query) + session.execute(statement, trace=True) + + # Ensure this does not throw an exception + str(statement.trace) + for event in statement.trace.events: + str(event) + class PreparedStatementTests(unittest.TestCase):