diff --git a/.gitignore b/.gitignore index 3bdda58a..a874ebce 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ *.so *.egg *.egg-info +*.attr .tox .python-version build @@ -19,6 +20,7 @@ setuptools*.egg cassandra/*.c !cassandra/cmurmur3.c +cassandra/*.html # OSX .DS_Store @@ -38,3 +40,4 @@ cassandra/*.c #iPython *.ipynb + diff --git a/MANIFEST.in b/MANIFEST.in index 1825f7bb..e3cb20eb 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,4 @@ include setup.py README.rst MANIFEST.in LICENSE ez_setup.py +include cassandra/*.pyx +include cassandra/*.pxd +include tests/unit/cython/*.pyx diff --git a/cassandra/buffer.pxd b/cassandra/buffer.pxd new file mode 100644 index 00000000..2f40ced0 --- /dev/null +++ b/cassandra/buffer.pxd @@ -0,0 +1,58 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple buffer data structure that provides a view on existing memory +(e.g. from a bytes object). This memory must stay alive while the +buffer is in use. +""" + +from cpython.bytes cimport PyBytes_AS_STRING + # char* PyBytes_AS_STRING(object string) + # Macro form of PyBytes_AsString() but without error + # checking. Only string objects are supported; no Unicode objects + # should be passed. + + +cdef struct Buffer: + char *ptr + Py_ssize_t size + + +cdef inline bytes to_bytes(Buffer *buf): + return buf.ptr[:buf.size] + +cdef inline char *buf_ptr(Buffer *buf): + return buf.ptr + +cdef inline char *buf_read(Buffer *buf, Py_ssize_t size) except NULL: + if size > buf.size: + raise IndexError("Requested more than length of buffer") + return buf.ptr + +cdef inline int slice_buffer(Buffer *buf, Buffer *out, + Py_ssize_t start, Py_ssize_t size) except -1: + if size < 0: + raise ValueError("Length must be positive") + + if start + size > buf.size: + raise IndexError("Buffer slice out of bounds") + + out.ptr = buf.ptr + start + out.size = size + return 0 + +cdef inline void from_ptr_and_size(char *ptr, Py_ssize_t size, Buffer *out): + out.ptr = ptr + out.size = size diff --git a/cassandra/bytesio.pxd b/cassandra/bytesio.pxd new file mode 100644 index 00000000..2bcda361 --- /dev/null +++ b/cassandra/bytesio.pxd @@ -0,0 +1,20 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cdef class BytesIOReader: + cdef bytes buf + cdef char *buf_ptr + cdef Py_ssize_t pos + cdef Py_ssize_t size + cdef char *read(self, Py_ssize_t n = ?) except NULL diff --git a/cassandra/bytesio.pyx b/cassandra/bytesio.pyx new file mode 100644 index 00000000..68a15baf --- /dev/null +++ b/cassandra/bytesio.pyx @@ -0,0 +1,44 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cdef class BytesIOReader: + """ + This class provides efficient support for reading bytes from a 'bytes' buffer, + by returning char * values directly without allocating intermediate objects. + """ + + def __init__(self, bytes buf): + self.buf = buf + self.size = len(buf) + self.buf_ptr = self.buf + + cdef char *read(self, Py_ssize_t n = -1) except NULL: + """Read at most size bytes from the file + (less if the read hits EOF before obtaining size bytes). + + If the size argument is negative or omitted, read all data until EOF + is reached. The bytes are returned as a string object. An empty + string is returned when EOF is encountered immediately. + """ + cdef Py_ssize_t newpos = self.pos + n + if n < 0: + newpos = self.size + elif newpos > self.size: + # Raise an error here, as we do not want the caller to consume past the + # end of the buffer + raise EOFError("Cannot read past the end of the file") + + cdef char *res = self.buf_ptr + self.pos + self.pos = newpos + return res diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index 2d858e52..f39d28b9 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -288,7 +288,7 @@ class _CassandraType(object): Given a set of other CassandraTypes, create a new subtype of this type using them as parameters. This is how composite types are constructed. - >>> MapType.apply_parameters(DateType, BooleanType) + >>> MapType.apply_parameters([DateType, BooleanType]) `subtypes` will be a sequence of CassandraTypes. If provided, `names` @@ -884,27 +884,7 @@ class UserType(TupleType): @classmethod def deserialize_safe(cls, byts, protocol_version): - proto_version = max(3, protocol_version) - p = 0 - values = [] - for col_type in cls.subtypes: - if p == len(byts): - break - itemlen = int32_unpack(byts[p:p + 4]) - p += 4 - if itemlen >= 0: - item = byts[p:p + itemlen] - p += itemlen - else: - item = None - # collections inside UDTs are always encoded with at least the - # version 3 format - values.append(col_type.from_binary(item, proto_version)) - - if len(values) < len(cls.subtypes): - nones = [None] * (len(cls.subtypes) - len(values)) - values = values + nones - + values = super(UserType, cls).deserialize_safe(byts, protocol_version) if cls.mapped_class: return cls.mapped_class(**dict(zip(cls.fieldnames, values))) else: diff --git a/cassandra/cython_deps.py b/cassandra/cython_deps.py new file mode 100644 index 00000000..5cc86fe7 --- /dev/null +++ b/cassandra/cython_deps.py @@ -0,0 +1,11 @@ +try: + from cassandra.row_parser import make_recv_results_rows + HAVE_CYTHON = True +except ImportError: + HAVE_CYTHON = False + +try: + import numpy + HAVE_NUMPY = True +except ImportError: + HAVE_NUMPY = False diff --git a/cassandra/cython_marshal.pyx b/cassandra/cython_marshal.pyx new file mode 100644 index 00000000..1ba11435 --- /dev/null +++ b/cassandra/cython_marshal.pyx @@ -0,0 +1,123 @@ +# -- cython: profile=True +# +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six + +from libc.stdint cimport (int8_t, int16_t, int32_t, int64_t, + uint8_t, uint16_t, uint32_t, uint64_t) +from cassandra.buffer cimport Buffer, buf_read, to_bytes + +cdef bint is_little_endian +from cassandra.util import is_little_endian + +cdef bint PY3 = six.PY3 + + +cdef inline void swap_order(char *buf, Py_ssize_t size): + """ + Swap the byteorder of `buf` in-place on little-endian platforms + (reverse all the bytes). + There are functions ntohl etc, but these may be POSIX-dependent. + """ + cdef Py_ssize_t start, end, i + cdef char c + + if is_little_endian: + for i in range(div2(size)): + end = size - i - 1 + c = buf[i] + buf[i] = buf[end] + buf[end] = c + +cdef inline Py_ssize_t div2(Py_ssize_t x): + return x >> 1 + +### Unpacking of signed integers + +cdef inline int64_t int64_unpack(Buffer *buf) except ?0xDEAD: + cdef int64_t x = ( buf_read(buf, 8))[0] + cdef char *p = &x + swap_order( &x, 8) + return x + +cdef inline int32_t int32_unpack(Buffer *buf) except ?0xDEAD: + cdef int32_t x = ( buf_read(buf, 4))[0] + cdef char *p = &x + swap_order( &x, 4) + return x + +cdef inline int16_t int16_unpack(Buffer *buf) except ?0xDED: + cdef int16_t x = ( buf_read(buf, 2))[0] + swap_order( &x, 2) + return x + +cdef inline int8_t int8_unpack(Buffer *buf) except ?80: + return ( buf_read(buf, 1))[0] + +cdef inline uint64_t uint64_unpack(Buffer *buf) except ?0xDEAD: + cdef uint64_t x = ( buf_read(buf, 8))[0] + swap_order( &x, 8) + return x + +cdef inline uint32_t uint32_unpack(Buffer *buf) except ?0xDEAD: + cdef uint32_t x = ( buf_read(buf, 4))[0] + swap_order( &x, 4) + return x + +cdef inline uint16_t uint16_unpack(Buffer *buf) except ?0xDEAD: + cdef uint16_t x = ( buf_read(buf, 2))[0] + swap_order( &x, 2) + return x + +cdef inline uint8_t uint8_unpack(Buffer *buf) except ?0xff: + return ( buf_read(buf, 1))[0] + +cdef inline double double_unpack(Buffer *buf) except ?1.74: + cdef double x = ( buf_read(buf, 8))[0] + swap_order( &x, 8) + return x + +cdef inline float float_unpack(Buffer *buf) except ?1.74: + cdef float x = ( buf_read(buf, 4))[0] + swap_order( &x, 4) + return x + + +cdef varint_unpack(Buffer *term): + """Unpack a variable-sized integer""" + if PY3: + return varint_unpack_py3(to_bytes(term)) + else: + return varint_unpack_py2(to_bytes(term)) + +# TODO: Optimize these two functions +cdef varint_unpack_py3(bytes term): + cdef int64_t one = 1L + + val = int(''.join("%02x" % i for i in term), 16) + if (term[0] & 128) != 0: + # There is a bug in Cython (0.20 - 0.22), where if we do + # '1 << (len(term) * 8)' Cython generates '1' directly into the + # C code, causing integer overflows + val -= one << (len(term) * 8) + return val + +cdef varint_unpack_py2(bytes term): # noqa + cdef int64_t one = 1L + val = int(term.encode('hex'), 16) + if (ord(term[0]) & 128) != 0: + val = val - (one << (len(term) * 8)) + return val diff --git a/cassandra/cython_utils.pxd b/cassandra/cython_utils.pxd new file mode 100644 index 00000000..d2bf7d20 --- /dev/null +++ b/cassandra/cython_utils.pxd @@ -0,0 +1,2 @@ +from libc.stdint cimport int64_t +cdef datetime_from_timestamp(double timestamp) \ No newline at end of file diff --git a/cassandra/cython_utils.pyx b/cassandra/cython_utils.pyx new file mode 100644 index 00000000..1d16d47d --- /dev/null +++ b/cassandra/cython_utils.pyx @@ -0,0 +1,44 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Duplicate module of util.py, with some accelerated functions +used for deserialization. +""" + +from cpython.datetime cimport ( + timedelta_new, + # cdef inline object timedelta_new(int days, int seconds, int useconds) + # Create timedelta object using DateTime CAPI factory function. + # Note, there are no range checks for any of the arguments. + import_datetime, + # Datetime C API initialization function. + # You have to call it before any usage of DateTime CAPI functions. + ) + +import datetime +import sys + +cdef bint is_little_endian +from cassandra.util import is_little_endian + +import_datetime() + +DATETIME_EPOC = datetime.datetime(1970, 1, 1) + + +cdef datetime_from_timestamp(double timestamp): + cdef int seconds = timestamp + cdef int microseconds = ( (timestamp * 1000000)) % 1000000 + return DATETIME_EPOC + timedelta_new(0, seconds, microseconds) diff --git a/cassandra/deserializers.pxd b/cassandra/deserializers.pxd new file mode 100644 index 00000000..26b4429a --- /dev/null +++ b/cassandra/deserializers.pxd @@ -0,0 +1,43 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.buffer cimport Buffer + +cdef class Deserializer: + # The cqltypes._CassandraType corresponding to this deserializer + cdef object cqltype + + # String may be empty, whereas other values may not be. + # Other values may be NULL, in which case the integer length + # of the binary data is negative. However, non-string types + # may also return a zero length for legacy reasons + # (see http://code.metager.de/source/xref/apache/cassandra/doc/native_protocol_v3.spec + # paragraph 6) + cdef bint empty_binary_ok + + cdef deserialize(self, Buffer *buf, int protocol_version) + # cdef deserialize(self, CString byts, protocol_version) + + +cdef inline object from_binary(Deserializer deserializer, + Buffer *buf, + int protocol_version): + if buf.size < 0: + return None + elif buf.size == 0 and not deserializer.empty_binary_ok: + return _ret_empty(deserializer, buf.size) + else: + return deserializer.deserialize(buf, protocol_version) + +cdef _ret_empty(Deserializer deserializer, Py_ssize_t buf_size) diff --git a/cassandra/deserializers.pyx b/cassandra/deserializers.pyx new file mode 100644 index 00000000..e2f23284 --- /dev/null +++ b/cassandra/deserializers.pyx @@ -0,0 +1,512 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from libc.stdint cimport int32_t, uint16_t + +include 'cython_marshal.pyx' +from cassandra.buffer cimport Buffer, to_bytes, slice_buffer +from cassandra.cython_utils cimport datetime_from_timestamp + +from cython.view cimport array as cython_array +from cassandra.tuple cimport tuple_new, tuple_set + +import socket +from decimal import Decimal +from uuid import UUID + +from cassandra import cqltypes +from cassandra import util + +cdef bint PY2 = six.PY2 + +cdef class Deserializer: + """Cython-based deserializer class for a cqltype""" + + def __init__(self, cqltype): + self.cqltype = cqltype + self.empty_binary_ok = cqltype.empty_binary_ok + + cdef deserialize(self, Buffer *buf, int protocol_version): + raise NotImplementedError + + +cdef class DesBytesType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + return to_bytes(buf) + + +# TODO: Use libmpdec: http://www.bytereef.org/mpdecimal/index.html +cdef class DesDecimalType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + cdef Buffer varint_buf + slice_buffer(buf, &varint_buf, 4, buf.size - 4) + + scale = int32_unpack(buf) + unscaled = varint_unpack(&varint_buf) + + return Decimal('%de%d' % (unscaled, -scale)) + + +cdef class DesUUIDType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + return UUID(bytes=to_bytes(buf)) + + +cdef class DesBooleanType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + if int8_unpack(buf): + return True + return False + + +cdef class DesByteType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + return int8_unpack(buf) + + +cdef class DesAsciiType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + if PY2: + return to_bytes(buf) + return to_bytes(buf).decode('ascii') + + +cdef class DesFloatType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + return float_unpack(buf) + + +cdef class DesDoubleType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + return double_unpack(buf) + + +cdef class DesLongType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + return int64_unpack(buf) + + +cdef class DesInt32Type(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + return int32_unpack(buf) + + +cdef class DesIntegerType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + return varint_unpack(buf) + + +cdef class DesInetAddressType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + cdef bytes byts = to_bytes(buf) + + # TODO: optimize inet_ntop, inet_ntoa + if buf.size == 16: + return util.inet_ntop(socket.AF_INET6, byts) + else: + # util.inet_pton could also handle, but this is faster + # since we've already determined the AF + return socket.inet_ntoa(byts) + + +cdef class DesCounterColumnType(DesLongType): + pass + + +cdef class DesDateType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + cdef double timestamp = int64_unpack(buf) / 1000.0 + return datetime_from_timestamp(timestamp) + + +cdef class TimestampType(DesDateType): + pass + + +cdef class TimeUUIDType(DesDateType): + cdef deserialize(self, Buffer *buf, int protocol_version): + return UUID(bytes=to_bytes(buf)) + + +# Values of the 'date'` type are encoded as 32-bit unsigned integers +# representing a number of days with epoch (January 1st, 1970) at the center of the +# range (2^31). +EPOCH_OFFSET_DAYS = 2 ** 31 + +cdef class DesSimpleDateType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + days = uint32_unpack(buf) - EPOCH_OFFSET_DAYS + return util.Date(days) + + +cdef class DesShortType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + return int16_unpack(buf) + + +cdef class DesTimeType(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + return util.Time(int64_unpack(buf)) + + +cdef class DesUTF8Type(Deserializer): + cdef deserialize(self, Buffer *buf, int protocol_version): + cdef val = to_bytes(buf) + return val.decode('utf8') + + +cdef class DesVarcharType(DesUTF8Type): + pass + + +cdef class _DesParameterizedType(Deserializer): + + cdef object subtypes + cdef Deserializer[::1] deserializers + cdef Py_ssize_t subtypes_len + + def __init__(self, cqltype): + super().__init__(cqltype) + self.subtypes = cqltype.subtypes + self.deserializers = make_deserializers(cqltype.subtypes) + self.subtypes_len = len(self.subtypes) + + +cdef class _DesSingleParamType(_DesParameterizedType): + cdef Deserializer deserializer + + def __init__(self, cqltype): + assert cqltype.subtypes and len(cqltype.subtypes) == 1, cqltype.subtypes + super().__init__(cqltype) + self.deserializer = self.deserializers[0] + + +#-------------------------------------------------------------------------- +# List and set deserialization + +cdef class DesListType(_DesSingleParamType): + cdef deserialize(self, Buffer *buf, int protocol_version): + cdef uint16_t v2_and_below = 2 + cdef int32_t v3_and_above = 3 + + if protocol_version >= 3: + result = _deserialize_list_or_set[int32_t]( + v3_and_above, buf, protocol_version, self.deserializer) + else: + result = _deserialize_list_or_set[uint16_t]( + v2_and_below, buf, protocol_version, self.deserializer) + + return result + +cdef class DesSetType(DesListType): + cdef deserialize(self, Buffer *buf, int protocol_version): + return util.sortedset(DesListType.deserialize(self, buf, protocol_version)) + + +ctypedef fused itemlen_t: + uint16_t # protocol <= v2 + int32_t # protocol >= v3 + +cdef list _deserialize_list_or_set(itemlen_t dummy_version, + Buffer *buf, int protocol_version, + Deserializer deserializer): + """ + Deserialize a list or set. + + The 'dummy' parameter is needed to make fused types work, so that + we can specialize on the protocol version. + """ + cdef itemlen_t itemlen + cdef Buffer itemlen_buf + cdef Buffer elem_buf + + cdef itemlen_t numelements + cdef itemlen_t idx + cdef list result = [] + + _unpack_len[itemlen_t](0, &numelements, buf) + idx = sizeof(itemlen_t) + + for _ in range(numelements): + subelem(buf, &elem_buf, &idx) + result.append(from_binary(deserializer, &elem_buf, protocol_version)) + + return result + + +cdef inline int subelem( + Buffer *buf, Buffer *elem_buf, itemlen_t *idx_p) except -1: + """ + Read the next element from the buffer: first read the size (in bytes) of the + element, then fill elem_buf with a newly sliced buffer of this size (and the + right offset). + + NOTE: The handling of 'idx' is somewhat atrocious, as there is a Cython + bug with the combination fused types + 'except' clause. + So instead, we pass in a pointer to 'idx', namely 'idx_p', and write + to this instead. + """ + cdef itemlen_t elemlen + + _unpack_len[itemlen_t](idx_p[0], &elemlen, buf) + idx_p[0] += sizeof(itemlen_t) + slice_buffer(buf, elem_buf, idx_p[0], elemlen) + idx_p[0] += elemlen + return 0 + + +cdef int _unpack_len(itemlen_t idx, itemlen_t *elemlen, Buffer *buf) except -1: + cdef itemlen_t result + cdef Buffer itemlen_buf + slice_buffer(buf, &itemlen_buf, idx, sizeof(itemlen_t)) + + if itemlen_t is uint16_t: + elemlen[0] = uint16_unpack(&itemlen_buf) + else: + elemlen[0] = int32_unpack(&itemlen_buf) + + return 0 + +#-------------------------------------------------------------------------- +# Map deserialization + +cdef class DesMapType(_DesParameterizedType): + + cdef Deserializer key_deserializer, val_deserializer + + def __init__(self, cqltype): + super().__init__(cqltype) + self.key_deserializer = self.deserializers[0] + self.val_deserializer = self.deserializers[1] + + cdef deserialize(self, Buffer *buf, int protocol_version): + cdef uint16_t v2_and_below = 0 + cdef int32_t v3_and_above = 0 + key_type, val_type = self.cqltype.subtypes + + if protocol_version >= 3: + result = _deserialize_map[int32_t]( + v3_and_above, buf, protocol_version, + self.key_deserializer, self.val_deserializer, + key_type, val_type) + else: + result = _deserialize_map[uint16_t]( + v2_and_below, buf, protocol_version, + self.key_deserializer, self.val_deserializer, + key_type, val_type) + + return result + + +cdef _deserialize_map(itemlen_t dummy_version, + Buffer *buf, int protocol_version, + Deserializer key_deserializer, Deserializer val_deserializer, + key_type, val_type): + cdef itemlen_t itemlen, val_len, key_len + cdef Buffer key_buf, val_buf + cdef Buffer itemlen_buf + + cdef itemlen_t numelements + cdef itemlen_t idx = sizeof(itemlen_t) + cdef list result = [] + + _unpack_len[itemlen_t](0, &numelements, buf) + idx = sizeof(itemlen_t) + themap = util.OrderedMapSerializedKey(key_type, protocol_version) + for _ in range(numelements): + subelem(buf, &key_buf, &idx) + subelem(buf, &val_buf, &idx) + key = from_binary(key_deserializer, &key_buf, protocol_version) + val = from_binary(val_deserializer, &val_buf, protocol_version) + themap._insert_unchecked(key, to_bytes(&key_buf), val) + + return themap + +#-------------------------------------------------------------------------- + +cdef class DesTupleType(_DesParameterizedType): + + # TODO: Use TupleRowParser to parse these tuples + + cdef deserialize(self, Buffer *buf, int protocol_version): + cdef Py_ssize_t i, p + cdef int32_t itemlen + cdef tuple res = tuple_new(self.subtypes_len) + cdef Buffer item_buf + cdef Buffer itemlen_buf + cdef Deserializer deserializer + + # collections inside UDTs are always encoded with at least the + # version 3 format + protocol_version = max(3, protocol_version) + + p = 0 + values = [] + for i in range(self.subtypes_len): + item = None + if p < buf.size: + slice_buffer(buf, &itemlen_buf, p, 4) + itemlen = int32_unpack(&itemlen_buf) + p += 4 + if itemlen >= 0: + slice_buffer(buf, &item_buf, p, itemlen) + p += itemlen + + deserializer = self.deserializers[i] + item = from_binary(deserializer, &item_buf, protocol_version) + + tuple_set(res, i, item) + + return res + + +cdef class DesUserType(DesTupleType): + cdef deserialize(self, Buffer *buf, int protocol_version): + typ = self.cqltype + values = DesTupleType.deserialize(self, buf, protocol_version) + if typ.mapped_class: + return typ.mapped_class(**dict(zip(typ.fieldnames, values))) + else: + return typ.tuple_type(*values) + + +cdef class DesCompositeType(_DesParameterizedType): + cdef deserialize(self, Buffer *buf, int protocol_version): + cdef Py_ssize_t i, idx, start + cdef Buffer elem_buf + cdef int16_t element_length + cdef Deserializer deserializer + cdef tuple res = tuple_new(self.subtypes_len) + + idx = 0 + for i in range(self.subtypes_len): + if not buf.size: + # CompositeType can have missing elements at the end + + # Fill the tuple with None values and slice it + # + # (I'm not sure a tuple needs to be fully initialized before + # it can be destroyed, so play it safe) + for j in range(i, self.subtypes_len): + tuple_set(res, j, None) + res = res[:i] + break + + element_length = uint16_unpack(buf) + slice_buffer(buf, &elem_buf, 2, element_length) + + deserializer = self.deserializers[i] + item = from_binary(deserializer, &elem_buf, protocol_version) + tuple_set(res, i, item) + + # skip element length, element, and the EOC (one byte) + start = 2 + element_length + 1 + slice_buffer(buf, buf, start, buf.size - start) + + return res + + +DesDynamicCompositeType = DesCompositeType + + +cdef class DesReversedType(_DesSingleParamType): + cdef deserialize(self, Buffer *buf, int protocol_version): + return from_binary(self.deserializer, buf, protocol_version) + + +cdef class DesFrozenType(_DesSingleParamType): + cdef deserialize(self, Buffer *buf, int protocol_version): + return from_binary(self.deserializer, buf, protocol_version) + +#-------------------------------------------------------------------------- + +cdef _ret_empty(Deserializer deserializer, Py_ssize_t buf_size): + """ + Decide whether to return None or EMPTY when a buffer size is + zero or negative. This is used by from_binary in deserializers.pxd. + """ + if buf_size < 0: + return None + elif deserializer.cqltype.support_empty_values: + return cqltypes.EMPTY + else: + return None + +#-------------------------------------------------------------------------- +# Generic deserialization + +cdef class GenericDeserializer(Deserializer): + """ + Wrap a generic datatype for deserialization + """ + + cdef deserialize(self, Buffer *buf, int protocol_version): + return self.cqltype.deserialize(to_bytes(buf), protocol_version) + + def __repr__(self): + return "GenericDeserializer(%s)" % (self.cqltype,) + +#-------------------------------------------------------------------------- +# Helper utilities + +def make_deserializers(cqltypes): + """Create an array of Deserializers for each given cqltype in cqltypes""" + cdef Deserializer[::1] deserializers + return obj_array([find_deserializer(ct) for ct in cqltypes]) + + +cdef dict classes = globals() + +cpdef Deserializer find_deserializer(cqltype): + """Find a deserializer for a cqltype""" + name = 'Des' + cqltype.__name__ + + if name in globals(): + cls = classes[name] + elif issubclass(cqltype, cqltypes.ListType): + cls = DesListType + elif issubclass(cqltype, cqltypes.SetType): + cls = DesSetType + elif issubclass(cqltype, cqltypes.MapType): + cls = DesMapType + elif issubclass(cqltype, cqltypes.UserType): + # UserType is a subclass of TupleType, so should precede it + cls = DesUserType + elif issubclass(cqltype, cqltypes.TupleType): + cls = DesTupleType + elif issubclass(cqltype, cqltypes.DynamicCompositeType): + # DynamicCompositeType is a subclass of CompositeType, so should precede it + cls = DesDynamicCompositeType + elif issubclass(cqltype, cqltypes.CompositeType): + cls = DesCompositeType + elif issubclass(cqltype, cqltypes.ReversedType): + cls = DesReversedType + elif issubclass(cqltype, cqltypes.FrozenType): + cls = DesFrozenType + else: + cls = GenericDeserializer + + return cls(cqltype) + + +def obj_array(list objs): + """Create a (Cython) array of objects given a list of objects""" + cdef object[:] arr + cdef Py_ssize_t i + arr = cython_array(shape=(len(objs),), itemsize=sizeof(void *), format="O") + # arr[:] = objs # This does not work (segmentation faults) + for i, obj in enumerate(objs): + arr[i] = obj + return arr diff --git a/cassandra/ioutils.pyx b/cassandra/ioutils.pyx new file mode 100644 index 00000000..c38b311a --- /dev/null +++ b/cassandra/ioutils.pyx @@ -0,0 +1,47 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +include 'cython_marshal.pyx' +from cassandra.buffer cimport Buffer, from_ptr_and_size + +from libc.stdint cimport int32_t +from cassandra.bytesio cimport BytesIOReader + + +cdef inline int get_buf(BytesIOReader reader, Buffer *buf_out) except -1: + """ + Get a pointer into the buffer provided by BytesIOReader for the + next data item in the stream of values. + + BEWARE: + If the next item has a zero negative size, the pointer will be set to NULL. + A negative size happens when the value is NULL in the database, whereas a + zero size may happen either for legacy reasons, or for data types such as + strings (which may be empty). + """ + cdef Py_ssize_t raw_val_size = read_int(reader) + cdef char *ptr + if raw_val_size <= 0: + ptr = NULL + else: + ptr = reader.read(raw_val_size) + + from_ptr_and_size(ptr, raw_val_size, buf_out) + return 0 + +cdef inline int32_t read_int(BytesIOReader reader) except ?0xDEAD: + cdef Buffer buf + buf.ptr = reader.read(4) + buf.size = 4 + return int32_unpack(&buf) diff --git a/cassandra/numpyFlags.h b/cassandra/numpyFlags.h new file mode 100644 index 00000000..6793b7a8 --- /dev/null +++ b/cassandra/numpyFlags.h @@ -0,0 +1 @@ +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION \ No newline at end of file diff --git a/cassandra/numpy_parser.pyx b/cassandra/numpy_parser.pyx new file mode 100644 index 00000000..6702cfcc --- /dev/null +++ b/cassandra/numpy_parser.pyx @@ -0,0 +1,185 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module provider an optional protocol parser that returns +NumPy arrays. + +============================================================================= +This module should not be imported by any of the main python-driver modules, +as numpy is an optional dependency. +============================================================================= +""" + +include "ioutils.pyx" + +cimport cython +from libc.stdint cimport uint64_t +from cpython.ref cimport Py_INCREF, PyObject + +from cassandra.bytesio cimport BytesIOReader +from cassandra.deserializers cimport Deserializer, from_binary +from cassandra.parsing cimport ParseDesc, ColumnParser, RowParser +from cassandra import cqltypes +from cassandra.util import is_little_endian + +import numpy as np +# import pandas as pd + +cdef extern from "numpyFlags.h": + # Include 'numpyFlags.h' into the generated C code to disable the + # deprecated NumPy API + pass + +cdef extern from "Python.h": + # An integer type large enough to hold a pointer + ctypedef uint64_t Py_uintptr_t + + +# Simple array descriptor, useful to parse rows into a NumPy array +ctypedef struct ArrDesc: + Py_uintptr_t buf_ptr + int stride # should be large enough as we allocate contiguous arrays + int is_object + +arrDescDtype = np.dtype( + [ ('buf_ptr', np.uintp) + , ('stride', np.dtype('i')) + , ('is_object', np.dtype('i')) + ]) + +_cqltype_to_numpy = { + cqltypes.LongType: np.dtype('>i8'), + cqltypes.CounterColumnType: np.dtype('>i8'), + cqltypes.Int32Type: np.dtype('>i4'), + cqltypes.ShortType: np.dtype('>i2'), + cqltypes.FloatType: np.dtype('>f4'), + cqltypes.DoubleType: np.dtype('>f8'), +} + +obj_dtype = np.dtype('O') + + +cdef class NumpyParser(ColumnParser): + """Decode a ResultMessage into a bunch of NumPy arrays""" + + cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc): + cdef Py_ssize_t rowcount + cdef ArrDesc[::1] array_descs + cdef ArrDesc *arrs + + rowcount = read_int(reader) + array_descs, arrays = make_arrays(desc, rowcount) + arrs = &array_descs[0] + + _parse_rows(reader, desc, arrs, rowcount) + + arrays = [make_native_byteorder(arr) for arr in arrays] + result = dict(zip(desc.colnames, arrays)) + return result + + +cdef _parse_rows(BytesIOReader reader, ParseDesc desc, + ArrDesc *arrs, Py_ssize_t rowcount): + cdef Py_ssize_t i + + for i in range(rowcount): + unpack_row(reader, desc, arrs) + + +### Helper functions to create NumPy arrays and array descriptors + +def make_arrays(ParseDesc desc, array_size): + """ + Allocate arrays for each result column. + + returns a tuple of (array_descs, arrays), where + 'array_descs' describe the arrays for NativeRowParser and + 'arrays' is a dict mapping column names to arrays + (e.g. this can be fed into pandas.DataFrame) + """ + array_descs = np.empty((desc.rowsize,), arrDescDtype) + arrays = [] + + for i, coltype in enumerate(desc.coltypes): + arr = make_array(coltype, array_size) + array_descs[i]['buf_ptr'] = arr.ctypes.data + array_descs[i]['stride'] = arr.strides[0] + array_descs[i]['is_object'] = coltype not in _cqltype_to_numpy + arrays.append(arr) + + return array_descs, arrays + + +def make_array(coltype, array_size): + """ + Allocate a new NumPy array of the given column type and size. + """ + dtype = _cqltype_to_numpy.get(coltype, obj_dtype) + return np.empty((array_size,), dtype=dtype) + + +#### Parse rows into NumPy arrays + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef inline int unpack_row( + BytesIOReader reader, ParseDesc desc, ArrDesc *arrays) except -1: + cdef Buffer buf + cdef Py_ssize_t i, rowsize = desc.rowsize + cdef ArrDesc arr + cdef Deserializer deserializer + + for i in range(rowsize): + get_buf(reader, &buf) + arr = arrays[i] + + if buf.size == 0: + raise ValueError("Cannot handle NULL value") + if arr.is_object: + deserializer = desc.deserializers[i] + val = from_binary(deserializer, &buf, desc.protocol_version) + Py_INCREF(val) + ( arr.buf_ptr)[0] = val + else: + memcopy(buf.ptr, arr.buf_ptr, buf.size) + + # Update the pointer into the array for the next time + arrays[i].buf_ptr += arr.stride + + return 0 + + +cdef inline void memcopy(char *src, char *dst, Py_ssize_t size): + """ + Our own simple memcopy which can be inlined. This is useful because our data types + are only a few bytes. + """ + cdef Py_ssize_t i + for i in range(size): + dst[i] = src[i] + + +def make_native_byteorder(arr): + """ + Make sure all values have a native endian in the NumPy arrays. + """ + if is_little_endian and not arr.dtype.kind == 'O': + # We have arrays in big-endian order. First swap the bytes + # into little endian order, and then update the numpy dtype + # accordingly (e.g. from '>i8' to '= 0 + + cdef Buffer buf + cdef Py_ssize_t i, rowsize = desc.rowsize + cdef Deserializer deserializer + cdef tuple res = tuple_new(desc.rowsize) + + for i in range(rowsize): + # Read the next few bytes + get_buf(reader, &buf) + + # Deserialize bytes to python object + deserializer = desc.deserializers[i] + val = from_binary(deserializer, &buf, desc.protocol_version) + + # Insert new object into tuple + tuple_set(res, i, val) + + return res diff --git a/cassandra/parsing.pxd b/cassandra/parsing.pxd new file mode 100644 index 00000000..278c6e71 --- /dev/null +++ b/cassandra/parsing.pxd @@ -0,0 +1,30 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.bytesio cimport BytesIOReader +from cassandra.deserializers cimport Deserializer + +cdef class ParseDesc: + cdef public object colnames + cdef public object coltypes + cdef Deserializer[::1] deserializers + cdef public int protocol_version + cdef Py_ssize_t rowsize + +cdef class ColumnParser: + cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc) + +cdef class RowParser: + cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc) + diff --git a/cassandra/parsing.pyx b/cassandra/parsing.pyx new file mode 100644 index 00000000..c44d7f5a --- /dev/null +++ b/cassandra/parsing.pyx @@ -0,0 +1,44 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Module containing the definitions and declarations (parsing.pxd) for parsers. +""" + +cdef class ParseDesc: + """Description of what structure to parse""" + + def __init__(self, colnames, coltypes, deserializers, protocol_version): + self.colnames = colnames + self.coltypes = coltypes + self.deserializers = deserializers + self.protocol_version = protocol_version + self.rowsize = len(colnames) + + +cdef class ColumnParser: + """Decode a ResultMessage into a set of columns""" + + cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc): + raise NotImplementedError + + +cdef class RowParser: + """Parser for a single row""" + + cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc): + """ + Unpack a single row of data in a ResultMessage. + """ + raise NotImplementedError diff --git a/cassandra/protocol.py b/cassandra/protocol.py index 29f948d4..9abc5095 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -22,6 +22,7 @@ import six from six.moves import range import io +from cassandra import type_codes from cassandra import (Unavailable, WriteTimeout, ReadTimeout, WriteFailure, ReadFailure, FunctionFailure, AlreadyExists, InvalidRequest, Unauthorized, @@ -35,10 +36,11 @@ from cassandra.cqltypes import (AsciiType, BytesType, BooleanType, DoubleType, FloatType, Int32Type, InetAddressType, IntegerType, ListType, LongType, MapType, SetType, TimeUUIDType, - UTF8Type, UUIDType, UserType, + UTF8Type, VarcharType, UUIDType, UserType, TupleType, lookup_casstype, SimpleDateType, TimeType, ByteType, ShortType) from cassandra.policies import WriteType +from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY from cassandra import util log = logging.getLogger(__name__) @@ -68,10 +70,16 @@ _message_types_by_opcode = {} _UNSET_VALUE = object() +def register_class(cls): + _message_types_by_opcode[cls.opcode] = cls + +def get_registered_classes(): + return _message_types_by_opcode.copy() + class _RegisterMessageType(type): def __init__(cls, name, bases, dct): if not name.startswith('_'): - _message_types_by_opcode[cls.opcode] = cls + register_class(cls) @six.add_metaclass(_RegisterMessageType) @@ -531,35 +539,6 @@ RESULT_KIND_SET_KEYSPACE = 0x0003 RESULT_KIND_PREPARED = 0x0004 RESULT_KIND_SCHEMA_CHANGE = 0x0005 -class CassandraTypeCodes(object): - CUSTOM_TYPE = 0x0000 - AsciiType = 0x0001 - LongType = 0x0002 - BytesType = 0x0003 - BooleanType = 0x0004 - CounterColumnType = 0x0005 - DecimalType = 0x0006 - DoubleType = 0x0007 - FloatType = 0x0008 - Int32Type = 0x0009 - UTF8Type = 0x000A - DateType = 0x000B - UUIDType = 0x000C - UTF8Type = 0x000D - IntegerType = 0x000E - TimeUUIDType = 0x000F - InetAddressType = 0x0010 - SimpleDateType = 0x0011 - TimeType = 0x0012 - ShortType = 0x0013 - ByteType = 0x0014 - ListType = 0x0020 - MapType = 0x0021 - SetType = 0x0022 - UserType = 0x0030 - TupleType = 0x0031 - - class ResultMessage(_MessageType): opcode = 0x08 name = 'RESULT' @@ -569,7 +548,7 @@ class ResultMessage(_MessageType): paging_state = None # Names match type name in module scope. Most are imported from cassandra.cqltypes (except CUSTOM_TYPE) - type_codes = _cqltypes_by_code = dict((v, globals()[k]) for k, v in CassandraTypeCodes.__dict__.items() if not k.startswith('_')) + type_codes = _cqltypes_by_code = dict((v, globals()[k]) for k, v in type_codes.__dict__.items() if not k.startswith('_')) _FLAGS_GLOBAL_TABLES_SPEC = 0x0001 _HAS_MORE_PAGES_FLAG = 0x0002 @@ -1017,6 +996,63 @@ class ProtocolHandler(object): return msg +def cython_protocol_handler(colparser): + """ + Given a column parser to deserialize ResultMessages, return a suitable + Cython-based protocol handler. + + There are three Cython-based protocol handlers (least to most performant): + + 1. obj_parser.ListParser + this parser decodes result messages into a list of tuples + + 2. obj_parser.LazyParser + this parser decodes result messages lazily by returning an iterator + + 3. numpy_parser.NumPyParser + this parser decodes result messages into NumPy arrays + + The default is to use obj_parser.ListParser + """ + from cassandra.row_parser import make_recv_results_rows + + class FastResultMessage(ResultMessage): + """ + Cython version of Result Message that has a faster implementation of + recv_results_row. + """ + # type_codes = ResultMessage.type_codes.copy() + code_to_type = dict((v, k) for k, v in ResultMessage.type_codes.items()) + recv_results_rows = classmethod(make_recv_results_rows(colparser)) + + class CythonProtocolHandler(ProtocolHandler): + """ + Use FastResultMessage to decode query result message messages. + """ + + my_opcodes = ProtocolHandler.message_types_by_opcode.copy() + my_opcodes[FastResultMessage.opcode] = FastResultMessage + message_types_by_opcode = my_opcodes + + return CythonProtocolHandler + + +if HAVE_CYTHON: + from cassandra.obj_parser import ListParser, LazyParser + ProtocolHandler = cython_protocol_handler(ListParser()) + LazyProtocolHandler = cython_protocol_handler(LazyParser()) +else: + # Use Python-based ProtocolHandler + LazyProtocolHandler = None + + +if HAVE_CYTHON and HAVE_NUMPY: + from cassandra.numpy_parser import NumpyParser + NumpyProtocolHandler = cython_protocol_handler(NumpyParser()) +else: + NumpyProtocolHandler = None + + def read_byte(f): return int8_unpack(f.read(1)) diff --git a/cassandra/row_parser.pyx b/cassandra/row_parser.pyx new file mode 100644 index 00000000..fc7bce15 --- /dev/null +++ b/cassandra/row_parser.pyx @@ -0,0 +1,38 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.parsing cimport ParseDesc, ColumnParser +from cassandra.deserializers import make_deserializers + +include "ioutils.pyx" + +def make_recv_results_rows(ColumnParser colparser): + def recv_results_rows(cls, f, int protocol_version, user_type_map): + """ + Parse protocol data given as a BytesIO f into a set of columns (e.g. list of tuples) + This is used as the recv_results_rows method of (Fast)ResultMessage + """ + paging_state, column_metadata = cls.recv_results_metadata(f, user_type_map) + + colnames = [c[2] for c in column_metadata] + coltypes = [c[3] for c in column_metadata] + + desc = ParseDesc(colnames, coltypes, make_deserializers(coltypes), + protocol_version) + reader = BytesIOReader(f.read()) + parsed_rows = colparser.parse_rows(reader, desc) + + return (paging_state, (colnames, parsed_rows)) + + return recv_results_rows diff --git a/cassandra/tuple.pxd b/cassandra/tuple.pxd new file mode 100644 index 00000000..746205e2 --- /dev/null +++ b/cassandra/tuple.pxd @@ -0,0 +1,41 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cpython.tuple cimport ( + PyTuple_New, + # Return value: New reference. + # Return a new tuple object of size len, or NULL on failure. + PyTuple_SET_ITEM, + # Like PyTuple_SetItem(), but does no error checking, and should + # only be used to fill in brand new tuples. Note: This function + # ``steals'' a reference to o. + ) + +from cpython.ref cimport ( + Py_INCREF + # void Py_INCREF(object o) + # Increment the reference count for object o. The object must not + # be NULL; if you aren't sure that it isn't NULL, use + # Py_XINCREF(). + ) + +cdef inline tuple tuple_new(Py_ssize_t n): + """Allocate a new tuple object""" + return PyTuple_New(n) + +cdef inline void tuple_set(tuple tup, Py_ssize_t idx, object item): + """Insert new object into tuple. No item must have been set yet.""" + # PyTuple_SET_ITEM steals a reference, so we need to INCREF + Py_INCREF(item) + PyTuple_SET_ITEM(tup, idx, item) diff --git a/cassandra/type_codes.pxd b/cassandra/type_codes.pxd new file mode 100644 index 00000000..90f29bc9 --- /dev/null +++ b/cassandra/type_codes.pxd @@ -0,0 +1,42 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cdef enum: + CUSTOM_TYPE + AsciiType + LongType + BytesType + BooleanType + CounterColumnType + DecimalType + DoubleType + FloatType + Int32Type + UTF8Type + DateType + UUIDType + VarcharType + IntegerType + TimeUUIDType + InetAddressType + SimpleDateType + TimeType + ShortType + ByteType + ListType + MapType + SetType + UserType + TupleType + diff --git a/cassandra/type_codes.py b/cassandra/type_codes.py new file mode 100644 index 00000000..2f0ce8f5 --- /dev/null +++ b/cassandra/type_codes.py @@ -0,0 +1,62 @@ +""" +Module with constants for Cassandra type codes. + +These constants are useful for + + a) mapping messages to cqltypes (cassandra/cqltypes.py) + b) optimized dispatching for (de)serialization (cassandra/encoding.py) + +Type codes are repeated here from the Cassandra binary protocol specification: + + 0x0000 Custom: the value is a [string], see above. + 0x0001 Ascii + 0x0002 Bigint + 0x0003 Blob + 0x0004 Boolean + 0x0005 Counter + 0x0006 Decimal + 0x0007 Double + 0x0008 Float + 0x0009 Int + 0x000A Text + 0x000B Timestamp + 0x000C Uuid + 0x000D Varchar + 0x000E Varint + 0x000F Timeuuid + 0x0010 Inet + 0x0020 List: the value is an [option], representing the type + of the elements of the list. + 0x0021 Map: the value is two [option], representing the types of the + keys and values of the map + 0x0022 Set: the value is an [option], representing the type + of the elements of the set +""" + +CUSTOM_TYPE = 0x0000 +AsciiType = 0x0001 +LongType = 0x0002 +BytesType = 0x0003 +BooleanType = 0x0004 +CounterColumnType = 0x0005 +DecimalType = 0x0006 +DoubleType = 0x0007 +FloatType = 0x0008 +Int32Type = 0x0009 +UTF8Type = 0x000A +DateType = 0x000B +UUIDType = 0x000C +VarcharType = 0x000D +IntegerType = 0x000E +TimeUUIDType = 0x000F +InetAddressType = 0x0010 +SimpleDateType = 0x0011 +TimeType = 0x0012 +ShortType = 0x0013 +ByteType = 0x0014 +ListType = 0x0020 +MapType = 0x0021 +SetType = 0x0022 +UserType = 0x0030 +TupleType = 0x0031 + diff --git a/cassandra/util.py b/cassandra/util.py index 83260577..c71822c2 100644 --- a/cassandra/util.py +++ b/cassandra/util.py @@ -1,12 +1,29 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import with_statement import calendar import datetime import random import six import uuid +import sys DATETIME_EPOC = datetime.datetime(1970, 1, 1) +assert sys.byteorder in ('little', 'big') +is_little_endian = sys.byteorder == 'little' def datetime_from_timestamp(timestamp): """ @@ -490,8 +507,7 @@ except ImportError: def __init__(self, iterable=()): self._items = [] - for i in iterable: - self.add(i) + self.update(iterable) def __len__(self): return len(self._items) @@ -564,6 +580,10 @@ except ImportError: else: self._items.append(item) + def update(self, iterable): + for i in iterable: + self.add(i) + def clear(self): del self._items[:] diff --git a/docs/api/cassandra/protocol.rst b/docs/api/cassandra/protocol.rst index cabf2b59..0d4df101 100644 --- a/docs/api/cassandra/protocol.rst +++ b/docs/api/cassandra/protocol.rst @@ -24,3 +24,17 @@ See :meth:`.Session.execute`, ::meth:`.Session.execute_async`, :attr:`.ResponseF .. automethod:: encode_message .. automethod:: decode_message + +Faster Deserialization +---------------------- +When python-driver is compiled with Cython, it uses a Cython-based deserialization path +to deserialize messages. There are two additional ProtocolHandler classes that can be +used to deserialize response messages: the first is ``LazyProtocolHandler`` and the +second is ``NumpyProtocolHandler``.They can be used as follows: + +.. code:: python + + from cassandra.protocol import NumpyProtocolHandler, LazyProtocolHandler + s.client_protocol_handler = LazyProtocolHandler # for a result iterator + s.client_protocol_handler = NumpyProtocolHandler # for a dict of NumPy arrays as result + diff --git a/setup.py b/setup.py index d27cf165..bf005bc6 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,6 @@ from distutils.errors import (CCompilerError, DistutilsPlatformError, DistutilsExecError) from distutils.cmd import Command - try: import subprocess has_subprocess = True @@ -71,6 +70,7 @@ if __name__ == '__main__' and sys.argv[1] == "install": except ImportError: pass +PROFILING = False class DocCommand(Command): @@ -263,11 +263,16 @@ if "--no-libev" not in sys.argv and not is_windows: if "--no-cython" not in sys.argv: try: from Cython.Build import cythonize - cython_candidates = ['cluster', 'concurrent', 'connection', 'cqltypes', 'marshal', 'metadata', 'pool', 'protocol', 'query', 'util'] + cython_candidates = ['cluster', 'concurrent', 'connection', 'cqltypes', 'metadata', + 'pool', 'protocol', 'query', 'util'] compile_args = [] if is_windows else ['-Wno-unused-function'] extensions.extend(cythonize( - [Extension('cassandra.%s' % m, ['cassandra/%s.py' % m], extra_compile_args=compile_args) for m in cython_candidates], + [Extension('cassandra.%s' % m, ['cassandra/%s.py' % m], + extra_compile_args=compile_args) + for m in cython_candidates], exclude_failures=True)) + extensions.extend(cythonize("cassandra/*.pyx")) + extensions.extend(cythonize("tests/unit/cython/*.pyx")) except ImportError: sys.stderr.write("Cython is not installed. Not compiling core driver files as extensions (optional).") diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index 057da5c8..f609492f 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -160,7 +160,7 @@ def remove_cluster(): CCM_CLUSTER.remove() CCM_CLUSTER = None return - except WindowsError: + except OSError: ex_type, ex, tb = sys.exc_info() log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb diff --git a/tests/integration/cqlengine/columns/test_container_columns.py b/tests/integration/cqlengine/columns/test_container_columns.py index 213c625c..ad67419c 100644 --- a/tests/integration/cqlengine/columns/test_container_columns.py +++ b/tests/integration/cqlengine/columns/test_container_columns.py @@ -386,7 +386,8 @@ class TestMapColumn(BaseCassEngTestCase): k2 = uuid4() now = datetime.now() then = now + timedelta(days=1) - m1 = TestMapModel.create(int_map={1: k1, 2: k2}, text_map={'now': now, 'then': then}) + m1 = TestMapModel.create(int_map={1: k1, 2: k2}, + text_map={'now': now, 'then': then}) m2 = TestMapModel.get(partition=m1.partition) self.assertTrue(isinstance(m2.int_map, dict)) diff --git a/tests/integration/cqlengine/query/test_queryset.py b/tests/integration/cqlengine/query/test_queryset.py index 7bb101b9..45277520 100644 --- a/tests/integration/cqlengine/query/test_queryset.py +++ b/tests/integration/cqlengine/query/test_queryset.py @@ -629,7 +629,7 @@ class TestMinMaxTimeUUIDFunctions(BaseCassEngTestCase): # test kwarg filtering q = TimeUUIDQueryModel.filter(partition=pk, time__lte=functions.MaxTimeUUID(midpoint)) q = [d for d in q] - assert len(q) == 2 + self.assertEqual(len(q), 2, msg="Got: %s" % q) datas = [d.data for d in q] assert '1' in datas assert '2' in datas diff --git a/tests/integration/cqlengine/query/test_updates.py b/tests/integration/cqlengine/query/test_updates.py index a3b80f15..6c539012 100644 --- a/tests/integration/cqlengine/query/test_updates.py +++ b/tests/integration/cqlengine/query/test_updates.py @@ -52,17 +52,17 @@ class QueryUpdateTests(BaseCassEngTestCase): # sanity check for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): - assert row.cluster == i - assert row.count == i - assert row.text == str(i) + self.assertEqual(row.cluster, i) + self.assertEqual(row.count, i) + self.assertEqual(row.text, str(i)) # perform update TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6) for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): - assert row.cluster == i - assert row.count == (6 if i == 3 else i) - assert row.text == str(i) + self.assertEqual(row.cluster, i) + self.assertEqual(row.count, 6 if i == 3 else i) + self.assertEqual(row.text, str(i)) def test_update_values_validation(self): """ tests calling udpate on models with values passed in """ @@ -72,9 +72,9 @@ class QueryUpdateTests(BaseCassEngTestCase): # sanity check for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): - assert row.cluster == i - assert row.count == i - assert row.text == str(i) + self.assertEqual(row.cluster, i) + self.assertEqual(row.count, i) + self.assertEqual(row.text, str(i)) # perform update with self.assertRaises(ValidationError): @@ -98,17 +98,17 @@ class QueryUpdateTests(BaseCassEngTestCase): # sanity check for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): - assert row.cluster == i - assert row.count == i - assert row.text == str(i) + self.assertEqual(row.cluster, i) + self.assertEqual(row.count, i) + self.assertEqual(row.text, str(i)) # perform update TestQueryUpdateModel.objects(partition=partition, cluster=3).update(text=None) for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): - assert row.cluster == i - assert row.count == i - assert row.text == (None if i == 3 else str(i)) + self.assertEqual(row.cluster, i) + self.assertEqual(row.count, i) + self.assertEqual(row.text, None if i == 3 else str(i)) def test_mixed_value_and_null_update(self): """ tests that updating a columns value, and removing another works properly """ @@ -118,17 +118,17 @@ class QueryUpdateTests(BaseCassEngTestCase): # sanity check for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): - assert row.cluster == i - assert row.count == i - assert row.text == str(i) + self.assertEqual(row.cluster, i) + self.assertEqual(row.count, i) + self.assertEqual(row.text, str(i)) # perform update TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6, text=None) for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): - assert row.cluster == i - assert row.count == (6 if i == 3 else i) - assert row.text == (None if i == 3 else str(i)) + self.assertEqual(row.cluster, i) + self.assertEqual(row.count, 6 if i == 3 else i) + self.assertEqual(row.text, None if i == 3 else str(i)) def test_counter_updates(self): pass diff --git a/tests/integration/long/test_schema.py b/tests/integration/long/test_schema.py index 7da5203f..6f165cfd 100644 --- a/tests/integration/long/test_schema.py +++ b/tests/integration/long/test_schema.py @@ -85,11 +85,11 @@ class SchemaTests(unittest.TestCase): session = self.session - for i in xrange(30): + for i in range(30): execute_until_pass(session, "CREATE KEYSPACE test_{0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}".format(i)) execute_until_pass(session, "CREATE TABLE test_{0}.cf (key int PRIMARY KEY, value int)".format(i)) - for j in xrange(100): + for j in range(100): execute_until_pass(session, "INSERT INTO test_{0}.cf (key, value) VALUES ({1}, {1})".format(i, j)) execute_until_pass(session, "DROP KEYSPACE test_{0}".format(i)) @@ -102,7 +102,7 @@ class SchemaTests(unittest.TestCase): cluster = Cluster(protocol_version=PROTOCOL_VERSION) session = cluster.connect() - for i in xrange(30): + for i in range(30): try: execute_until_pass(session, "CREATE KEYSPACE test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}") except AlreadyExists: @@ -111,7 +111,7 @@ class SchemaTests(unittest.TestCase): execute_until_pass(session, "CREATE TABLE test.cf (key int PRIMARY KEY, value int)") - for j in xrange(100): + for j in range(100): execute_until_pass(session, "INSERT INTO test.cf (key, value) VALUES ({0}, {0})".format(j)) execute_until_pass(session, "DROP KEYSPACE test") diff --git a/tests/integration/standard/__init__.py b/tests/integration/standard/__init__.py index 794d75bf..484ed237 100644 --- a/tests/integration/standard/__init__.py +++ b/tests/integration/standard/__init__.py @@ -11,6 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + try: from ccmlib import common except ImportError as e: diff --git a/tests/integration/standard/test_concurrent.py b/tests/integration/standard/test_concurrent.py index 9edcbcbf..4d9ce3ae 100644 --- a/tests/integration/standard/test_concurrent.py +++ b/tests/integration/standard/test_concurrent.py @@ -25,6 +25,8 @@ from cassandra.query import tuple_factory, SimpleStatement from tests.integration import use_singledc, PROTOCOL_VERSION +from six import next + try: import unittest2 as unittest except ImportError: diff --git a/tests/integration/standard/test_custom_protocol_handler.py b/tests/integration/standard/test_custom_protocol_handler.py index 61a23831..36965a36 100644 --- a/tests/integration/standard/test_custom_protocol_handler.py +++ b/tests/integration/standard/test_custom_protocol_handler.py @@ -21,7 +21,8 @@ from cassandra.protocol import ProtocolHandler, ResultMessage, UUIDType, read_in from cassandra.query import tuple_factory from cassandra.cluster import Cluster from tests.integration import use_singledc, PROTOCOL_VERSION, execute_until_pass -from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, get_sample +from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES +from tests.integration.standard.utils import create_table_with_all_types, get_all_primitive_params from six import binary_type import uuid @@ -62,7 +63,7 @@ class CustomProtocolHandlerTest(unittest.TestCase): """ # Ensure that we get normal uuid back first - session = Cluster().connect() + session = Cluster(protocol_version=PROTOCOL_VERSION).connect(keyspace="custserdes") session.row_factory = tuple_factory result_set = session.execute("SELECT schema_version FROM system.local") result = result_set.pop() @@ -102,15 +103,16 @@ class CustomProtocolHandlerTest(unittest.TestCase): @test_category data_types:serialization """ # Connect using a custom protocol handler that tracks the various types the result message is used with. - session = Cluster().connect(keyspace="custserdes") + session = Cluster(protocol_version=PROTOCOL_VERSION).connect(keyspace="custserdes") session.client_protocol_handler = CustomProtocolHandlerResultMessageTracked session.row_factory = tuple_factory - columns_string = create_table_with_all_types("test_table", session) + colnames = create_table_with_all_types("alltypes", session, 1) + columns_string = ", ".join(colnames) # verify data - params = get_all_primitive_params() - results = session.execute("SELECT {0} FROM alltypes WHERE zz=0".format(columns_string))[0] + params = get_all_primitive_params(0) + results = session.execute("SELECT {0} FROM alltypes WHERE primkey=0".format(columns_string))[0] for expected, actual in zip(params, results): self.assertEqual(actual, expected) # Ensure we have covered the various primitive types @@ -118,43 +120,6 @@ class CustomProtocolHandlerTest(unittest.TestCase): session.shutdown() -def create_table_with_all_types(table_name, session): - """ - Method that given a table_name and session construct a table that contains all possible primitive types - :param table_name: Name of table to create - :param session: session to use for table creation - :return: a string containing and columns. This can be used to query the table. - """ - # create table - alpha_type_list = ["zz int PRIMARY KEY"] - col_names = ["zz"] - start_index = ord('a') - for i, datatype in enumerate(PRIMITIVE_DATATYPES): - alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype)) - col_names.append(chr(start_index + i)) - - session.execute("CREATE TABLE alltypes ({0})".format(', '.join(alpha_type_list)), timeout=120) - - # create the input - params = get_all_primitive_params() - - # insert into table as a simple statement - columns_string = ', '.join(col_names) - placeholders = ', '.join(["%s"] * len(col_names)) - session.execute("INSERT INTO alltypes ({0}) VALUES ({1})".format(columns_string, placeholders), params, timeout=120) - return columns_string - - -def get_all_primitive_params(): - """ - Simple utility method used to give back a list of all possible primitive data sample types. - """ - params = [0] - for datatype in PRIMITIVE_DATATYPES: - params.append((get_sample(datatype))) - return params - - class CustomResultMessageRaw(ResultMessage): """ This is a custom Result Message that is used to return raw results, rather then diff --git a/tests/integration/standard/test_cython_protocol_handlers.py b/tests/integration/standard/test_cython_protocol_handlers.py new file mode 100644 index 00000000..985b7953 --- /dev/null +++ b/tests/integration/standard/test_cython_protocol_handlers.py @@ -0,0 +1,129 @@ +"""Test the various Cython-based message deserializers""" + +# Based on test_custom_protocol_handler.py + +try: + import unittest2 as unittest +except ImportError: + import unittest + +from cassandra.query import tuple_factory +from cassandra.cluster import Cluster +from cassandra.protocol import ProtocolHandler, LazyProtocolHandler, NumpyProtocolHandler + +from tests.integration import use_singledc, PROTOCOL_VERSION +from tests.integration.datatype_utils import update_datatypes +from tests.integration.standard.utils import ( + create_table_with_all_types, get_all_primitive_params, get_primitive_datatypes) + +from tests.unit.cython.utils import cythontest, numpytest + +def setup_module(): + use_singledc() + update_datatypes() + + +class CythonProtocolHandlerTest(unittest.TestCase): + + N_ITEMS = 10 + + @classmethod + def setUpClass(cls): + cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cls.session = cls.cluster.connect() + cls.session.execute("CREATE KEYSPACE testspace WITH replication = " + "{ 'class' : 'SimpleStrategy', 'replication_factor': '1'}") + cls.session.set_keyspace("testspace") + cls.colnames = create_table_with_all_types("test_table", cls.session, cls.N_ITEMS) + + @classmethod + def tearDownClass(cls): + cls.session.execute("DROP KEYSPACE testspace") + cls.cluster.shutdown() + + @cythontest + def test_cython_parser(self): + """ + Test Cython-based parser that returns a list of tuples + """ + verify_iterator_data(self.assertEqual, get_data(ProtocolHandler)) + + @cythontest + def test_cython_lazy_parser(self): + """ + Test Cython-based parser that returns an iterator of tuples + """ + verify_iterator_data(self.assertEqual, get_data(LazyProtocolHandler)) + + @numpytest + def test_numpy_parser(self): + """ + Test Numpy-based parser that returns a NumPy array + """ + # arrays = { 'a': arr1, 'b': arr2, ... } + arrays = get_data(NumpyProtocolHandler) + + colnames = self.colnames + datatypes = get_primitive_datatypes() + for colname, datatype in zip(colnames, datatypes): + arr = arrays[colname] + self.match_dtype(datatype, arr.dtype) + + verify_iterator_data(self.assertEqual, arrays_to_list_of_tuples(arrays, colnames)) + + def match_dtype(self, datatype, dtype): + """Match a string cqltype (e.g. 'int' or 'blob') with a numpy dtype""" + if datatype == 'smallint': + self.match_dtype_props(dtype, 'i', 2) + elif datatype == 'int': + self.match_dtype_props(dtype, 'i', 4) + elif datatype in ('bigint', 'counter'): + self.match_dtype_props(dtype, 'i', 8) + elif datatype == 'float': + self.match_dtype_props(dtype, 'f', 4) + elif datatype == 'double': + self.match_dtype_props(dtype, 'f', 8) + else: + self.assertEqual(dtype.kind, 'O', msg=(dtype, datatype)) + + def match_dtype_props(self, dtype, kind, size, signed=None): + self.assertEqual(dtype.kind, kind, msg=dtype) + self.assertEqual(dtype.itemsize, size, msg=dtype) + + +def arrays_to_list_of_tuples(arrays, colnames): + """Convert a dict of arrays (as given by the numpy protocol handler) to a list of tuples""" + first_array = arrays[colnames[0]] + return [tuple(arrays[colname][i] for colname in colnames) + for i in range(len(first_array))] + + +def get_data(protocol_handler): + """ + Get some data from the test table. + + :param key: if None, get all results (100.000 results), otherwise get only one result + """ + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect(keyspace="testspace") + + # use our custom protocol handler + session.client_protocol_handler = protocol_handler + session.row_factory = tuple_factory + + results = session.execute("SELECT * FROM test_table") + session.shutdown() + return results + + +def verify_iterator_data(assertEqual, results): + """ + Check the result of get_data() when this is a list or + iterator of tuples + """ + for result in results: + params = get_all_primitive_params(result[0]) + assertEqual(len(params), len(result), + msg="Not the right number of columns?") + for expected, actual in zip(params, result): + assertEqual(actual, expected) diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index df49c148..92ffbf68 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -301,8 +301,8 @@ class BatchStatementTests(unittest.TestCase): keys.add(result.k) values.add(result.v) - self.assertEqual(set(range(10)), keys) - self.assertEqual(set(range(10)), values) + self.assertEqual(set(range(10)), keys, msg=results) + self.assertEqual(set(range(10)), values, msg=results) def test_string_statements(self): batch = BatchStatement(BatchType.LOGGED) @@ -370,6 +370,11 @@ class BatchStatementTests(unittest.TestCase): self.session.execute(batch) self.confirm_results() + def test_no_parameters_many_times(self): + for i in range(1000): + self.test_no_parameters() + self.session.execute("TRUNCATE test3rf.test") + class SerialConsistencyTests(unittest.TestCase): def setUp(self): diff --git a/tests/integration/standard/utils.py b/tests/integration/standard/utils.py new file mode 100644 index 00000000..fe54f04d --- /dev/null +++ b/tests/integration/standard/utils.py @@ -0,0 +1,53 @@ +""" +Helper module to populate a dummy Cassandra tables with data. +""" + +from tests.integration.datatype_utils import PRIMITIVE_DATATYPES, get_sample + +def create_table_with_all_types(table_name, session, N): + """ + Method that given a table_name and session construct a table that contains + all possible primitive types. + + :param table_name: Name of table to create + :param session: session to use for table creation + :param N: the number of items to insert into the table + + :return: a list of column names + """ + # create table + alpha_type_list = ["primkey int PRIMARY KEY"] + col_names = ["primkey"] + start_index = ord('a') + for i, datatype in enumerate(PRIMITIVE_DATATYPES): + alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype)) + col_names.append(chr(start_index + i)) + + session.execute("CREATE TABLE {0} ({1})".format( + table_name, ', '.join(alpha_type_list)), timeout=120) + + # create the input + + for key in range(N): + params = get_all_primitive_params(key) + + # insert into table as a simple statement + columns_string = ', '.join(col_names) + placeholders = ', '.join(["%s"] * len(col_names)) + session.execute("INSERT INTO {0} ({1}) VALUES ({2})".format( + table_name, columns_string, placeholders), params, timeout=120) + return col_names + + +def get_all_primitive_params(key): + """ + Simple utility method used to give back a list of all possible primitive data sample types. + """ + params = [key] + for datatype in PRIMITIVE_DATATYPES: + params.append(get_sample(datatype)) + return params + + +def get_primitive_datatypes(): + return ['int'] + list(PRIMITIVE_DATATYPES) \ No newline at end of file diff --git a/tests/stress_tests/test_multi_inserts.py b/tests/stress_tests/test_multi_inserts.py index b23a29dd..12b5b70e 100644 --- a/tests/stress_tests/test_multi_inserts.py +++ b/tests/stress_tests/test_multi_inserts.py @@ -75,7 +75,7 @@ class StressInsertsTests(unittest.TestCase): break for conn in pool.get_connections(): if conn.in_flight > 1: - print self.session.get_pool_state() + print(self.session.get_pool_state()) leaking_connections = True break i = i + 1 diff --git a/tests/unit/cython/__init__.py b/tests/unit/cython/__init__.py new file mode 100644 index 00000000..e4b89e5f --- /dev/null +++ b/tests/unit/cython/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/cython/bytesio_testhelper.pyx b/tests/unit/cython/bytesio_testhelper.pyx new file mode 100644 index 00000000..d557c037 --- /dev/null +++ b/tests/unit/cython/bytesio_testhelper.pyx @@ -0,0 +1,44 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.bytesio cimport BytesIOReader + +def test_read1(assert_equal, assert_raises): + cdef BytesIOReader reader = BytesIOReader(b'abcdef') + assert_equal(reader.read(2)[:2], b'ab') + assert_equal(reader.read(2)[:2], b'cd') + assert_equal(reader.read(0)[:0], b'') + assert_equal(reader.read(2)[:2], b'ef') + +def test_read2(assert_equal, assert_raises): + cdef BytesIOReader reader = BytesIOReader(b'abcdef') + reader.read(5) + reader.read(1) + +def test_read3(assert_equal, assert_raises): + cdef BytesIOReader reader = BytesIOReader(b'abcdef') + reader.read(6) + +def test_read_eof(assert_equal, assert_raises): + cdef BytesIOReader reader = BytesIOReader(b'abcdef') + reader.read(5) + # cannot convert reader.read to an object, do it manually + # assert_raises(EOFError, reader.read, 2) + try: + reader.read(2) + except EOFError: + pass + else: + raise Exception("Expected an EOFError") + reader.read(1) # see that we can still read this diff --git a/tests/unit/cython/test_bytesio.py b/tests/unit/cython/test_bytesio.py new file mode 100644 index 00000000..2dbf1311 --- /dev/null +++ b/tests/unit/cython/test_bytesio.py @@ -0,0 +1,35 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests.unit.cython.utils import cyimport, cythontest +bytesio_testhelper = cyimport('tests.unit.cython.bytesio_testhelper') + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + + +class BytesIOTest(unittest.TestCase): + """Test Cython BytesIO proxy""" + + @cythontest + def test_reading(self): + bytesio_testhelper.test_read1(self.assertEqual, self.assertRaises) + bytesio_testhelper.test_read2(self.assertEqual, self.assertRaises) + bytesio_testhelper.test_read3(self.assertEqual, self.assertRaises) + + @cythontest + def test_reading_error(self): + bytesio_testhelper.test_read_eof(self.assertEqual, self.assertRaises) diff --git a/tests/unit/cython/utils.py b/tests/unit/cython/utils.py new file mode 100644 index 00000000..788212ac --- /dev/null +++ b/tests/unit/cython/utils.py @@ -0,0 +1,38 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +def cyimport(import_path): + """ + Import a Cython module if available, otherwise return None + (and skip any relevant tests). + """ + try: + return __import__(import_path, fromlist=[True]) + except ImportError: + if HAVE_CYTHON: + raise + return None + + +# @cythontest +# def test_something(self): ... +cythontest = unittest.skipUnless(HAVE_CYTHON, 'Cython is not available') +numpytest = unittest.skipUnless(HAVE_CYTHON and HAVE_NUMPY, 'NumPy is not available')