Files
deb-python-cassandra-driver/tests/integration/standard/test_cython_protocol_handlers.py
2016-07-22 16:30:28 -05:00

258 lines
9.5 KiB
Python

"""Test the various Cython-based message deserializers"""
# Based on test_custom_protocol_handler.py
try:
import unittest2 as unittest
except ImportError:
import unittest
from itertools import count
from cassandra.query import tuple_factory
from cassandra.cluster import Cluster, NoHostAvailable
from cassandra.concurrent import execute_concurrent_with_args
from cassandra.protocol import ProtocolHandler, LazyProtocolHandler, NumpyProtocolHandler
from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY
from tests.integration import use_singledc, PROTOCOL_VERSION, notprotocolv1, drop_keyspace_shutdown_cluster, VERIFY_CYTHON, BasicSharedKeyspaceUnitTestCase, execute_with_retry_tolerant, greaterthancass21
from tests.integration.datatype_utils import update_datatypes
from tests.integration.standard.utils import (
create_table_with_all_types, get_all_primitive_params, get_primitive_datatypes)
from tests.unit.cython.utils import cythontest, numpytest
def setup_module():
use_singledc()
update_datatypes()
class CythonProtocolHandlerTest(unittest.TestCase):
N_ITEMS = 10
@classmethod
def setUpClass(cls):
cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
cls.session = cls.cluster.connect()
cls.session.execute("CREATE KEYSPACE testspace WITH replication = "
"{ 'class' : 'SimpleStrategy', 'replication_factor': '1'}")
cls.session.set_keyspace("testspace")
cls.colnames = create_table_with_all_types("test_table", cls.session, cls.N_ITEMS)
@classmethod
def tearDownClass(cls):
drop_keyspace_shutdown_cluster("testspace", cls.session, cls.session)
@cythontest
def test_cython_parser(self):
"""
Test Cython-based parser that returns a list of tuples
"""
verify_iterator_data(self.assertEqual, get_data(ProtocolHandler))
@cythontest
def test_cython_lazy_parser(self):
"""
Test Cython-based parser that returns an iterator of tuples
"""
verify_iterator_data(self.assertEqual, get_data(LazyProtocolHandler))
@notprotocolv1
@numpytest
def test_cython_lazy_results_paged(self):
"""
Test Cython-based parser that returns an iterator, over multiple pages
"""
# arrays = { 'a': arr1, 'b': arr2, ... }
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect(keyspace="testspace")
session.row_factory = tuple_factory
session.client_protocol_handler = LazyProtocolHandler
session.default_fetch_size = 2
self.assertLess(session.default_fetch_size, self.N_ITEMS)
results = session.execute("SELECT * FROM test_table")
self.assertTrue(results.has_more_pages)
self.assertEqual(verify_iterator_data(self.assertEqual, results), self.N_ITEMS) # make sure we see all rows
cluster.shutdown()
@notprotocolv1
@numpytest
def test_numpy_parser(self):
"""
Test Numpy-based parser that returns a NumPy array
"""
# arrays = { 'a': arr1, 'b': arr2, ... }
result = get_data(NumpyProtocolHandler)
self.assertFalse(result.has_more_pages)
self._verify_numpy_page(result[0])
@notprotocolv1
@numpytest
def test_numpy_results_paged(self):
"""
Test Numpy-based parser that returns a NumPy array
"""
# arrays = { 'a': arr1, 'b': arr2, ... }
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect(keyspace="testspace")
session.row_factory = tuple_factory
session.client_protocol_handler = NumpyProtocolHandler
session.default_fetch_size = 2
expected_pages = (self.N_ITEMS + session.default_fetch_size - 1) // session.default_fetch_size
self.assertLess(session.default_fetch_size, self.N_ITEMS)
results = session.execute("SELECT * FROM test_table")
self.assertTrue(results.has_more_pages)
for count, page in enumerate(results, 1):
self.assertIsInstance(page, dict)
for colname, arr in page.items():
if count <= expected_pages:
self.assertGreater(len(arr), 0, "page count: %d" % (count,))
self.assertLessEqual(len(arr), session.default_fetch_size)
else:
# we get one extra item out of this iteration because of the way NumpyParser returns results
# The last page is returned as a dict with zero-length arrays
self.assertEqual(len(arr), 0)
self.assertEqual(self._verify_numpy_page(page), len(arr))
self.assertEqual(count, expected_pages + 1) # see note about extra 'page' above
cluster.shutdown()
@numpytest
def test_cython_numpy_are_installed_valid(self):
"""
Test to validate that cython and numpy are installed correctly
@since 3.3.0
@jira_ticket PYTHON-543
@expected_result Cython and Numpy should be present
@test_category configuration
"""
if VERIFY_CYTHON:
self.assertTrue(HAVE_CYTHON)
self.assertTrue(HAVE_NUMPY)
def _verify_numpy_page(self, page):
colnames = self.colnames
datatypes = get_primitive_datatypes()
for colname, datatype in zip(colnames, datatypes):
arr = page[colname]
self.match_dtype(datatype, arr.dtype)
return verify_iterator_data(self.assertEqual, arrays_to_list_of_tuples(page, colnames))
def match_dtype(self, datatype, dtype):
"""Match a string cqltype (e.g. 'int' or 'blob') with a numpy dtype"""
if datatype == 'smallint':
self.match_dtype_props(dtype, 'i', 2)
elif datatype == 'int':
self.match_dtype_props(dtype, 'i', 4)
elif datatype in ('bigint', 'counter'):
self.match_dtype_props(dtype, 'i', 8)
elif datatype == 'float':
self.match_dtype_props(dtype, 'f', 4)
elif datatype == 'double':
self.match_dtype_props(dtype, 'f', 8)
else:
self.assertEqual(dtype.kind, 'O', msg=(dtype, datatype))
def match_dtype_props(self, dtype, kind, size, signed=None):
self.assertEqual(dtype.kind, kind, msg=dtype)
self.assertEqual(dtype.itemsize, size, msg=dtype)
def arrays_to_list_of_tuples(arrays, colnames):
"""Convert a dict of arrays (as given by the numpy protocol handler) to a list of tuples"""
first_array = arrays[colnames[0]]
return [tuple(arrays[colname][i] for colname in colnames)
for i in range(len(first_array))]
def get_data(protocol_handler):
"""
Get data from the test table.
"""
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect(keyspace="testspace")
# use our custom protocol handler
session.client_protocol_handler = protocol_handler
session.row_factory = tuple_factory
results = session.execute("SELECT * FROM test_table")
session.shutdown()
return results
def verify_iterator_data(assertEqual, results):
"""
Check the result of get_data() when this is a list or
iterator of tuples
"""
count = 0
for count, result in enumerate(results, 1):
params = get_all_primitive_params(result[0])
assertEqual(len(params), len(result),
msg="Not the right number of columns?")
for expected, actual in zip(params, result):
assertEqual(actual, expected)
return count
class NumpyNullTest(BasicSharedKeyspaceUnitTestCase):
@numpytest
@greaterthancass21
def test_null_types(self):
"""
Test to validate that the numpy protocol handler can deal with null values.
@since 3.3.0
- updated 3.6.0: now numeric types used masked array
@jira_ticket PYTHON-550
@expected_result Numpy can handle non mapped types' null values.
@test_category data_types:serialization
"""
s = self.session
s.row_factory = tuple_factory
s.client_protocol_handler = NumpyProtocolHandler
table = "%s.%s" % (self.keyspace_name, self.function_table_name)
create_table_with_all_types(table, s, 10)
begin_unset = max(s.execute('select primkey from %s' % (table,))[0]['primkey']) + 1
keys_null = range(begin_unset, begin_unset + 10)
# scatter some emptry rows in here
insert = "insert into %s (primkey) values (%%s)" % (table,)
execute_concurrent_with_args(s, insert, ((k,) for k in keys_null))
result = s.execute("select * from %s" % (table,))[0]
from numpy.ma import masked, MaskedArray
result_keys = result.pop('primkey')
mapped_index = [v[1] for v in sorted(zip(result_keys, count()))]
had_masked = had_none = False
for col_array in result.values():
# these have to be different branches (as opposed to comparing against an 'unset value')
# because None and `masked` have different identity and equals semantics
if isinstance(col_array, MaskedArray):
had_masked = True
[self.assertIsNot(col_array[i], masked) for i in mapped_index[:begin_unset]]
[self.assertIs(col_array[i], masked) for i in mapped_index[begin_unset:]]
else:
had_none = True
[self.assertIsNotNone(col_array[i]) for i in mapped_index[:begin_unset]]
[self.assertIsNone(col_array[i]) for i in mapped_index[begin_unset:]]
self.assertTrue(had_masked)
self.assertTrue(had_none)