diff --git a/cassandra/numpyparser.pyx b/cassandra/numpyparser.pyx index 8499d938..0a4e7e3e 100644 --- a/cassandra/numpyparser.pyx +++ b/cassandra/numpyparser.pyx @@ -23,7 +23,7 @@ from cassandra import cqltypes from cassandra.util import is_little_endian import numpy as np - +# import pandas as pd cdef extern from "numpyFlags.h": # Include 'numpyFlags.h' into the generated C code to disable the @@ -74,8 +74,10 @@ cdef class NumpyParser(ColumnParser): for i in range(rowcount): unpack_row(reader, desc, arrs) - return [make_native_byteorder(arr) for arr in arrays] - # return pd.DataFrame(dict(zip(desc.colnames, arrays))) + arrays = [make_native_byteorder(arr) for arr in arrays] + result = dict(zip(desc.colnames, arrays)) + # return pd.DataFrame(result) + return result ### Helper functions to create NumPy arrays and array descriptors diff --git a/tests/integration/standard/test_custom_protocol_handler.py b/tests/integration/standard/test_custom_protocol_handler.py index edd066be..36965a36 100644 --- a/tests/integration/standard/test_custom_protocol_handler.py +++ b/tests/integration/standard/test_custom_protocol_handler.py @@ -107,10 +107,11 @@ class CustomProtocolHandlerTest(unittest.TestCase): session.client_protocol_handler = CustomProtocolHandlerResultMessageTracked session.row_factory = tuple_factory - columns_string = create_table_with_all_types("alltypes", session) + colnames = create_table_with_all_types("alltypes", session, 1) + columns_string = ", ".join(colnames) # verify data - params = get_all_primitive_params() + params = get_all_primitive_params(0) results = session.execute("SELECT {0} FROM alltypes WHERE primkey=0".format(columns_string))[0] for expected, actual in zip(params, results): self.assertEqual(actual, expected) diff --git a/tests/integration/standard/test_cython_protocol_handlers.py b/tests/integration/standard/test_cython_protocol_handlers.py index ba75cf72..985b7953 100644 --- a/tests/integration/standard/test_cython_protocol_handlers.py +++ b/tests/integration/standard/test_cython_protocol_handlers.py @@ -7,23 +7,25 @@ try: except ImportError: import unittest +from cassandra.query import tuple_factory from cassandra.cluster import Cluster from cassandra.protocol import ProtocolHandler, LazyProtocolHandler, NumpyProtocolHandler + from tests.integration import use_singledc, PROTOCOL_VERSION from tests.integration.datatype_utils import update_datatypes -from tests.integration.standard.utils import create_table_with_all_types, get_all_primitive_params - -from cassandra.cython_deps import HAVE_CYTHON -if not HAVE_CYTHON: - raise unittest.SkipTest("Skipping test, not compiled with Cython enabled") +from tests.integration.standard.utils import ( + create_table_with_all_types, get_all_primitive_params, get_primitive_datatypes) +from tests.unit.cython.utils import cythontest, numpytest def setup_module(): use_singledc() update_datatypes() -class CustomProtocolHandlerTest(unittest.TestCase): +class CythonProtocolHandlerTest(unittest.TestCase): + + N_ITEMS = 10 @classmethod def setUpClass(cls): @@ -32,39 +34,96 @@ class CustomProtocolHandlerTest(unittest.TestCase): cls.session.execute("CREATE KEYSPACE testspace WITH replication = " "{ 'class' : 'SimpleStrategy', 'replication_factor': '1'}") cls.session.set_keyspace("testspace") - create_table_with_all_types("test_table", cls.session) + cls.colnames = create_table_with_all_types("test_table", cls.session, cls.N_ITEMS) @classmethod def tearDownClass(cls): cls.session.execute("DROP KEYSPACE testspace") cls.cluster.shutdown() + @cythontest def test_cython_parser(self): """ Test Cython-based parser that returns a list of tuples """ - self.cython_parser(ProtocolHandler) + verify_iterator_data(self.assertEqual, get_data(ProtocolHandler)) + @cythontest def test_cython_lazy_parser(self): """ - Test Cython-based parser that returns a list of tuples + Test Cython-based parser that returns an iterator of tuples """ - self.cython_parser(LazyProtocolHandler) + verify_iterator_data(self.assertEqual, get_data(LazyProtocolHandler)) - def cython_parser(self, protocol_handler): - cluster = Cluster(protocol_version=PROTOCOL_VERSION) - session = cluster.connect(keyspace="testspace") + @numpytest + def test_numpy_parser(self): + """ + Test Numpy-based parser that returns a NumPy array + """ + # arrays = { 'a': arr1, 'b': arr2, ... } + arrays = get_data(NumpyProtocolHandler) - # use our custom protocol handler - session.client_protocol_handler = protocol_handler - # session.row_factory = tuple_factory + colnames = self.colnames + datatypes = get_primitive_datatypes() + for colname, datatype in zip(colnames, datatypes): + arr = arrays[colname] + self.match_dtype(datatype, arr.dtype) - # verify data - params = get_all_primitive_params() - [first_result] = session.execute("SELECT * FROM test_table WHERE primkey=0") - self.assertEqual(len(params), len(first_result), - msg="Not the right number of columns?") - for expected, actual in zip(params, first_result): - self.assertEqual(actual, expected) + verify_iterator_data(self.assertEqual, arrays_to_list_of_tuples(arrays, colnames)) - session.shutdown() + def match_dtype(self, datatype, dtype): + """Match a string cqltype (e.g. 'int' or 'blob') with a numpy dtype""" + if datatype == 'smallint': + self.match_dtype_props(dtype, 'i', 2) + elif datatype == 'int': + self.match_dtype_props(dtype, 'i', 4) + elif datatype in ('bigint', 'counter'): + self.match_dtype_props(dtype, 'i', 8) + elif datatype == 'float': + self.match_dtype_props(dtype, 'f', 4) + elif datatype == 'double': + self.match_dtype_props(dtype, 'f', 8) + else: + self.assertEqual(dtype.kind, 'O', msg=(dtype, datatype)) + + def match_dtype_props(self, dtype, kind, size, signed=None): + self.assertEqual(dtype.kind, kind, msg=dtype) + self.assertEqual(dtype.itemsize, size, msg=dtype) + + +def arrays_to_list_of_tuples(arrays, colnames): + """Convert a dict of arrays (as given by the numpy protocol handler) to a list of tuples""" + first_array = arrays[colnames[0]] + return [tuple(arrays[colname][i] for colname in colnames) + for i in range(len(first_array))] + + +def get_data(protocol_handler): + """ + Get some data from the test table. + + :param key: if None, get all results (100.000 results), otherwise get only one result + """ + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect(keyspace="testspace") + + # use our custom protocol handler + session.client_protocol_handler = protocol_handler + session.row_factory = tuple_factory + + results = session.execute("SELECT * FROM test_table") + session.shutdown() + return results + + +def verify_iterator_data(assertEqual, results): + """ + Check the result of get_data() when this is a list or + iterator of tuples + """ + for result in results: + params = get_all_primitive_params(result[0]) + assertEqual(len(params), len(result), + msg="Not the right number of columns?") + for expected, actual in zip(params, result): + assertEqual(actual, expected) diff --git a/tests/integration/standard/utils.py b/tests/integration/standard/utils.py index bd0c80b5..fe54f04d 100644 --- a/tests/integration/standard/utils.py +++ b/tests/integration/standard/utils.py @@ -4,15 +4,16 @@ Helper module to populate a dummy Cassandra tables with data. from tests.integration.datatype_utils import PRIMITIVE_DATATYPES, get_sample -def create_table_with_all_types(table_name, session): +def create_table_with_all_types(table_name, session, N): """ Method that given a table_name and session construct a table that contains all possible primitive types. :param table_name: Name of table to create :param session: session to use for table creation - :return: a string containing the names of all the columns. - This can be used to query the table. + :param N: the number of items to insert into the table + + :return: a list of column names """ # create table alpha_type_list = ["primkey int PRIMARY KEY"] @@ -26,21 +27,27 @@ def create_table_with_all_types(table_name, session): table_name, ', '.join(alpha_type_list)), timeout=120) # create the input - params = get_all_primitive_params() - # insert into table as a simple statement - columns_string = ', '.join(col_names) - placeholders = ', '.join(["%s"] * len(col_names)) - session.execute("INSERT INTO {0} ({1}) VALUES ({2})".format( - table_name, columns_string, placeholders), params, timeout=120) - return columns_string + for key in range(N): + params = get_all_primitive_params(key) + + # insert into table as a simple statement + columns_string = ', '.join(col_names) + placeholders = ', '.join(["%s"] * len(col_names)) + session.execute("INSERT INTO {0} ({1}) VALUES ({2})".format( + table_name, columns_string, placeholders), params, timeout=120) + return col_names -def get_all_primitive_params(): +def get_all_primitive_params(key): """ Simple utility method used to give back a list of all possible primitive data sample types. """ - params = [0] + params = [key] for datatype in PRIMITIVE_DATATYPES: params.append(get_sample(datatype)) return params + + +def get_primitive_datatypes(): + return ['int'] + list(PRIMITIVE_DATATYPES) \ No newline at end of file diff --git a/tests/unit/cython/utils.py b/tests/unit/cython/utils.py index f2598c0e..9f0a5a87 100644 --- a/tests/unit/cython/utils.py +++ b/tests/unit/cython/utils.py @@ -17,6 +17,7 @@ def cyimport(import_path): raise return None + # @cythontest # def test_something(self): ... cythontest = unittest.skipUnless(HAVE_CYTHON, 'Cython is not available')