Merge pull request #425 from datastax/430
PYTHON-430 - fix paged results for non-standard protocol handlers
This commit is contained in:
@@ -61,7 +61,7 @@ from cassandra.protocol import (QueryMessage, ResultMessage,
|
|||||||
BatchMessage, RESULT_KIND_PREPARED,
|
BatchMessage, RESULT_KIND_PREPARED,
|
||||||
RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS,
|
RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS,
|
||||||
RESULT_KIND_SCHEMA_CHANGE, MIN_SUPPORTED_VERSION,
|
RESULT_KIND_SCHEMA_CHANGE, MIN_SUPPORTED_VERSION,
|
||||||
ProtocolHandler)
|
ProtocolHandler, _RESULT_SEQUENCE_TYPES)
|
||||||
from cassandra.metadata import Metadata, protect_name, murmur3
|
from cassandra.metadata import Metadata, protect_name, murmur3
|
||||||
from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy,
|
from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy,
|
||||||
ExponentialReconnectionPolicy, HostDistance,
|
ExponentialReconnectionPolicy, HostDistance,
|
||||||
@@ -3349,7 +3349,7 @@ class ResultSet(object):
|
|||||||
|
|
||||||
def __init__(self, response_future, initial_response):
|
def __init__(self, response_future, initial_response):
|
||||||
self.response_future = response_future
|
self.response_future = response_future
|
||||||
self._current_rows = initial_response or []
|
self._set_current_rows(initial_response)
|
||||||
self._page_iter = None
|
self._page_iter = None
|
||||||
self._list_mode = False
|
self._list_mode = False
|
||||||
|
|
||||||
@@ -3390,10 +3390,16 @@ class ResultSet(object):
|
|||||||
if self.response_future.has_more_pages:
|
if self.response_future.has_more_pages:
|
||||||
self.response_future.start_fetching_next_page()
|
self.response_future.start_fetching_next_page()
|
||||||
result = self.response_future.result()
|
result = self.response_future.result()
|
||||||
self._current_rows = result._current_rows
|
self._current_rows = result._current_rows # ResultSet has already _set_current_rows to the appropriate form
|
||||||
else:
|
else:
|
||||||
self._current_rows = []
|
self._current_rows = []
|
||||||
|
|
||||||
|
def _set_current_rows(self, result):
|
||||||
|
if isinstance(result, _RESULT_SEQUENCE_TYPES):
|
||||||
|
self._current_rows = result
|
||||||
|
else:
|
||||||
|
self._current_rows = [result] if result else []
|
||||||
|
|
||||||
def _fetch_all(self):
|
def _fetch_all(self):
|
||||||
self._current_rows = list(self)
|
self._current_rows = list(self)
|
||||||
self._page_iter = None
|
self._page_iter = None
|
||||||
|
|||||||
@@ -38,6 +38,8 @@ cdef class LazyParser(ColumnParser):
|
|||||||
# supported in cpdef methods
|
# supported in cpdef methods
|
||||||
return parse_rows_lazy(reader, desc)
|
return parse_rows_lazy(reader, desc)
|
||||||
|
|
||||||
|
cpdef get_cython_generator_type(self):
|
||||||
|
return get_cython_generator_type()
|
||||||
|
|
||||||
def parse_rows_lazy(BytesIOReader reader, ParseDesc desc):
|
def parse_rows_lazy(BytesIOReader reader, ParseDesc desc):
|
||||||
cdef Py_ssize_t i, rowcount
|
cdef Py_ssize_t i, rowcount
|
||||||
@@ -45,6 +47,8 @@ def parse_rows_lazy(BytesIOReader reader, ParseDesc desc):
|
|||||||
cdef RowParser rowparser = TupleRowParser()
|
cdef RowParser rowparser = TupleRowParser()
|
||||||
return (rowparser.unpack_row(reader, desc) for i in range(rowcount))
|
return (rowparser.unpack_row(reader, desc) for i in range(rowcount))
|
||||||
|
|
||||||
|
def get_cython_generator_type():
|
||||||
|
return type((i for i in range(0)))
|
||||||
|
|
||||||
cdef class TupleRowParser(RowParser):
|
cdef class TupleRowParser(RowParser):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1000,6 +1000,7 @@ class ProtocolHandler(object):
|
|||||||
|
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
_RESULT_SEQUENCE_TYPES = (list, tuple) # types retuned by ResultMessages
|
||||||
|
|
||||||
def cython_protocol_handler(colparser):
|
def cython_protocol_handler(colparser):
|
||||||
"""
|
"""
|
||||||
@@ -1045,6 +1046,9 @@ def cython_protocol_handler(colparser):
|
|||||||
if HAVE_CYTHON:
|
if HAVE_CYTHON:
|
||||||
from cassandra.obj_parser import ListParser, LazyParser
|
from cassandra.obj_parser import ListParser, LazyParser
|
||||||
ProtocolHandler = cython_protocol_handler(ListParser())
|
ProtocolHandler = cython_protocol_handler(ListParser())
|
||||||
|
|
||||||
|
lazy_parser = LazyParser()
|
||||||
|
_RESULT_SEQUENCE_TYPES += (lazy_parser.get_cython_generator_type(),)
|
||||||
LazyProtocolHandler = cython_protocol_handler(LazyParser())
|
LazyProtocolHandler = cython_protocol_handler(LazyParser())
|
||||||
else:
|
else:
|
||||||
# Use Python-based ProtocolHandler
|
# Use Python-based ProtocolHandler
|
||||||
|
|||||||
@@ -55,21 +55,79 @@ class CythonProtocolHandlerTest(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
verify_iterator_data(self.assertEqual, get_data(LazyProtocolHandler))
|
verify_iterator_data(self.assertEqual, get_data(LazyProtocolHandler))
|
||||||
|
|
||||||
|
@numpytest
|
||||||
|
def test_cython_lazy_results_paged(self):
|
||||||
|
"""
|
||||||
|
Test Cython-based parser that returns an iterator, over multiple pages
|
||||||
|
"""
|
||||||
|
# arrays = { 'a': arr1, 'b': arr2, ... }
|
||||||
|
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
|
session = cluster.connect(keyspace="testspace")
|
||||||
|
session.row_factory = tuple_factory
|
||||||
|
session.client_protocol_handler = LazyProtocolHandler
|
||||||
|
session.default_fetch_size = 2
|
||||||
|
|
||||||
|
self.assertLess(session.default_fetch_size, self.N_ITEMS)
|
||||||
|
|
||||||
|
results = session.execute("SELECT * FROM test_table")
|
||||||
|
|
||||||
|
self.assertTrue(results.has_more_pages)
|
||||||
|
self.assertEqual(verify_iterator_data(self.assertEqual, results), self.N_ITEMS) # make sure we see all rows
|
||||||
|
|
||||||
|
cluster.shutdown()
|
||||||
|
|
||||||
@numpytest
|
@numpytest
|
||||||
def test_numpy_parser(self):
|
def test_numpy_parser(self):
|
||||||
"""
|
"""
|
||||||
Test Numpy-based parser that returns a NumPy array
|
Test Numpy-based parser that returns a NumPy array
|
||||||
"""
|
"""
|
||||||
# arrays = { 'a': arr1, 'b': arr2, ... }
|
# arrays = { 'a': arr1, 'b': arr2, ... }
|
||||||
arrays = get_data(NumpyProtocolHandler)
|
result = get_data(NumpyProtocolHandler)
|
||||||
|
self.assertFalse(result.has_more_pages)
|
||||||
|
self._verify_numpy_page(result[0])
|
||||||
|
|
||||||
|
@numpytest
|
||||||
|
def test_numpy_results_paged(self):
|
||||||
|
"""
|
||||||
|
Test Numpy-based parser that returns a NumPy array
|
||||||
|
"""
|
||||||
|
# arrays = { 'a': arr1, 'b': arr2, ... }
|
||||||
|
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
|
session = cluster.connect(keyspace="testspace")
|
||||||
|
session.row_factory = tuple_factory
|
||||||
|
session.client_protocol_handler = NumpyProtocolHandler
|
||||||
|
session.default_fetch_size = 2
|
||||||
|
|
||||||
|
expected_pages = (self.N_ITEMS + session.default_fetch_size - 1) // session.default_fetch_size
|
||||||
|
|
||||||
|
self.assertLess(session.default_fetch_size, self.N_ITEMS)
|
||||||
|
|
||||||
|
results = session.execute("SELECT * FROM test_table")
|
||||||
|
|
||||||
|
self.assertTrue(results.has_more_pages)
|
||||||
|
for count, page in enumerate(results, 1):
|
||||||
|
self.assertIsInstance(page, dict)
|
||||||
|
for colname, arr in page.items():
|
||||||
|
if count <= expected_pages:
|
||||||
|
self.assertGreater(len(arr), 0, "page count: %d" % (count,))
|
||||||
|
self.assertLessEqual(len(arr), session.default_fetch_size)
|
||||||
|
else:
|
||||||
|
# we get one extra item out of this iteration because of the way NumpyParser returns results
|
||||||
|
# The last page is returned as a dict with zero-length arrays
|
||||||
|
self.assertEqual(len(arr), 0)
|
||||||
|
self.assertEqual(self._verify_numpy_page(page), len(arr))
|
||||||
|
self.assertEqual(count, expected_pages + 1) # see note about extra 'page' above
|
||||||
|
|
||||||
|
cluster.shutdown()
|
||||||
|
|
||||||
|
def _verify_numpy_page(self, page):
|
||||||
colnames = self.colnames
|
colnames = self.colnames
|
||||||
datatypes = get_primitive_datatypes()
|
datatypes = get_primitive_datatypes()
|
||||||
for colname, datatype in zip(colnames, datatypes):
|
for colname, datatype in zip(colnames, datatypes):
|
||||||
arr = arrays[colname]
|
arr = page[colname]
|
||||||
self.match_dtype(datatype, arr.dtype)
|
self.match_dtype(datatype, arr.dtype)
|
||||||
|
|
||||||
verify_iterator_data(self.assertEqual, arrays_to_list_of_tuples(arrays, colnames))
|
return verify_iterator_data(self.assertEqual, arrays_to_list_of_tuples(page, colnames))
|
||||||
|
|
||||||
def match_dtype(self, datatype, dtype):
|
def match_dtype(self, datatype, dtype):
|
||||||
"""Match a string cqltype (e.g. 'int' or 'blob') with a numpy dtype"""
|
"""Match a string cqltype (e.g. 'int' or 'blob') with a numpy dtype"""
|
||||||
@@ -100,9 +158,7 @@ def arrays_to_list_of_tuples(arrays, colnames):
|
|||||||
|
|
||||||
def get_data(protocol_handler):
|
def get_data(protocol_handler):
|
||||||
"""
|
"""
|
||||||
Get some data from the test table.
|
Get data from the test table.
|
||||||
|
|
||||||
:param key: if None, get all results (100.000 results), otherwise get only one result
|
|
||||||
"""
|
"""
|
||||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
session = cluster.connect(keyspace="testspace")
|
session = cluster.connect(keyspace="testspace")
|
||||||
@@ -121,9 +177,11 @@ def verify_iterator_data(assertEqual, results):
|
|||||||
Check the result of get_data() when this is a list or
|
Check the result of get_data() when this is a list or
|
||||||
iterator of tuples
|
iterator of tuples
|
||||||
"""
|
"""
|
||||||
for result in results:
|
count = 0
|
||||||
|
for count, result in enumerate(results, 1):
|
||||||
params = get_all_primitive_params(result[0])
|
params = get_all_primitive_params(result[0])
|
||||||
assertEqual(len(params), len(result),
|
assertEqual(len(params), len(result),
|
||||||
msg="Not the right number of columns?")
|
msg="Not the right number of columns?")
|
||||||
for expected, actual in zip(params, result):
|
for expected, actual in zip(params, result):
|
||||||
assertEqual(actual, expected)
|
assertEqual(actual, expected)
|
||||||
|
return count
|
||||||
|
|||||||
Reference in New Issue
Block a user