130 lines
4.3 KiB
Python
130 lines
4.3 KiB
Python
"""Test the various Cython-based message deserializers"""
|
|
|
|
# Based on test_custom_protocol_handler.py
|
|
|
|
try:
|
|
import unittest2 as unittest
|
|
except ImportError:
|
|
import unittest
|
|
|
|
from cassandra.query import tuple_factory
|
|
from cassandra.cluster import Cluster
|
|
from cassandra.protocol import ProtocolHandler, LazyProtocolHandler, NumpyProtocolHandler
|
|
|
|
from tests.integration import use_singledc, PROTOCOL_VERSION
|
|
from tests.integration.datatype_utils import update_datatypes
|
|
from tests.integration.standard.utils import (
|
|
create_table_with_all_types, get_all_primitive_params, get_primitive_datatypes)
|
|
|
|
from tests.unit.cython.utils import cythontest, numpytest
|
|
|
|
def setup_module():
|
|
use_singledc()
|
|
update_datatypes()
|
|
|
|
|
|
class CythonProtocolHandlerTest(unittest.TestCase):
|
|
|
|
N_ITEMS = 10
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
|
cls.session = cls.cluster.connect()
|
|
cls.session.execute("CREATE KEYSPACE testspace WITH replication = "
|
|
"{ 'class' : 'SimpleStrategy', 'replication_factor': '1'}")
|
|
cls.session.set_keyspace("testspace")
|
|
cls.colnames = create_table_with_all_types("test_table", cls.session, cls.N_ITEMS)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls.session.execute("DROP KEYSPACE testspace")
|
|
cls.cluster.shutdown()
|
|
|
|
@cythontest
|
|
def test_cython_parser(self):
|
|
"""
|
|
Test Cython-based parser that returns a list of tuples
|
|
"""
|
|
verify_iterator_data(self.assertEqual, get_data(ProtocolHandler))
|
|
|
|
@cythontest
|
|
def test_cython_lazy_parser(self):
|
|
"""
|
|
Test Cython-based parser that returns an iterator of tuples
|
|
"""
|
|
verify_iterator_data(self.assertEqual, get_data(LazyProtocolHandler))
|
|
|
|
@numpytest
|
|
def test_numpy_parser(self):
|
|
"""
|
|
Test Numpy-based parser that returns a NumPy array
|
|
"""
|
|
# arrays = { 'a': arr1, 'b': arr2, ... }
|
|
arrays = get_data(NumpyProtocolHandler)
|
|
|
|
colnames = self.colnames
|
|
datatypes = get_primitive_datatypes()
|
|
for colname, datatype in zip(colnames, datatypes):
|
|
arr = arrays[colname]
|
|
self.match_dtype(datatype, arr.dtype)
|
|
|
|
verify_iterator_data(self.assertEqual, arrays_to_list_of_tuples(arrays, colnames))
|
|
|
|
def match_dtype(self, datatype, dtype):
|
|
"""Match a string cqltype (e.g. 'int' or 'blob') with a numpy dtype"""
|
|
if datatype == 'smallint':
|
|
self.match_dtype_props(dtype, 'i', 2)
|
|
elif datatype == 'int':
|
|
self.match_dtype_props(dtype, 'i', 4)
|
|
elif datatype in ('bigint', 'counter'):
|
|
self.match_dtype_props(dtype, 'i', 8)
|
|
elif datatype == 'float':
|
|
self.match_dtype_props(dtype, 'f', 4)
|
|
elif datatype == 'double':
|
|
self.match_dtype_props(dtype, 'f', 8)
|
|
else:
|
|
self.assertEqual(dtype.kind, 'O', msg=(dtype, datatype))
|
|
|
|
def match_dtype_props(self, dtype, kind, size, signed=None):
|
|
self.assertEqual(dtype.kind, kind, msg=dtype)
|
|
self.assertEqual(dtype.itemsize, size, msg=dtype)
|
|
|
|
|
|
def arrays_to_list_of_tuples(arrays, colnames):
|
|
"""Convert a dict of arrays (as given by the numpy protocol handler) to a list of tuples"""
|
|
first_array = arrays[colnames[0]]
|
|
return [tuple(arrays[colname][i] for colname in colnames)
|
|
for i in range(len(first_array))]
|
|
|
|
|
|
def get_data(protocol_handler):
|
|
"""
|
|
Get some data from the test table.
|
|
|
|
:param key: if None, get all results (100.000 results), otherwise get only one result
|
|
"""
|
|
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
|
session = cluster.connect(keyspace="testspace")
|
|
|
|
# use our custom protocol handler
|
|
session.client_protocol_handler = protocol_handler
|
|
session.row_factory = tuple_factory
|
|
|
|
results = session.execute("SELECT * FROM test_table")
|
|
session.shutdown()
|
|
return results
|
|
|
|
|
|
def verify_iterator_data(assertEqual, results):
|
|
"""
|
|
Check the result of get_data() when this is a list or
|
|
iterator of tuples
|
|
"""
|
|
for result in results:
|
|
params = get_all_primitive_params(result[0])
|
|
assertEqual(len(params), len(result),
|
|
msg="Not the right number of columns?")
|
|
for expected, actual in zip(params, result):
|
|
assertEqual(actual, expected)
|