Fix tracing with non-default row_factory

For PYTHON-49
This commit is contained in:
Tyler Hobbs
2014-02-18 17:10:25 -06:00
parent 4698155981
commit 95ced181a1
4 changed files with 47 additions and 8 deletions

View File

@@ -13,6 +13,8 @@ Bug Fixes
Cassandra is a multiple of the read buffer size. Previously, if no more data
became available to read on the socket, the message would never be processed,
resulting in an OperationTimedOut error.
* Don't break tracing when a Session's row_factory is not the default
namedtuple_factory.
Other
-----

View File

@@ -1039,6 +1039,12 @@ class Session(object):
... log.exception("Operation failed:")
"""
future = self._create_response_future(query, parameters, trace)
future.send_request()
return future
def _create_response_future(self, query, parameters, trace):
""" Returns the ResponseFuture before calling send_request() on it """
prepared_statement = None
if isinstance(query, basestring):
query = SimpleStatement(query)
@@ -1060,11 +1066,9 @@ class Session(object):
if trace:
message.tracing = True
future = ResponseFuture(
return ResponseFuture(
self, message, query, self.default_timeout, metrics=self._metrics,
prepared_statement=prepared_statement)
future.send_request()
return future
def prepare(self, query):
"""

View File

@@ -8,10 +8,10 @@ from datetime import datetime, timedelta
import struct
import time
from cassandra import ConsistencyLevel
from cassandra import ConsistencyLevel, OperationTimedOut
from cassandra.cqltypes import unix_time_from_uuid1
from cassandra.decoder import (cql_encoders, cql_encode_object,
cql_encode_sequence)
cql_encode_sequence, named_tuple_factory)
import logging
log = logging.getLogger(__name__)
@@ -409,9 +409,13 @@ class QueryTrace(object):
attempt = 0
start = time.time()
while True:
if max_wait is not None and time.time() - start >= max_wait:
time_spent = time.time() - start
if max_wait is not None and time_spent >= max_wait:
raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,))
session_results = self._session.execute(self._SELECT_SESSIONS_FORMAT, (self.trace_id,))
session_results = self._execute(
self._SELECT_SESSIONS_FORMAT, (self.trace_id,), time_spent, max_wait)
if not session_results or session_results[0].duration is None:
time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt))
attempt += 1
@@ -424,11 +428,25 @@ class QueryTrace(object):
self.coordinator = session_row.coordinator
self.parameters = session_row.parameters
event_results = self._session.execute(self._SELECT_EVENTS_FORMAT, (self.trace_id,))
time_spent = time.time() - start
event_results = self._execute(
self._SELECT_EVENTS_FORMAT, (self.trace_id,), time_spent, max_wait)
self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread)
for r in event_results)
break
def _execute(self, query, parameters, time_spent, max_wait):
# in case the user switched the row factory, set it to namedtuple for this query
future = self._session._create_response_future(query, parameters, trace=False)
future.row_factory = named_tuple_factory
future.send_request()
timeout = (max_wait - time_spent) if max_wait is not None else None
try:
return future.result(timeout=timeout)
except OperationTimedOut:
raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,))
def __str__(self):
return "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" \
% (self.request_type, self.trace_id, self.coordinator, self.started_at,

View File

@@ -5,6 +5,7 @@ except ImportError:
from cassandra.query import PreparedStatement, BoundStatement, ValueSequence, SimpleStatement
from cassandra.cluster import Cluster
from cassandra.decoder import dict_factory
class QueryTest(unittest.TestCase):
@@ -49,6 +50,20 @@ class QueryTest(unittest.TestCase):
for event in statement.trace.events:
str(event)
def test_trace_ignores_row_factory(self):
cluster = Cluster()
session = cluster.connect()
session.row_factory = dict_factory
query = "SELECT * FROM system.local"
statement = SimpleStatement(query)
session.execute(statement, trace=True)
# Ensure this does not throw an exception
str(statement.trace)
for event in statement.trace.events:
str(event)
class PreparedStatementTests(unittest.TestCase):