Fix disabling paging on a per-statement basis

Relates to #151
Fixes PYTHON-93
This commit is contained in:
Tyler Hobbs
2014-07-17 15:57:52 -05:00
parent 2d5754d110
commit d8973e7eca
4 changed files with 93 additions and 8 deletions

View File

@@ -6,6 +6,9 @@ Bug Fixes
---------
* Properly specify UDTs for columns in CREATE TABLE statements
* Avoid moving retries to a new host when using request ID zero (PYTHON-88)
* Don't ignore fetch_size arguments to Statement constructors (github-151)
* Allow disabling automatic paging on a per-statement basis when it's
enabled by default for the session (PYTHON-93)
Other
-----

View File

@@ -65,7 +65,7 @@ from cassandra.pool import (_ReconnectionHandler, _HostReconnectionHandler,
NoConnectionsAvailable)
from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement,
BatchStatement, bind_params, QueryTrace, Statement,
named_tuple_factory, dict_factory)
named_tuple_factory, dict_factory, FETCH_SIZE_UNSET)
# default to gevent when we are monkey patched, otherwise if libev is available, use that as the
# default because it's fastest. Otherwise, use asyncore.
@@ -1246,7 +1246,7 @@ class Session(object):
cl = query.consistency_level if query.consistency_level is not None else self.default_consistency_level
fetch_size = query.fetch_size
if not fetch_size and self._protocol_version >= 2:
if fetch_size is FETCH_SIZE_UNSET and self._protocol_version >= 2:
fetch_size = self.default_fetch_size
if self._protocol_version >= 3 and self.use_client_timestamp:

View File

@@ -133,6 +133,9 @@ def ordered_dict_factory(colnames, rows):
return [OrderedDict(zip(colnames, row)) for row in rows]
FETCH_SIZE_UNSET = object()
class Statement(object):
"""
An abstract class representing a single query. There are three subclasses:
@@ -160,7 +163,7 @@ class Statement(object):
the Session this is executed in will be used.
"""
fetch_size = None
fetch_size = FETCH_SIZE_UNSET
"""
How many rows will be fetched at a time. This overrides the default
of :attr:`.Session.default_fetch_size`
@@ -175,13 +178,13 @@ class Statement(object):
_routing_key = None
def __init__(self, retry_policy=None, consistency_level=None, routing_key=None,
serial_consistency_level=None, fetch_size=None):
serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET):
self.retry_policy = retry_policy
if consistency_level is not None:
self.consistency_level = consistency_level
if serial_consistency_level is not None:
self.serial_consistency_level = serial_consistency_level
if fetch_size is not None:
if fetch_size is not FETCH_SIZE_UNSET:
self.fetch_size = fetch_size
self._routing_key = routing_key
@@ -315,9 +318,11 @@ class PreparedStatement(object):
_protocol_version = None
fetch_size = FETCH_SIZE_UNSET
def __init__(self, column_metadata, query_id, routing_key_indexes, query, keyspace,
protocol_version, consistency_level=None, serial_consistency_level=None,
fetch_size=None):
fetch_size=FETCH_SIZE_UNSET):
self.column_metadata = column_metadata
self.query_id = query_id
self.routing_key_indexes = routing_key_indexes
@@ -326,7 +331,8 @@ class PreparedStatement(object):
self._protocol_version = protocol_version
self.consistency_level = consistency_level
self.serial_consistency_level = serial_consistency_level
self.fetch_size = fetch_size
if fetch_size is not FETCH_SIZE_UNSET:
self.fetch_size = fetch_size
@classmethod
def from_message(cls, query_id, column_metadata, cluster_metadata, query, keyspace, protocol_version):

View File

@@ -26,7 +26,7 @@ from itertools import cycle, count
from six.moves import range
from threading import Event
from cassandra.cluster import Cluster
from cassandra.cluster import Cluster, PagedResult
from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args
from cassandra.policies import HostDistance
from cassandra.query import SimpleStatement
@@ -281,3 +281,79 @@ class QueryPagingTests(unittest.TestCase):
for (success, result) in results:
self.assertTrue(success)
self.assertEquals(100, len(list(result)))
def test_fetch_size(self):
"""
Ensure per-statement fetch_sizes override the default fetch size.
"""
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
[(i, ) for i in range(100)])
execute_concurrent(self.session, list(statements_and_params))
prepared = self.session.prepare("SELECT * FROM test3rf.test")
self.session.default_fetch_size = 10
result = self.session.execute(prepared, [])
self.assertIsInstance(result, PagedResult)
self.session.default_fetch_size = 2000
result = self.session.execute(prepared, [])
self.assertIsInstance(result, list)
self.session.default_fetch_size = None
result = self.session.execute(prepared, [])
self.assertIsInstance(result, list)
self.session.default_fetch_size = 10
prepared.fetch_size = 2000
result = self.session.execute(prepared, [])
self.assertIsInstance(result, list)
prepared.fetch_size = None
result = self.session.execute(prepared, [])
self.assertIsInstance(result, list)
prepared.fetch_size = 10
result = self.session.execute(prepared, [])
self.assertIsInstance(result, PagedResult)
prepared.fetch_size = 2000
bound = prepared.bind([])
result = self.session.execute(bound, [])
self.assertIsInstance(result, list)
prepared.fetch_size = None
bound = prepared.bind([])
result = self.session.execute(bound, [])
self.assertIsInstance(result, list)
prepared.fetch_size = 10
bound = prepared.bind([])
result = self.session.execute(bound, [])
self.assertIsInstance(result, PagedResult)
bound.fetch_size = 2000
result = self.session.execute(bound, [])
self.assertIsInstance(result, list)
bound.fetch_size = None
result = self.session.execute(bound, [])
self.assertIsInstance(result, list)
bound.fetch_size = 10
result = self.session.execute(bound, [])
self.assertIsInstance(result, PagedResult)
s = SimpleStatement("SELECT * FROM test3rf.test", fetch_size=None)
result = self.session.execute(s, [])
self.assertIsInstance(result, list)
s = SimpleStatement("SELECT * FROM test3rf.test")
result = self.session.execute(s, [])
self.assertIsInstance(result, PagedResult)
s = SimpleStatement("SELECT * FROM test3rf.test")
s.fetch_size = None
result = self.session.execute(s, [])
self.assertIsInstance(result, list)