diff --git a/cassandra/numpy_parser.pyx b/cassandra/numpy_parser.pyx index 1334e747..ed755d00 100644 --- a/cassandra/numpy_parser.pyx +++ b/cassandra/numpy_parser.pyx @@ -13,7 +13,7 @@ # limitations under the License. """ -This module provider an optional protocol parser that returns +This module provides an optional protocol parser that returns NumPy arrays. ============================================================================= @@ -25,7 +25,7 @@ as numpy is an optional dependency. include "ioutils.pyx" cimport cython -from libc.stdint cimport uint64_t +from libc.stdint cimport uint64_t, uint8_t from cpython.ref cimport Py_INCREF, PyObject from cassandra.bytesio cimport BytesIOReader @@ -35,7 +35,6 @@ from cassandra import cqltypes from cassandra.util import is_little_endian import numpy as np -# import pandas as pd cdef extern from "numpyFlags.h": # Include 'numpyFlags.h' into the generated C code to disable the @@ -52,11 +51,13 @@ ctypedef struct ArrDesc: Py_uintptr_t buf_ptr int stride # should be large enough as we allocate contiguous arrays int is_object + Py_uintptr_t mask_ptr arrDescDtype = np.dtype( [ ('buf_ptr', np.uintp) , ('stride', np.dtype('i')) , ('is_object', np.dtype('i')) + , ('mask_ptr', np.uintp) ], align=True) _cqltype_to_numpy = { @@ -70,6 +71,7 @@ _cqltype_to_numpy = { obj_dtype = np.dtype('O') +cdef uint8_t mask_true = 0x01 cdef class NumpyParser(ColumnParser): """Decode a ResultMessage into a bunch of NumPy arrays""" @@ -116,7 +118,11 @@ def make_arrays(ParseDesc desc, array_size): arr = make_array(coltype, array_size) array_descs[i]['buf_ptr'] = arr.ctypes.data array_descs[i]['stride'] = arr.strides[0] - array_descs[i]['is_object'] = coltype not in _cqltype_to_numpy + array_descs[i]['is_object'] = arr.dtype is obj_dtype + try: + array_descs[i]['mask_ptr'] = arr.mask.ctypes.data + except AttributeError: + array_descs[i]['mask_ptr'] = 0 arrays.append(arr) return array_descs, arrays @@ -126,8 +132,12 @@ def make_array(coltype, array_size): """ Allocate a new NumPy array of the given column type and size. """ - dtype = _cqltype_to_numpy.get(coltype, obj_dtype) - return np.empty((array_size,), dtype=dtype) + try: + a = np.ma.empty((array_size,), dtype=_cqltype_to_numpy[coltype]) + a.mask = np.zeros((array_size,), dtype=np.bool) + except KeyError: + a = np.empty((array_size,), dtype=obj_dtype) + return a #### Parse rows into NumPy arrays @@ -140,7 +150,6 @@ cdef inline int unpack_row( cdef Py_ssize_t i, rowsize = desc.rowsize cdef ArrDesc arr cdef Deserializer deserializer - for i in range(rowsize): get_buf(reader, &buf) arr = arrays[i] @@ -150,13 +159,14 @@ cdef inline int unpack_row( val = from_binary(deserializer, &buf, desc.protocol_version) Py_INCREF(val) ( arr.buf_ptr)[0] = val - elif buf.size < 0: - raise ValueError("Cannot handle NULL value") - else: + elif buf.size >= 0: memcpy( arr.buf_ptr, buf.ptr, buf.size) + else: + memcpy(arr.mask_ptr, &mask_true, 1) # Update the pointer into the array for the next time arrays[i].buf_ptr += arr.stride + arrays[i].mask_ptr += 1 return 0