Fix disabling paging on a per-statement basis

Relates to #151
Fixes PYTHON-93
This commit is contained in:
Tyler Hobbs
2014-07-17 15:57:52 -05:00
parent 2d5754d110
commit d8973e7eca
4 changed files with 93 additions and 8 deletions

View File

@@ -6,6 +6,9 @@ Bug Fixes
--------- ---------
* Properly specify UDTs for columns in CREATE TABLE statements * Properly specify UDTs for columns in CREATE TABLE statements
* Avoid moving retries to a new host when using request ID zero (PYTHON-88) * Avoid moving retries to a new host when using request ID zero (PYTHON-88)
* Don't ignore fetch_size arguments to Statement constructors (github-151)
* Allow disabling automatic paging on a per-statement basis when it's
enabled by default for the session (PYTHON-93)
Other Other
----- -----

View File

@@ -65,7 +65,7 @@ from cassandra.pool import (_ReconnectionHandler, _HostReconnectionHandler,
NoConnectionsAvailable) NoConnectionsAvailable)
from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement, from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement,
BatchStatement, bind_params, QueryTrace, Statement, BatchStatement, bind_params, QueryTrace, Statement,
named_tuple_factory, dict_factory) named_tuple_factory, dict_factory, FETCH_SIZE_UNSET)
# default to gevent when we are monkey patched, otherwise if libev is available, use that as the # default to gevent when we are monkey patched, otherwise if libev is available, use that as the
# default because it's fastest. Otherwise, use asyncore. # default because it's fastest. Otherwise, use asyncore.
@@ -1246,7 +1246,7 @@ class Session(object):
cl = query.consistency_level if query.consistency_level is not None else self.default_consistency_level cl = query.consistency_level if query.consistency_level is not None else self.default_consistency_level
fetch_size = query.fetch_size fetch_size = query.fetch_size
if not fetch_size and self._protocol_version >= 2: if fetch_size is FETCH_SIZE_UNSET and self._protocol_version >= 2:
fetch_size = self.default_fetch_size fetch_size = self.default_fetch_size
if self._protocol_version >= 3 and self.use_client_timestamp: if self._protocol_version >= 3 and self.use_client_timestamp:

View File

@@ -133,6 +133,9 @@ def ordered_dict_factory(colnames, rows):
return [OrderedDict(zip(colnames, row)) for row in rows] return [OrderedDict(zip(colnames, row)) for row in rows]
FETCH_SIZE_UNSET = object()
class Statement(object): class Statement(object):
""" """
An abstract class representing a single query. There are three subclasses: An abstract class representing a single query. There are three subclasses:
@@ -160,7 +163,7 @@ class Statement(object):
the Session this is executed in will be used. the Session this is executed in will be used.
""" """
fetch_size = None fetch_size = FETCH_SIZE_UNSET
""" """
How many rows will be fetched at a time. This overrides the default How many rows will be fetched at a time. This overrides the default
of :attr:`.Session.default_fetch_size` of :attr:`.Session.default_fetch_size`
@@ -175,13 +178,13 @@ class Statement(object):
_routing_key = None _routing_key = None
def __init__(self, retry_policy=None, consistency_level=None, routing_key=None, def __init__(self, retry_policy=None, consistency_level=None, routing_key=None,
serial_consistency_level=None, fetch_size=None): serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET):
self.retry_policy = retry_policy self.retry_policy = retry_policy
if consistency_level is not None: if consistency_level is not None:
self.consistency_level = consistency_level self.consistency_level = consistency_level
if serial_consistency_level is not None: if serial_consistency_level is not None:
self.serial_consistency_level = serial_consistency_level self.serial_consistency_level = serial_consistency_level
if fetch_size is not None: if fetch_size is not FETCH_SIZE_UNSET:
self.fetch_size = fetch_size self.fetch_size = fetch_size
self._routing_key = routing_key self._routing_key = routing_key
@@ -315,9 +318,11 @@ class PreparedStatement(object):
_protocol_version = None _protocol_version = None
fetch_size = FETCH_SIZE_UNSET
def __init__(self, column_metadata, query_id, routing_key_indexes, query, keyspace, def __init__(self, column_metadata, query_id, routing_key_indexes, query, keyspace,
protocol_version, consistency_level=None, serial_consistency_level=None, protocol_version, consistency_level=None, serial_consistency_level=None,
fetch_size=None): fetch_size=FETCH_SIZE_UNSET):
self.column_metadata = column_metadata self.column_metadata = column_metadata
self.query_id = query_id self.query_id = query_id
self.routing_key_indexes = routing_key_indexes self.routing_key_indexes = routing_key_indexes
@@ -326,7 +331,8 @@ class PreparedStatement(object):
self._protocol_version = protocol_version self._protocol_version = protocol_version
self.consistency_level = consistency_level self.consistency_level = consistency_level
self.serial_consistency_level = serial_consistency_level self.serial_consistency_level = serial_consistency_level
self.fetch_size = fetch_size if fetch_size is not FETCH_SIZE_UNSET:
self.fetch_size = fetch_size
@classmethod @classmethod
def from_message(cls, query_id, column_metadata, cluster_metadata, query, keyspace, protocol_version): def from_message(cls, query_id, column_metadata, cluster_metadata, query, keyspace, protocol_version):

View File

@@ -26,7 +26,7 @@ from itertools import cycle, count
from six.moves import range from six.moves import range
from threading import Event from threading import Event
from cassandra.cluster import Cluster from cassandra.cluster import Cluster, PagedResult
from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args
from cassandra.policies import HostDistance from cassandra.policies import HostDistance
from cassandra.query import SimpleStatement from cassandra.query import SimpleStatement
@@ -281,3 +281,79 @@ class QueryPagingTests(unittest.TestCase):
for (success, result) in results: for (success, result) in results:
self.assertTrue(success) self.assertTrue(success)
self.assertEquals(100, len(list(result))) self.assertEquals(100, len(list(result)))
def test_fetch_size(self):
"""
Ensure per-statement fetch_sizes override the default fetch size.
"""
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
[(i, ) for i in range(100)])
execute_concurrent(self.session, list(statements_and_params))
prepared = self.session.prepare("SELECT * FROM test3rf.test")
self.session.default_fetch_size = 10
result = self.session.execute(prepared, [])
self.assertIsInstance(result, PagedResult)
self.session.default_fetch_size = 2000
result = self.session.execute(prepared, [])
self.assertIsInstance(result, list)
self.session.default_fetch_size = None
result = self.session.execute(prepared, [])
self.assertIsInstance(result, list)
self.session.default_fetch_size = 10
prepared.fetch_size = 2000
result = self.session.execute(prepared, [])
self.assertIsInstance(result, list)
prepared.fetch_size = None
result = self.session.execute(prepared, [])
self.assertIsInstance(result, list)
prepared.fetch_size = 10
result = self.session.execute(prepared, [])
self.assertIsInstance(result, PagedResult)
prepared.fetch_size = 2000
bound = prepared.bind([])
result = self.session.execute(bound, [])
self.assertIsInstance(result, list)
prepared.fetch_size = None
bound = prepared.bind([])
result = self.session.execute(bound, [])
self.assertIsInstance(result, list)
prepared.fetch_size = 10
bound = prepared.bind([])
result = self.session.execute(bound, [])
self.assertIsInstance(result, PagedResult)
bound.fetch_size = 2000
result = self.session.execute(bound, [])
self.assertIsInstance(result, list)
bound.fetch_size = None
result = self.session.execute(bound, [])
self.assertIsInstance(result, list)
bound.fetch_size = 10
result = self.session.execute(bound, [])
self.assertIsInstance(result, PagedResult)
s = SimpleStatement("SELECT * FROM test3rf.test", fetch_size=None)
result = self.session.execute(s, [])
self.assertIsInstance(result, list)
s = SimpleStatement("SELECT * FROM test3rf.test")
result = self.session.execute(s, [])
self.assertIsInstance(result, PagedResult)
s = SimpleStatement("SELECT * FROM test3rf.test")
s.fetch_size = None
result = self.session.execute(s, [])
self.assertIsInstance(result, list)