Merge pull request #425 from datastax/430
PYTHON-430 - fix paged results for non-standard protocol handlers
This commit is contained in:
@@ -61,7 +61,7 @@ from cassandra.protocol import (QueryMessage, ResultMessage,
|
||||
BatchMessage, RESULT_KIND_PREPARED,
|
||||
RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS,
|
||||
RESULT_KIND_SCHEMA_CHANGE, MIN_SUPPORTED_VERSION,
|
||||
ProtocolHandler)
|
||||
ProtocolHandler, _RESULT_SEQUENCE_TYPES)
|
||||
from cassandra.metadata import Metadata, protect_name, murmur3
|
||||
from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy,
|
||||
ExponentialReconnectionPolicy, HostDistance,
|
||||
@@ -3349,7 +3349,7 @@ class ResultSet(object):
|
||||
|
||||
def __init__(self, response_future, initial_response):
|
||||
self.response_future = response_future
|
||||
self._current_rows = initial_response or []
|
||||
self._set_current_rows(initial_response)
|
||||
self._page_iter = None
|
||||
self._list_mode = False
|
||||
|
||||
@@ -3390,10 +3390,16 @@ class ResultSet(object):
|
||||
if self.response_future.has_more_pages:
|
||||
self.response_future.start_fetching_next_page()
|
||||
result = self.response_future.result()
|
||||
self._current_rows = result._current_rows
|
||||
self._current_rows = result._current_rows # ResultSet has already _set_current_rows to the appropriate form
|
||||
else:
|
||||
self._current_rows = []
|
||||
|
||||
def _set_current_rows(self, result):
|
||||
if isinstance(result, _RESULT_SEQUENCE_TYPES):
|
||||
self._current_rows = result
|
||||
else:
|
||||
self._current_rows = [result] if result else []
|
||||
|
||||
def _fetch_all(self):
|
||||
self._current_rows = list(self)
|
||||
self._page_iter = None
|
||||
|
||||
@@ -38,6 +38,8 @@ cdef class LazyParser(ColumnParser):
|
||||
# supported in cpdef methods
|
||||
return parse_rows_lazy(reader, desc)
|
||||
|
||||
cpdef get_cython_generator_type(self):
|
||||
return get_cython_generator_type()
|
||||
|
||||
def parse_rows_lazy(BytesIOReader reader, ParseDesc desc):
|
||||
cdef Py_ssize_t i, rowcount
|
||||
@@ -45,6 +47,8 @@ def parse_rows_lazy(BytesIOReader reader, ParseDesc desc):
|
||||
cdef RowParser rowparser = TupleRowParser()
|
||||
return (rowparser.unpack_row(reader, desc) for i in range(rowcount))
|
||||
|
||||
def get_cython_generator_type():
|
||||
return type((i for i in range(0)))
|
||||
|
||||
cdef class TupleRowParser(RowParser):
|
||||
"""
|
||||
|
||||
@@ -1000,6 +1000,7 @@ class ProtocolHandler(object):
|
||||
|
||||
return msg
|
||||
|
||||
_RESULT_SEQUENCE_TYPES = (list, tuple) # types retuned by ResultMessages
|
||||
|
||||
def cython_protocol_handler(colparser):
|
||||
"""
|
||||
@@ -1045,6 +1046,9 @@ def cython_protocol_handler(colparser):
|
||||
if HAVE_CYTHON:
|
||||
from cassandra.obj_parser import ListParser, LazyParser
|
||||
ProtocolHandler = cython_protocol_handler(ListParser())
|
||||
|
||||
lazy_parser = LazyParser()
|
||||
_RESULT_SEQUENCE_TYPES += (lazy_parser.get_cython_generator_type(),)
|
||||
LazyProtocolHandler = cython_protocol_handler(LazyParser())
|
||||
else:
|
||||
# Use Python-based ProtocolHandler
|
||||
|
||||
@@ -55,21 +55,79 @@ class CythonProtocolHandlerTest(unittest.TestCase):
|
||||
"""
|
||||
verify_iterator_data(self.assertEqual, get_data(LazyProtocolHandler))
|
||||
|
||||
@numpytest
|
||||
def test_cython_lazy_results_paged(self):
|
||||
"""
|
||||
Test Cython-based parser that returns an iterator, over multiple pages
|
||||
"""
|
||||
# arrays = { 'a': arr1, 'b': arr2, ... }
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect(keyspace="testspace")
|
||||
session.row_factory = tuple_factory
|
||||
session.client_protocol_handler = LazyProtocolHandler
|
||||
session.default_fetch_size = 2
|
||||
|
||||
self.assertLess(session.default_fetch_size, self.N_ITEMS)
|
||||
|
||||
results = session.execute("SELECT * FROM test_table")
|
||||
|
||||
self.assertTrue(results.has_more_pages)
|
||||
self.assertEqual(verify_iterator_data(self.assertEqual, results), self.N_ITEMS) # make sure we see all rows
|
||||
|
||||
cluster.shutdown()
|
||||
|
||||
@numpytest
|
||||
def test_numpy_parser(self):
|
||||
"""
|
||||
Test Numpy-based parser that returns a NumPy array
|
||||
"""
|
||||
# arrays = { 'a': arr1, 'b': arr2, ... }
|
||||
arrays = get_data(NumpyProtocolHandler)
|
||||
result = get_data(NumpyProtocolHandler)
|
||||
self.assertFalse(result.has_more_pages)
|
||||
self._verify_numpy_page(result[0])
|
||||
|
||||
@numpytest
|
||||
def test_numpy_results_paged(self):
|
||||
"""
|
||||
Test Numpy-based parser that returns a NumPy array
|
||||
"""
|
||||
# arrays = { 'a': arr1, 'b': arr2, ... }
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect(keyspace="testspace")
|
||||
session.row_factory = tuple_factory
|
||||
session.client_protocol_handler = NumpyProtocolHandler
|
||||
session.default_fetch_size = 2
|
||||
|
||||
expected_pages = (self.N_ITEMS + session.default_fetch_size - 1) // session.default_fetch_size
|
||||
|
||||
self.assertLess(session.default_fetch_size, self.N_ITEMS)
|
||||
|
||||
results = session.execute("SELECT * FROM test_table")
|
||||
|
||||
self.assertTrue(results.has_more_pages)
|
||||
for count, page in enumerate(results, 1):
|
||||
self.assertIsInstance(page, dict)
|
||||
for colname, arr in page.items():
|
||||
if count <= expected_pages:
|
||||
self.assertGreater(len(arr), 0, "page count: %d" % (count,))
|
||||
self.assertLessEqual(len(arr), session.default_fetch_size)
|
||||
else:
|
||||
# we get one extra item out of this iteration because of the way NumpyParser returns results
|
||||
# The last page is returned as a dict with zero-length arrays
|
||||
self.assertEqual(len(arr), 0)
|
||||
self.assertEqual(self._verify_numpy_page(page), len(arr))
|
||||
self.assertEqual(count, expected_pages + 1) # see note about extra 'page' above
|
||||
|
||||
cluster.shutdown()
|
||||
|
||||
def _verify_numpy_page(self, page):
|
||||
colnames = self.colnames
|
||||
datatypes = get_primitive_datatypes()
|
||||
for colname, datatype in zip(colnames, datatypes):
|
||||
arr = arrays[colname]
|
||||
arr = page[colname]
|
||||
self.match_dtype(datatype, arr.dtype)
|
||||
|
||||
verify_iterator_data(self.assertEqual, arrays_to_list_of_tuples(arrays, colnames))
|
||||
return verify_iterator_data(self.assertEqual, arrays_to_list_of_tuples(page, colnames))
|
||||
|
||||
def match_dtype(self, datatype, dtype):
|
||||
"""Match a string cqltype (e.g. 'int' or 'blob') with a numpy dtype"""
|
||||
@@ -100,9 +158,7 @@ def arrays_to_list_of_tuples(arrays, colnames):
|
||||
|
||||
def get_data(protocol_handler):
|
||||
"""
|
||||
Get some data from the test table.
|
||||
|
||||
:param key: if None, get all results (100.000 results), otherwise get only one result
|
||||
Get data from the test table.
|
||||
"""
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect(keyspace="testspace")
|
||||
@@ -121,9 +177,11 @@ def verify_iterator_data(assertEqual, results):
|
||||
Check the result of get_data() when this is a list or
|
||||
iterator of tuples
|
||||
"""
|
||||
for result in results:
|
||||
count = 0
|
||||
for count, result in enumerate(results, 1):
|
||||
params = get_all_primitive_params(result[0])
|
||||
assertEqual(len(params), len(result),
|
||||
msg="Not the right number of columns?")
|
||||
for expected, actual in zip(params, result):
|
||||
assertEqual(actual, expected)
|
||||
return count
|
||||
|
||||
Reference in New Issue
Block a user