More comprehensive cython and numpy deserializer tests
This commit is contained in:
@@ -23,7 +23,7 @@ from cassandra import cqltypes
|
|||||||
from cassandra.util import is_little_endian
|
from cassandra.util import is_little_endian
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
# import pandas as pd
|
||||||
|
|
||||||
cdef extern from "numpyFlags.h":
|
cdef extern from "numpyFlags.h":
|
||||||
# Include 'numpyFlags.h' into the generated C code to disable the
|
# Include 'numpyFlags.h' into the generated C code to disable the
|
||||||
@@ -74,8 +74,10 @@ cdef class NumpyParser(ColumnParser):
|
|||||||
for i in range(rowcount):
|
for i in range(rowcount):
|
||||||
unpack_row(reader, desc, arrs)
|
unpack_row(reader, desc, arrs)
|
||||||
|
|
||||||
return [make_native_byteorder(arr) for arr in arrays]
|
arrays = [make_native_byteorder(arr) for arr in arrays]
|
||||||
# return pd.DataFrame(dict(zip(desc.colnames, arrays)))
|
result = dict(zip(desc.colnames, arrays))
|
||||||
|
# return pd.DataFrame(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
### Helper functions to create NumPy arrays and array descriptors
|
### Helper functions to create NumPy arrays and array descriptors
|
||||||
|
@@ -107,10 +107,11 @@ class CustomProtocolHandlerTest(unittest.TestCase):
|
|||||||
session.client_protocol_handler = CustomProtocolHandlerResultMessageTracked
|
session.client_protocol_handler = CustomProtocolHandlerResultMessageTracked
|
||||||
session.row_factory = tuple_factory
|
session.row_factory = tuple_factory
|
||||||
|
|
||||||
columns_string = create_table_with_all_types("alltypes", session)
|
colnames = create_table_with_all_types("alltypes", session, 1)
|
||||||
|
columns_string = ", ".join(colnames)
|
||||||
|
|
||||||
# verify data
|
# verify data
|
||||||
params = get_all_primitive_params()
|
params = get_all_primitive_params(0)
|
||||||
results = session.execute("SELECT {0} FROM alltypes WHERE primkey=0".format(columns_string))[0]
|
results = session.execute("SELECT {0} FROM alltypes WHERE primkey=0".format(columns_string))[0]
|
||||||
for expected, actual in zip(params, results):
|
for expected, actual in zip(params, results):
|
||||||
self.assertEqual(actual, expected)
|
self.assertEqual(actual, expected)
|
||||||
|
@@ -7,23 +7,25 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from cassandra.query import tuple_factory
|
||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
from cassandra.protocol import ProtocolHandler, LazyProtocolHandler, NumpyProtocolHandler
|
from cassandra.protocol import ProtocolHandler, LazyProtocolHandler, NumpyProtocolHandler
|
||||||
|
|
||||||
from tests.integration import use_singledc, PROTOCOL_VERSION
|
from tests.integration import use_singledc, PROTOCOL_VERSION
|
||||||
from tests.integration.datatype_utils import update_datatypes
|
from tests.integration.datatype_utils import update_datatypes
|
||||||
from tests.integration.standard.utils import create_table_with_all_types, get_all_primitive_params
|
from tests.integration.standard.utils import (
|
||||||
|
create_table_with_all_types, get_all_primitive_params, get_primitive_datatypes)
|
||||||
from cassandra.cython_deps import HAVE_CYTHON
|
|
||||||
if not HAVE_CYTHON:
|
|
||||||
raise unittest.SkipTest("Skipping test, not compiled with Cython enabled")
|
|
||||||
|
|
||||||
|
from tests.unit.cython.utils import cythontest, numpytest
|
||||||
|
|
||||||
def setup_module():
|
def setup_module():
|
||||||
use_singledc()
|
use_singledc()
|
||||||
update_datatypes()
|
update_datatypes()
|
||||||
|
|
||||||
|
|
||||||
class CustomProtocolHandlerTest(unittest.TestCase):
|
class CythonProtocolHandlerTest(unittest.TestCase):
|
||||||
|
|
||||||
|
N_ITEMS = 10
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@@ -32,39 +34,96 @@ class CustomProtocolHandlerTest(unittest.TestCase):
|
|||||||
cls.session.execute("CREATE KEYSPACE testspace WITH replication = "
|
cls.session.execute("CREATE KEYSPACE testspace WITH replication = "
|
||||||
"{ 'class' : 'SimpleStrategy', 'replication_factor': '1'}")
|
"{ 'class' : 'SimpleStrategy', 'replication_factor': '1'}")
|
||||||
cls.session.set_keyspace("testspace")
|
cls.session.set_keyspace("testspace")
|
||||||
create_table_with_all_types("test_table", cls.session)
|
cls.colnames = create_table_with_all_types("test_table", cls.session, cls.N_ITEMS)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
cls.session.execute("DROP KEYSPACE testspace")
|
cls.session.execute("DROP KEYSPACE testspace")
|
||||||
cls.cluster.shutdown()
|
cls.cluster.shutdown()
|
||||||
|
|
||||||
|
@cythontest
|
||||||
def test_cython_parser(self):
|
def test_cython_parser(self):
|
||||||
"""
|
"""
|
||||||
Test Cython-based parser that returns a list of tuples
|
Test Cython-based parser that returns a list of tuples
|
||||||
"""
|
"""
|
||||||
self.cython_parser(ProtocolHandler)
|
verify_iterator_data(self.assertEqual, get_data(ProtocolHandler))
|
||||||
|
|
||||||
|
@cythontest
|
||||||
def test_cython_lazy_parser(self):
|
def test_cython_lazy_parser(self):
|
||||||
"""
|
"""
|
||||||
Test Cython-based parser that returns a list of tuples
|
Test Cython-based parser that returns an iterator of tuples
|
||||||
"""
|
"""
|
||||||
self.cython_parser(LazyProtocolHandler)
|
verify_iterator_data(self.assertEqual, get_data(LazyProtocolHandler))
|
||||||
|
|
||||||
def cython_parser(self, protocol_handler):
|
@numpytest
|
||||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
def test_numpy_parser(self):
|
||||||
session = cluster.connect(keyspace="testspace")
|
"""
|
||||||
|
Test Numpy-based parser that returns a NumPy array
|
||||||
|
"""
|
||||||
|
# arrays = { 'a': arr1, 'b': arr2, ... }
|
||||||
|
arrays = get_data(NumpyProtocolHandler)
|
||||||
|
|
||||||
# use our custom protocol handler
|
colnames = self.colnames
|
||||||
session.client_protocol_handler = protocol_handler
|
datatypes = get_primitive_datatypes()
|
||||||
# session.row_factory = tuple_factory
|
for colname, datatype in zip(colnames, datatypes):
|
||||||
|
arr = arrays[colname]
|
||||||
|
self.match_dtype(datatype, arr.dtype)
|
||||||
|
|
||||||
# verify data
|
verify_iterator_data(self.assertEqual, arrays_to_list_of_tuples(arrays, colnames))
|
||||||
params = get_all_primitive_params()
|
|
||||||
[first_result] = session.execute("SELECT * FROM test_table WHERE primkey=0")
|
|
||||||
self.assertEqual(len(params), len(first_result),
|
|
||||||
msg="Not the right number of columns?")
|
|
||||||
for expected, actual in zip(params, first_result):
|
|
||||||
self.assertEqual(actual, expected)
|
|
||||||
|
|
||||||
session.shutdown()
|
def match_dtype(self, datatype, dtype):
|
||||||
|
"""Match a string cqltype (e.g. 'int' or 'blob') with a numpy dtype"""
|
||||||
|
if datatype == 'smallint':
|
||||||
|
self.match_dtype_props(dtype, 'i', 2)
|
||||||
|
elif datatype == 'int':
|
||||||
|
self.match_dtype_props(dtype, 'i', 4)
|
||||||
|
elif datatype in ('bigint', 'counter'):
|
||||||
|
self.match_dtype_props(dtype, 'i', 8)
|
||||||
|
elif datatype == 'float':
|
||||||
|
self.match_dtype_props(dtype, 'f', 4)
|
||||||
|
elif datatype == 'double':
|
||||||
|
self.match_dtype_props(dtype, 'f', 8)
|
||||||
|
else:
|
||||||
|
self.assertEqual(dtype.kind, 'O', msg=(dtype, datatype))
|
||||||
|
|
||||||
|
def match_dtype_props(self, dtype, kind, size, signed=None):
|
||||||
|
self.assertEqual(dtype.kind, kind, msg=dtype)
|
||||||
|
self.assertEqual(dtype.itemsize, size, msg=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def arrays_to_list_of_tuples(arrays, colnames):
|
||||||
|
"""Convert a dict of arrays (as given by the numpy protocol handler) to a list of tuples"""
|
||||||
|
first_array = arrays[colnames[0]]
|
||||||
|
return [tuple(arrays[colname][i] for colname in colnames)
|
||||||
|
for i in range(len(first_array))]
|
||||||
|
|
||||||
|
|
||||||
|
def get_data(protocol_handler):
|
||||||
|
"""
|
||||||
|
Get some data from the test table.
|
||||||
|
|
||||||
|
:param key: if None, get all results (100.000 results), otherwise get only one result
|
||||||
|
"""
|
||||||
|
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
|
session = cluster.connect(keyspace="testspace")
|
||||||
|
|
||||||
|
# use our custom protocol handler
|
||||||
|
session.client_protocol_handler = protocol_handler
|
||||||
|
session.row_factory = tuple_factory
|
||||||
|
|
||||||
|
results = session.execute("SELECT * FROM test_table")
|
||||||
|
session.shutdown()
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def verify_iterator_data(assertEqual, results):
|
||||||
|
"""
|
||||||
|
Check the result of get_data() when this is a list or
|
||||||
|
iterator of tuples
|
||||||
|
"""
|
||||||
|
for result in results:
|
||||||
|
params = get_all_primitive_params(result[0])
|
||||||
|
assertEqual(len(params), len(result),
|
||||||
|
msg="Not the right number of columns?")
|
||||||
|
for expected, actual in zip(params, result):
|
||||||
|
assertEqual(actual, expected)
|
||||||
|
@@ -4,15 +4,16 @@ Helper module to populate a dummy Cassandra tables with data.
|
|||||||
|
|
||||||
from tests.integration.datatype_utils import PRIMITIVE_DATATYPES, get_sample
|
from tests.integration.datatype_utils import PRIMITIVE_DATATYPES, get_sample
|
||||||
|
|
||||||
def create_table_with_all_types(table_name, session):
|
def create_table_with_all_types(table_name, session, N):
|
||||||
"""
|
"""
|
||||||
Method that given a table_name and session construct a table that contains
|
Method that given a table_name and session construct a table that contains
|
||||||
all possible primitive types.
|
all possible primitive types.
|
||||||
|
|
||||||
:param table_name: Name of table to create
|
:param table_name: Name of table to create
|
||||||
:param session: session to use for table creation
|
:param session: session to use for table creation
|
||||||
:return: a string containing the names of all the columns.
|
:param N: the number of items to insert into the table
|
||||||
This can be used to query the table.
|
|
||||||
|
:return: a list of column names
|
||||||
"""
|
"""
|
||||||
# create table
|
# create table
|
||||||
alpha_type_list = ["primkey int PRIMARY KEY"]
|
alpha_type_list = ["primkey int PRIMARY KEY"]
|
||||||
@@ -26,21 +27,27 @@ def create_table_with_all_types(table_name, session):
|
|||||||
table_name, ', '.join(alpha_type_list)), timeout=120)
|
table_name, ', '.join(alpha_type_list)), timeout=120)
|
||||||
|
|
||||||
# create the input
|
# create the input
|
||||||
params = get_all_primitive_params()
|
|
||||||
|
|
||||||
# insert into table as a simple statement
|
for key in range(N):
|
||||||
columns_string = ', '.join(col_names)
|
params = get_all_primitive_params(key)
|
||||||
placeholders = ', '.join(["%s"] * len(col_names))
|
|
||||||
session.execute("INSERT INTO {0} ({1}) VALUES ({2})".format(
|
# insert into table as a simple statement
|
||||||
table_name, columns_string, placeholders), params, timeout=120)
|
columns_string = ', '.join(col_names)
|
||||||
return columns_string
|
placeholders = ', '.join(["%s"] * len(col_names))
|
||||||
|
session.execute("INSERT INTO {0} ({1}) VALUES ({2})".format(
|
||||||
|
table_name, columns_string, placeholders), params, timeout=120)
|
||||||
|
return col_names
|
||||||
|
|
||||||
|
|
||||||
def get_all_primitive_params():
|
def get_all_primitive_params(key):
|
||||||
"""
|
"""
|
||||||
Simple utility method used to give back a list of all possible primitive data sample types.
|
Simple utility method used to give back a list of all possible primitive data sample types.
|
||||||
"""
|
"""
|
||||||
params = [0]
|
params = [key]
|
||||||
for datatype in PRIMITIVE_DATATYPES:
|
for datatype in PRIMITIVE_DATATYPES:
|
||||||
params.append(get_sample(datatype))
|
params.append(get_sample(datatype))
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def get_primitive_datatypes():
|
||||||
|
return ['int'] + list(PRIMITIVE_DATATYPES)
|
@@ -17,6 +17,7 @@ def cyimport(import_path):
|
|||||||
raise
|
raise
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
# @cythontest
|
# @cythontest
|
||||||
# def test_something(self): ...
|
# def test_something(self): ...
|
||||||
cythontest = unittest.skipUnless(HAVE_CYTHON, 'Cython is not available')
|
cythontest = unittest.skipUnless(HAVE_CYTHON, 'Cython is not available')
|
||||||
|
Reference in New Issue
Block a user