Merge pull request #425 from datastax/430

PYTHON-430 - fix paged results for non-standard protocol handlers
This commit is contained in:
Michael Penick
2015-10-20 11:52:35 -07:00
4 changed files with 82 additions and 10 deletions

View File

@@ -61,7 +61,7 @@ from cassandra.protocol import (QueryMessage, ResultMessage,
BatchMessage, RESULT_KIND_PREPARED, BatchMessage, RESULT_KIND_PREPARED,
RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS, RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS,
RESULT_KIND_SCHEMA_CHANGE, MIN_SUPPORTED_VERSION, RESULT_KIND_SCHEMA_CHANGE, MIN_SUPPORTED_VERSION,
ProtocolHandler) ProtocolHandler, _RESULT_SEQUENCE_TYPES)
from cassandra.metadata import Metadata, protect_name, murmur3 from cassandra.metadata import Metadata, protect_name, murmur3
from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy, from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy,
ExponentialReconnectionPolicy, HostDistance, ExponentialReconnectionPolicy, HostDistance,
@@ -3349,7 +3349,7 @@ class ResultSet(object):
def __init__(self, response_future, initial_response): def __init__(self, response_future, initial_response):
self.response_future = response_future self.response_future = response_future
self._current_rows = initial_response or [] self._set_current_rows(initial_response)
self._page_iter = None self._page_iter = None
self._list_mode = False self._list_mode = False
@@ -3390,10 +3390,16 @@ class ResultSet(object):
if self.response_future.has_more_pages: if self.response_future.has_more_pages:
self.response_future.start_fetching_next_page() self.response_future.start_fetching_next_page()
result = self.response_future.result() result = self.response_future.result()
self._current_rows = result._current_rows self._current_rows = result._current_rows # ResultSet has already _set_current_rows to the appropriate form
else: else:
self._current_rows = [] self._current_rows = []
def _set_current_rows(self, result):
if isinstance(result, _RESULT_SEQUENCE_TYPES):
self._current_rows = result
else:
self._current_rows = [result] if result else []
def _fetch_all(self): def _fetch_all(self):
self._current_rows = list(self) self._current_rows = list(self)
self._page_iter = None self._page_iter = None

View File

@@ -38,6 +38,8 @@ cdef class LazyParser(ColumnParser):
# supported in cpdef methods # supported in cpdef methods
return parse_rows_lazy(reader, desc) return parse_rows_lazy(reader, desc)
cpdef get_cython_generator_type(self):
return get_cython_generator_type()
def parse_rows_lazy(BytesIOReader reader, ParseDesc desc): def parse_rows_lazy(BytesIOReader reader, ParseDesc desc):
cdef Py_ssize_t i, rowcount cdef Py_ssize_t i, rowcount
@@ -45,6 +47,8 @@ def parse_rows_lazy(BytesIOReader reader, ParseDesc desc):
cdef RowParser rowparser = TupleRowParser() cdef RowParser rowparser = TupleRowParser()
return (rowparser.unpack_row(reader, desc) for i in range(rowcount)) return (rowparser.unpack_row(reader, desc) for i in range(rowcount))
def get_cython_generator_type():
return type((i for i in range(0)))
cdef class TupleRowParser(RowParser): cdef class TupleRowParser(RowParser):
""" """

View File

@@ -1000,6 +1000,7 @@ class ProtocolHandler(object):
return msg return msg
_RESULT_SEQUENCE_TYPES = (list, tuple) # types retuned by ResultMessages
def cython_protocol_handler(colparser): def cython_protocol_handler(colparser):
""" """
@@ -1045,6 +1046,9 @@ def cython_protocol_handler(colparser):
if HAVE_CYTHON: if HAVE_CYTHON:
from cassandra.obj_parser import ListParser, LazyParser from cassandra.obj_parser import ListParser, LazyParser
ProtocolHandler = cython_protocol_handler(ListParser()) ProtocolHandler = cython_protocol_handler(ListParser())
lazy_parser = LazyParser()
_RESULT_SEQUENCE_TYPES += (lazy_parser.get_cython_generator_type(),)
LazyProtocolHandler = cython_protocol_handler(LazyParser()) LazyProtocolHandler = cython_protocol_handler(LazyParser())
else: else:
# Use Python-based ProtocolHandler # Use Python-based ProtocolHandler

View File

@@ -55,21 +55,79 @@ class CythonProtocolHandlerTest(unittest.TestCase):
""" """
verify_iterator_data(self.assertEqual, get_data(LazyProtocolHandler)) verify_iterator_data(self.assertEqual, get_data(LazyProtocolHandler))
@numpytest
def test_cython_lazy_results_paged(self):
"""
Test Cython-based parser that returns an iterator, over multiple pages
"""
# arrays = { 'a': arr1, 'b': arr2, ... }
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect(keyspace="testspace")
session.row_factory = tuple_factory
session.client_protocol_handler = LazyProtocolHandler
session.default_fetch_size = 2
self.assertLess(session.default_fetch_size, self.N_ITEMS)
results = session.execute("SELECT * FROM test_table")
self.assertTrue(results.has_more_pages)
self.assertEqual(verify_iterator_data(self.assertEqual, results), self.N_ITEMS) # make sure we see all rows
cluster.shutdown()
@numpytest @numpytest
def test_numpy_parser(self): def test_numpy_parser(self):
""" """
Test Numpy-based parser that returns a NumPy array Test Numpy-based parser that returns a NumPy array
""" """
# arrays = { 'a': arr1, 'b': arr2, ... } # arrays = { 'a': arr1, 'b': arr2, ... }
arrays = get_data(NumpyProtocolHandler) result = get_data(NumpyProtocolHandler)
self.assertFalse(result.has_more_pages)
self._verify_numpy_page(result[0])
@numpytest
def test_numpy_results_paged(self):
"""
Test Numpy-based parser that returns a NumPy array
"""
# arrays = { 'a': arr1, 'b': arr2, ... }
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect(keyspace="testspace")
session.row_factory = tuple_factory
session.client_protocol_handler = NumpyProtocolHandler
session.default_fetch_size = 2
expected_pages = (self.N_ITEMS + session.default_fetch_size - 1) // session.default_fetch_size
self.assertLess(session.default_fetch_size, self.N_ITEMS)
results = session.execute("SELECT * FROM test_table")
self.assertTrue(results.has_more_pages)
for count, page in enumerate(results, 1):
self.assertIsInstance(page, dict)
for colname, arr in page.items():
if count <= expected_pages:
self.assertGreater(len(arr), 0, "page count: %d" % (count,))
self.assertLessEqual(len(arr), session.default_fetch_size)
else:
# we get one extra item out of this iteration because of the way NumpyParser returns results
# The last page is returned as a dict with zero-length arrays
self.assertEqual(len(arr), 0)
self.assertEqual(self._verify_numpy_page(page), len(arr))
self.assertEqual(count, expected_pages + 1) # see note about extra 'page' above
cluster.shutdown()
def _verify_numpy_page(self, page):
colnames = self.colnames colnames = self.colnames
datatypes = get_primitive_datatypes() datatypes = get_primitive_datatypes()
for colname, datatype in zip(colnames, datatypes): for colname, datatype in zip(colnames, datatypes):
arr = arrays[colname] arr = page[colname]
self.match_dtype(datatype, arr.dtype) self.match_dtype(datatype, arr.dtype)
verify_iterator_data(self.assertEqual, arrays_to_list_of_tuples(arrays, colnames)) return verify_iterator_data(self.assertEqual, arrays_to_list_of_tuples(page, colnames))
def match_dtype(self, datatype, dtype): def match_dtype(self, datatype, dtype):
"""Match a string cqltype (e.g. 'int' or 'blob') with a numpy dtype""" """Match a string cqltype (e.g. 'int' or 'blob') with a numpy dtype"""
@@ -100,9 +158,7 @@ def arrays_to_list_of_tuples(arrays, colnames):
def get_data(protocol_handler): def get_data(protocol_handler):
""" """
Get some data from the test table. Get data from the test table.
:param key: if None, get all results (100.000 results), otherwise get only one result
""" """
cluster = Cluster(protocol_version=PROTOCOL_VERSION) cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect(keyspace="testspace") session = cluster.connect(keyspace="testspace")
@@ -121,9 +177,11 @@ def verify_iterator_data(assertEqual, results):
Check the result of get_data() when this is a list or Check the result of get_data() when this is a list or
iterator of tuples iterator of tuples
""" """
for result in results: count = 0
for count, result in enumerate(results, 1):
params = get_all_primitive_params(result[0]) params = get_all_primitive_params(result[0])
assertEqual(len(params), len(result), assertEqual(len(params), len(result),
msg="Not the right number of columns?") msg="Not the right number of columns?")
for expected, actual in zip(params, result): for expected, actual in zip(params, result):
assertEqual(actual, expected) assertEqual(actual, expected)
return count