diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 44c91e01..5024a163 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -61,7 +61,7 @@ from cassandra.protocol import (QueryMessage, ResultMessage, BatchMessage, RESULT_KIND_PREPARED, RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS, RESULT_KIND_SCHEMA_CHANGE, MIN_SUPPORTED_VERSION, - ProtocolHandler) + ProtocolHandler, _RESULT_SEQUENCE_TYPES) from cassandra.metadata import Metadata, protect_name, murmur3 from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy, ExponentialReconnectionPolicy, HostDistance, @@ -3349,7 +3349,7 @@ class ResultSet(object): def __init__(self, response_future, initial_response): self.response_future = response_future - self._current_rows = initial_response or [] + self._set_current_rows(initial_response) self._page_iter = None self._list_mode = False @@ -3390,10 +3390,16 @@ class ResultSet(object): if self.response_future.has_more_pages: self.response_future.start_fetching_next_page() result = self.response_future.result() - self._current_rows = result._current_rows + self._current_rows = result._current_rows # ResultSet has already _set_current_rows to the appropriate form else: self._current_rows = [] + def _set_current_rows(self, result): + if isinstance(result, _RESULT_SEQUENCE_TYPES): + self._current_rows = result + else: + self._current_rows = [result] if result else [] + def _fetch_all(self): self._current_rows = list(self) self._page_iter = None diff --git a/cassandra/obj_parser.pyx b/cassandra/obj_parser.pyx index 8aa5b394..21ce95e0 100644 --- a/cassandra/obj_parser.pyx +++ b/cassandra/obj_parser.pyx @@ -38,6 +38,8 @@ cdef class LazyParser(ColumnParser): # supported in cpdef methods return parse_rows_lazy(reader, desc) + cpdef get_cython_generator_type(self): + return get_cython_generator_type() def parse_rows_lazy(BytesIOReader reader, ParseDesc desc): cdef Py_ssize_t i, rowcount @@ -45,6 +47,8 @@ def parse_rows_lazy(BytesIOReader reader, ParseDesc desc): cdef RowParser rowparser = TupleRowParser() return (rowparser.unpack_row(reader, desc) for i in range(rowcount)) +def get_cython_generator_type(): + return type((i for i in range(0))) cdef class TupleRowParser(RowParser): """ diff --git a/cassandra/protocol.py b/cassandra/protocol.py index bc87f9e4..eec9ac37 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -1000,6 +1000,7 @@ class ProtocolHandler(object): return msg +_RESULT_SEQUENCE_TYPES = (list, tuple) # types retuned by ResultMessages def cython_protocol_handler(colparser): """ @@ -1045,6 +1046,9 @@ def cython_protocol_handler(colparser): if HAVE_CYTHON: from cassandra.obj_parser import ListParser, LazyParser ProtocolHandler = cython_protocol_handler(ListParser()) + + lazy_parser = LazyParser() + _RESULT_SEQUENCE_TYPES += (lazy_parser.get_cython_generator_type(),) LazyProtocolHandler = cython_protocol_handler(LazyParser()) else: # Use Python-based ProtocolHandler diff --git a/tests/integration/standard/test_cython_protocol_handlers.py b/tests/integration/standard/test_cython_protocol_handlers.py index 985b7953..888f73fe 100644 --- a/tests/integration/standard/test_cython_protocol_handlers.py +++ b/tests/integration/standard/test_cython_protocol_handlers.py @@ -55,21 +55,79 @@ class CythonProtocolHandlerTest(unittest.TestCase): """ verify_iterator_data(self.assertEqual, get_data(LazyProtocolHandler)) + @numpytest + def test_cython_lazy_results_paged(self): + """ + Test Cython-based parser that returns an iterator, over multiple pages + """ + # arrays = { 'a': arr1, 'b': arr2, ... } + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect(keyspace="testspace") + session.row_factory = tuple_factory + session.client_protocol_handler = LazyProtocolHandler + session.default_fetch_size = 2 + + self.assertLess(session.default_fetch_size, self.N_ITEMS) + + results = session.execute("SELECT * FROM test_table") + + self.assertTrue(results.has_more_pages) + self.assertEqual(verify_iterator_data(self.assertEqual, results), self.N_ITEMS) # make sure we see all rows + + cluster.shutdown() + @numpytest def test_numpy_parser(self): """ Test Numpy-based parser that returns a NumPy array """ # arrays = { 'a': arr1, 'b': arr2, ... } - arrays = get_data(NumpyProtocolHandler) + result = get_data(NumpyProtocolHandler) + self.assertFalse(result.has_more_pages) + self._verify_numpy_page(result[0]) + @numpytest + def test_numpy_results_paged(self): + """ + Test Numpy-based parser that returns a NumPy array + """ + # arrays = { 'a': arr1, 'b': arr2, ... } + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect(keyspace="testspace") + session.row_factory = tuple_factory + session.client_protocol_handler = NumpyProtocolHandler + session.default_fetch_size = 2 + + expected_pages = (self.N_ITEMS + session.default_fetch_size - 1) // session.default_fetch_size + + self.assertLess(session.default_fetch_size, self.N_ITEMS) + + results = session.execute("SELECT * FROM test_table") + + self.assertTrue(results.has_more_pages) + for count, page in enumerate(results, 1): + self.assertIsInstance(page, dict) + for colname, arr in page.items(): + if count <= expected_pages: + self.assertGreater(len(arr), 0, "page count: %d" % (count,)) + self.assertLessEqual(len(arr), session.default_fetch_size) + else: + # we get one extra item out of this iteration because of the way NumpyParser returns results + # The last page is returned as a dict with zero-length arrays + self.assertEqual(len(arr), 0) + self.assertEqual(self._verify_numpy_page(page), len(arr)) + self.assertEqual(count, expected_pages + 1) # see note about extra 'page' above + + cluster.shutdown() + + def _verify_numpy_page(self, page): colnames = self.colnames datatypes = get_primitive_datatypes() for colname, datatype in zip(colnames, datatypes): - arr = arrays[colname] + arr = page[colname] self.match_dtype(datatype, arr.dtype) - verify_iterator_data(self.assertEqual, arrays_to_list_of_tuples(arrays, colnames)) + return verify_iterator_data(self.assertEqual, arrays_to_list_of_tuples(page, colnames)) def match_dtype(self, datatype, dtype): """Match a string cqltype (e.g. 'int' or 'blob') with a numpy dtype""" @@ -100,9 +158,7 @@ def arrays_to_list_of_tuples(arrays, colnames): def get_data(protocol_handler): """ - Get some data from the test table. - - :param key: if None, get all results (100.000 results), otherwise get only one result + Get data from the test table. """ cluster = Cluster(protocol_version=PROTOCOL_VERSION) session = cluster.connect(keyspace="testspace") @@ -121,9 +177,11 @@ def verify_iterator_data(assertEqual, results): Check the result of get_data() when this is a list or iterator of tuples """ - for result in results: + count = 0 + for count, result in enumerate(results, 1): params = get_all_primitive_params(result[0]) assertEqual(len(params), len(result), msg="Not the right number of columns?") for expected, actual in zip(params, result): assertEqual(actual, expected) + return count