114 lines
4.3 KiB
Python
114 lines
4.3 KiB
Python
from tests.integration import PROTOCOL_VERSION
|
|
|
|
import logging
|
|
log = logging.getLogger(__name__)
|
|
|
|
try:
|
|
import unittest2 as unittest
|
|
except ImportError:
|
|
import unittest # noqa
|
|
|
|
from itertools import cycle, count
|
|
from six.moves import range
|
|
from threading import Event
|
|
|
|
from cassandra.cluster import Cluster
|
|
from cassandra.concurrent import execute_concurrent
|
|
from cassandra.policies import HostDistance
|
|
from cassandra.query import SimpleStatement
|
|
|
|
|
|
class QueryPagingTests(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
if PROTOCOL_VERSION < 2:
|
|
raise unittest.SkipTest(
|
|
"Protocol 2.0+ is required for BATCH operations, currently testing against %r"
|
|
% (PROTOCOL_VERSION,))
|
|
|
|
self.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
|
self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1)
|
|
self.session = self.cluster.connect()
|
|
self.session.execute("TRUNCATE test3rf.test")
|
|
|
|
def test_paging(self):
|
|
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
|
|
[(i, ) for i in range(100)])
|
|
execute_concurrent(self.session, list(statements_and_params))
|
|
|
|
prepared = self.session.prepare("SELECT * FROM test3rf.test")
|
|
|
|
for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
|
|
self.session.default_fetch_size = fetch_size
|
|
self.assertEqual(100, len(list(self.session.execute("SELECT * FROM test3rf.test"))))
|
|
|
|
statement = SimpleStatement("SELECT * FROM test3rf.test")
|
|
self.assertEqual(100, len(list(self.session.execute(statement))))
|
|
|
|
self.assertEqual(100, len(list(self.session.execute(prepared))))
|
|
|
|
def test_async_paging(self):
|
|
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
|
|
[(i, ) for i in range(100)])
|
|
execute_concurrent(self.session, list(statements_and_params))
|
|
|
|
prepared = self.session.prepare("SELECT * FROM test3rf.test")
|
|
|
|
for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
|
|
self.session.default_fetch_size = fetch_size
|
|
self.assertEqual(100, len(list(self.session.execute_async("SELECT * FROM test3rf.test").result())))
|
|
|
|
statement = SimpleStatement("SELECT * FROM test3rf.test")
|
|
self.assertEqual(100, len(list(self.session.execute_async(statement).result())))
|
|
|
|
self.assertEqual(100, len(list(self.session.execute_async(prepared).result())))
|
|
|
|
def test_paging_callbacks(self):
|
|
statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]),
|
|
[(i, ) for i in range(100)])
|
|
execute_concurrent(self.session, list(statements_and_params))
|
|
|
|
prepared = self.session.prepare("SELECT * FROM test3rf.test")
|
|
|
|
for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000):
|
|
self.session.default_fetch_size = fetch_size
|
|
future = self.session.execute_async("SELECT * FROM test3rf.test")
|
|
|
|
event = Event()
|
|
counter = count()
|
|
|
|
def handle_page(rows, future, counter):
|
|
for row in rows:
|
|
next(counter)
|
|
|
|
if future.has_more_pages:
|
|
future.start_fetching_next_page()
|
|
else:
|
|
event.set()
|
|
|
|
def handle_error(err):
|
|
event.set()
|
|
self.fail(err)
|
|
|
|
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
|
|
event.wait()
|
|
self.assertEquals(next(counter), 100)
|
|
|
|
# simple statement
|
|
future = self.session.execute_async(SimpleStatement("SELECT * FROM test3rf.test"))
|
|
event.clear()
|
|
counter = count()
|
|
|
|
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
|
|
event.wait()
|
|
self.assertEquals(next(counter), 100)
|
|
|
|
# prepared statement
|
|
future = self.session.execute_async(prepared)
|
|
event.clear()
|
|
counter = count()
|
|
|
|
future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error)
|
|
event.wait()
|
|
self.assertEquals(next(counter), 100)
|