diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c8a0b24a..32083041 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,6 +6,9 @@ Bug Fixes --------- * Properly specify UDTs for columns in CREATE TABLE statements * Avoid moving retries to a new host when using request ID zero (PYTHON-88) +* Don't ignore fetch_size arguments to Statement constructors (github-151) +* Allow disabling automatic paging on a per-statement basis when it's + enabled by default for the session (PYTHON-93) Other ----- diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 08e61deb..5747fa2f 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -65,7 +65,7 @@ from cassandra.pool import (_ReconnectionHandler, _HostReconnectionHandler, NoConnectionsAvailable) from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement, BatchStatement, bind_params, QueryTrace, Statement, - named_tuple_factory, dict_factory) + named_tuple_factory, dict_factory, FETCH_SIZE_UNSET) # default to gevent when we are monkey patched, otherwise if libev is available, use that as the # default because it's fastest. Otherwise, use asyncore. @@ -1246,7 +1246,7 @@ class Session(object): cl = query.consistency_level if query.consistency_level is not None else self.default_consistency_level fetch_size = query.fetch_size - if not fetch_size and self._protocol_version >= 2: + if fetch_size is FETCH_SIZE_UNSET and self._protocol_version >= 2: fetch_size = self.default_fetch_size if self._protocol_version >= 3 and self.use_client_timestamp: diff --git a/cassandra/query.py b/cassandra/query.py index ecdad388..a79a2c26 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -133,6 +133,9 @@ def ordered_dict_factory(colnames, rows): return [OrderedDict(zip(colnames, row)) for row in rows] +FETCH_SIZE_UNSET = object() + + class Statement(object): """ An abstract class representing a single query. There are three subclasses: @@ -160,7 +163,7 @@ class Statement(object): the Session this is executed in will be used. """ - fetch_size = None + fetch_size = FETCH_SIZE_UNSET """ How many rows will be fetched at a time. This overrides the default of :attr:`.Session.default_fetch_size` @@ -175,13 +178,13 @@ class Statement(object): _routing_key = None def __init__(self, retry_policy=None, consistency_level=None, routing_key=None, - serial_consistency_level=None, fetch_size=None): + serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET): self.retry_policy = retry_policy if consistency_level is not None: self.consistency_level = consistency_level if serial_consistency_level is not None: self.serial_consistency_level = serial_consistency_level - if fetch_size is not None: + if fetch_size is not FETCH_SIZE_UNSET: self.fetch_size = fetch_size self._routing_key = routing_key @@ -315,9 +318,11 @@ class PreparedStatement(object): _protocol_version = None + fetch_size = FETCH_SIZE_UNSET + def __init__(self, column_metadata, query_id, routing_key_indexes, query, keyspace, protocol_version, consistency_level=None, serial_consistency_level=None, - fetch_size=None): + fetch_size=FETCH_SIZE_UNSET): self.column_metadata = column_metadata self.query_id = query_id self.routing_key_indexes = routing_key_indexes @@ -326,7 +331,8 @@ class PreparedStatement(object): self._protocol_version = protocol_version self.consistency_level = consistency_level self.serial_consistency_level = serial_consistency_level - self.fetch_size = fetch_size + if fetch_size is not FETCH_SIZE_UNSET: + self.fetch_size = fetch_size @classmethod def from_message(cls, query_id, column_metadata, cluster_metadata, query, keyspace, protocol_version): diff --git a/tests/integration/standard/test_query_paging.py b/tests/integration/standard/test_query_paging.py index 09438abc..448daf0c 100644 --- a/tests/integration/standard/test_query_paging.py +++ b/tests/integration/standard/test_query_paging.py @@ -26,7 +26,7 @@ from itertools import cycle, count from six.moves import range from threading import Event -from cassandra.cluster import Cluster +from cassandra.cluster import Cluster, PagedResult from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args from cassandra.policies import HostDistance from cassandra.query import SimpleStatement @@ -281,3 +281,79 @@ class QueryPagingTests(unittest.TestCase): for (success, result) in results: self.assertTrue(success) self.assertEquals(100, len(list(result))) + + def test_fetch_size(self): + """ + Ensure per-statement fetch_sizes override the default fetch size. + """ + statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), + [(i, ) for i in range(100)]) + execute_concurrent(self.session, list(statements_and_params)) + + prepared = self.session.prepare("SELECT * FROM test3rf.test") + + self.session.default_fetch_size = 10 + result = self.session.execute(prepared, []) + self.assertIsInstance(result, PagedResult) + + self.session.default_fetch_size = 2000 + result = self.session.execute(prepared, []) + self.assertIsInstance(result, list) + + self.session.default_fetch_size = None + result = self.session.execute(prepared, []) + self.assertIsInstance(result, list) + + self.session.default_fetch_size = 10 + + prepared.fetch_size = 2000 + result = self.session.execute(prepared, []) + self.assertIsInstance(result, list) + + prepared.fetch_size = None + result = self.session.execute(prepared, []) + self.assertIsInstance(result, list) + + prepared.fetch_size = 10 + result = self.session.execute(prepared, []) + self.assertIsInstance(result, PagedResult) + + prepared.fetch_size = 2000 + bound = prepared.bind([]) + result = self.session.execute(bound, []) + self.assertIsInstance(result, list) + + prepared.fetch_size = None + bound = prepared.bind([]) + result = self.session.execute(bound, []) + self.assertIsInstance(result, list) + + prepared.fetch_size = 10 + bound = prepared.bind([]) + result = self.session.execute(bound, []) + self.assertIsInstance(result, PagedResult) + + bound.fetch_size = 2000 + result = self.session.execute(bound, []) + self.assertIsInstance(result, list) + + bound.fetch_size = None + result = self.session.execute(bound, []) + self.assertIsInstance(result, list) + + bound.fetch_size = 10 + result = self.session.execute(bound, []) + self.assertIsInstance(result, PagedResult) + + s = SimpleStatement("SELECT * FROM test3rf.test", fetch_size=None) + result = self.session.execute(s, []) + self.assertIsInstance(result, list) + + s = SimpleStatement("SELECT * FROM test3rf.test") + result = self.session.execute(s, []) + self.assertIsInstance(result, PagedResult) + + s = SimpleStatement("SELECT * FROM test3rf.test") + s.fetch_size = None + result = self.session.execute(s, []) + self.assertIsInstance(result, list)