Merge branch 'cython_serdes'

Conflicts:
	.gitignore
This commit is contained in:
Adam Holmberg
2015-08-13 12:35:51 -05:00
42 changed files with 1908 additions and 133 deletions

3
.gitignore vendored
View File

@@ -4,6 +4,7 @@
*.so *.so
*.egg *.egg
*.egg-info *.egg-info
*.attr
.tox .tox
.python-version .python-version
build build
@@ -19,6 +20,7 @@ setuptools*.egg
cassandra/*.c cassandra/*.c
!cassandra/cmurmur3.c !cassandra/cmurmur3.c
cassandra/*.html
# OSX # OSX
.DS_Store .DS_Store
@@ -38,3 +40,4 @@ cassandra/*.c
#iPython #iPython
*.ipynb *.ipynb

View File

@@ -1 +1,4 @@
include setup.py README.rst MANIFEST.in LICENSE ez_setup.py include setup.py README.rst MANIFEST.in LICENSE ez_setup.py
include cassandra/*.pyx
include cassandra/*.pxd
include tests/unit/cython/*.pyx

58
cassandra/buffer.pxd Normal file
View File

@@ -0,0 +1,58 @@
# Copyright 2013-2015 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Simple buffer data structure that provides a view on existing memory
(e.g. from a bytes object). This memory must stay alive while the
buffer is in use.
"""
from cpython.bytes cimport PyBytes_AS_STRING
# char* PyBytes_AS_STRING(object string)
# Macro form of PyBytes_AsString() but without error
# checking. Only string objects are supported; no Unicode objects
# should be passed.
cdef struct Buffer:
char *ptr
Py_ssize_t size
cdef inline bytes to_bytes(Buffer *buf):
return buf.ptr[:buf.size]
cdef inline char *buf_ptr(Buffer *buf):
return buf.ptr
cdef inline char *buf_read(Buffer *buf, Py_ssize_t size) except NULL:
if size > buf.size:
raise IndexError("Requested more than length of buffer")
return buf.ptr
cdef inline int slice_buffer(Buffer *buf, Buffer *out,
Py_ssize_t start, Py_ssize_t size) except -1:
if size < 0:
raise ValueError("Length must be positive")
if start + size > buf.size:
raise IndexError("Buffer slice out of bounds")
out.ptr = buf.ptr + start
out.size = size
return 0
cdef inline void from_ptr_and_size(char *ptr, Py_ssize_t size, Buffer *out):
out.ptr = ptr
out.size = size

20
cassandra/bytesio.pxd Normal file
View File

@@ -0,0 +1,20 @@
# Copyright 2013-2015 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
cdef class BytesIOReader:
cdef bytes buf
cdef char *buf_ptr
cdef Py_ssize_t pos
cdef Py_ssize_t size
cdef char *read(self, Py_ssize_t n = ?) except NULL

44
cassandra/bytesio.pyx Normal file
View File

@@ -0,0 +1,44 @@
# Copyright 2013-2015 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
cdef class BytesIOReader:
"""
This class provides efficient support for reading bytes from a 'bytes' buffer,
by returning char * values directly without allocating intermediate objects.
"""
def __init__(self, bytes buf):
self.buf = buf
self.size = len(buf)
self.buf_ptr = self.buf
cdef char *read(self, Py_ssize_t n = -1) except NULL:
"""Read at most size bytes from the file
(less if the read hits EOF before obtaining size bytes).
If the size argument is negative or omitted, read all data until EOF
is reached. The bytes are returned as a string object. An empty
string is returned when EOF is encountered immediately.
"""
cdef Py_ssize_t newpos = self.pos + n
if n < 0:
newpos = self.size
elif newpos > self.size:
# Raise an error here, as we do not want the caller to consume past the
# end of the buffer
raise EOFError("Cannot read past the end of the file")
cdef char *res = self.buf_ptr + self.pos
self.pos = newpos
return res

View File

@@ -288,7 +288,7 @@ class _CassandraType(object):
Given a set of other CassandraTypes, create a new subtype of this type Given a set of other CassandraTypes, create a new subtype of this type
using them as parameters. This is how composite types are constructed. using them as parameters. This is how composite types are constructed.
>>> MapType.apply_parameters(DateType, BooleanType) >>> MapType.apply_parameters([DateType, BooleanType])
<class 'cassandra.types.MapType(DateType, BooleanType)'> <class 'cassandra.types.MapType(DateType, BooleanType)'>
`subtypes` will be a sequence of CassandraTypes. If provided, `names` `subtypes` will be a sequence of CassandraTypes. If provided, `names`
@@ -884,27 +884,7 @@ class UserType(TupleType):
@classmethod @classmethod
def deserialize_safe(cls, byts, protocol_version): def deserialize_safe(cls, byts, protocol_version):
proto_version = max(3, protocol_version) values = super(UserType, cls).deserialize_safe(byts, protocol_version)
p = 0
values = []
for col_type in cls.subtypes:
if p == len(byts):
break
itemlen = int32_unpack(byts[p:p + 4])
p += 4
if itemlen >= 0:
item = byts[p:p + itemlen]
p += itemlen
else:
item = None
# collections inside UDTs are always encoded with at least the
# version 3 format
values.append(col_type.from_binary(item, proto_version))
if len(values) < len(cls.subtypes):
nones = [None] * (len(cls.subtypes) - len(values))
values = values + nones
if cls.mapped_class: if cls.mapped_class:
return cls.mapped_class(**dict(zip(cls.fieldnames, values))) return cls.mapped_class(**dict(zip(cls.fieldnames, values)))
else: else:

11
cassandra/cython_deps.py Normal file
View File

@@ -0,0 +1,11 @@
try:
from cassandra.row_parser import make_recv_results_rows
HAVE_CYTHON = True
except ImportError:
HAVE_CYTHON = False
try:
import numpy
HAVE_NUMPY = True
except ImportError:
HAVE_NUMPY = False

View File

@@ -0,0 +1,123 @@
# -- cython: profile=True
#
# Copyright 2013-2015 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
from libc.stdint cimport (int8_t, int16_t, int32_t, int64_t,
uint8_t, uint16_t, uint32_t, uint64_t)
from cassandra.buffer cimport Buffer, buf_read, to_bytes
cdef bint is_little_endian
from cassandra.util import is_little_endian
cdef bint PY3 = six.PY3
cdef inline void swap_order(char *buf, Py_ssize_t size):
"""
Swap the byteorder of `buf` in-place on little-endian platforms
(reverse all the bytes).
There are functions ntohl etc, but these may be POSIX-dependent.
"""
cdef Py_ssize_t start, end, i
cdef char c
if is_little_endian:
for i in range(div2(size)):
end = size - i - 1
c = buf[i]
buf[i] = buf[end]
buf[end] = c
cdef inline Py_ssize_t div2(Py_ssize_t x):
return x >> 1
### Unpacking of signed integers
cdef inline int64_t int64_unpack(Buffer *buf) except ?0xDEAD:
cdef int64_t x = (<int64_t *> buf_read(buf, 8))[0]
cdef char *p = <char *> &x
swap_order(<char *> &x, 8)
return x
cdef inline int32_t int32_unpack(Buffer *buf) except ?0xDEAD:
cdef int32_t x = (<int32_t *> buf_read(buf, 4))[0]
cdef char *p = <char *> &x
swap_order(<char *> &x, 4)
return x
cdef inline int16_t int16_unpack(Buffer *buf) except ?0xDED:
cdef int16_t x = (<int16_t *> buf_read(buf, 2))[0]
swap_order(<char *> &x, 2)
return x
cdef inline int8_t int8_unpack(Buffer *buf) except ?80:
return (<int8_t *> buf_read(buf, 1))[0]
cdef inline uint64_t uint64_unpack(Buffer *buf) except ?0xDEAD:
cdef uint64_t x = (<uint64_t *> buf_read(buf, 8))[0]
swap_order(<char *> &x, 8)
return x
cdef inline uint32_t uint32_unpack(Buffer *buf) except ?0xDEAD:
cdef uint32_t x = (<uint32_t *> buf_read(buf, 4))[0]
swap_order(<char *> &x, 4)
return x
cdef inline uint16_t uint16_unpack(Buffer *buf) except ?0xDEAD:
cdef uint16_t x = (<uint16_t *> buf_read(buf, 2))[0]
swap_order(<char *> &x, 2)
return x
cdef inline uint8_t uint8_unpack(Buffer *buf) except ?0xff:
return (<uint8_t *> buf_read(buf, 1))[0]
cdef inline double double_unpack(Buffer *buf) except ?1.74:
cdef double x = (<double *> buf_read(buf, 8))[0]
swap_order(<char *> &x, 8)
return x
cdef inline float float_unpack(Buffer *buf) except ?1.74:
cdef float x = (<float *> buf_read(buf, 4))[0]
swap_order(<char *> &x, 4)
return x
cdef varint_unpack(Buffer *term):
"""Unpack a variable-sized integer"""
if PY3:
return varint_unpack_py3(to_bytes(term))
else:
return varint_unpack_py2(to_bytes(term))
# TODO: Optimize these two functions
cdef varint_unpack_py3(bytes term):
cdef int64_t one = 1L
val = int(''.join("%02x" % i for i in term), 16)
if (term[0] & 128) != 0:
# There is a bug in Cython (0.20 - 0.22), where if we do
# '1 << (len(term) * 8)' Cython generates '1' directly into the
# C code, causing integer overflows
val -= one << (len(term) * 8)
return val
cdef varint_unpack_py2(bytes term): # noqa
cdef int64_t one = 1L
val = int(term.encode('hex'), 16)
if (ord(term[0]) & 128) != 0:
val = val - (one << (len(term) * 8))
return val

View File

@@ -0,0 +1,2 @@
from libc.stdint cimport int64_t
cdef datetime_from_timestamp(double timestamp)

View File

@@ -0,0 +1,44 @@
# Copyright 2013-2015 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Duplicate module of util.py, with some accelerated functions
used for deserialization.
"""
from cpython.datetime cimport (
timedelta_new,
# cdef inline object timedelta_new(int days, int seconds, int useconds)
# Create timedelta object using DateTime CAPI factory function.
# Note, there are no range checks for any of the arguments.
import_datetime,
# Datetime C API initialization function.
# You have to call it before any usage of DateTime CAPI functions.
)
import datetime
import sys
cdef bint is_little_endian
from cassandra.util import is_little_endian
import_datetime()
DATETIME_EPOC = datetime.datetime(1970, 1, 1)
cdef datetime_from_timestamp(double timestamp):
cdef int seconds = <int> timestamp
cdef int microseconds = (<int64_t> (timestamp * 1000000)) % 1000000
return DATETIME_EPOC + timedelta_new(0, seconds, microseconds)

View File

@@ -0,0 +1,43 @@
# Copyright 2013-2015 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from cassandra.buffer cimport Buffer
cdef class Deserializer:
# The cqltypes._CassandraType corresponding to this deserializer
cdef object cqltype
# String may be empty, whereas other values may not be.
# Other values may be NULL, in which case the integer length
# of the binary data is negative. However, non-string types
# may also return a zero length for legacy reasons
# (see http://code.metager.de/source/xref/apache/cassandra/doc/native_protocol_v3.spec
# paragraph 6)
cdef bint empty_binary_ok
cdef deserialize(self, Buffer *buf, int protocol_version)
# cdef deserialize(self, CString byts, protocol_version)
cdef inline object from_binary(Deserializer deserializer,
Buffer *buf,
int protocol_version):
if buf.size < 0:
return None
elif buf.size == 0 and not deserializer.empty_binary_ok:
return _ret_empty(deserializer, buf.size)
else:
return deserializer.deserialize(buf, protocol_version)
cdef _ret_empty(Deserializer deserializer, Py_ssize_t buf_size)

512
cassandra/deserializers.pyx Normal file
View File

@@ -0,0 +1,512 @@
# Copyright 2013-2015 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from libc.stdint cimport int32_t, uint16_t
include 'cython_marshal.pyx'
from cassandra.buffer cimport Buffer, to_bytes, slice_buffer
from cassandra.cython_utils cimport datetime_from_timestamp
from cython.view cimport array as cython_array
from cassandra.tuple cimport tuple_new, tuple_set
import socket
from decimal import Decimal
from uuid import UUID
from cassandra import cqltypes
from cassandra import util
cdef bint PY2 = six.PY2
cdef class Deserializer:
"""Cython-based deserializer class for a cqltype"""
def __init__(self, cqltype):
self.cqltype = cqltype
self.empty_binary_ok = cqltype.empty_binary_ok
cdef deserialize(self, Buffer *buf, int protocol_version):
raise NotImplementedError
cdef class DesBytesType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
return to_bytes(buf)
# TODO: Use libmpdec: http://www.bytereef.org/mpdecimal/index.html
cdef class DesDecimalType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
cdef Buffer varint_buf
slice_buffer(buf, &varint_buf, 4, buf.size - 4)
scale = int32_unpack(buf)
unscaled = varint_unpack(&varint_buf)
return Decimal('%de%d' % (unscaled, -scale))
cdef class DesUUIDType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
return UUID(bytes=to_bytes(buf))
cdef class DesBooleanType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
if int8_unpack(buf):
return True
return False
cdef class DesByteType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
return int8_unpack(buf)
cdef class DesAsciiType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
if PY2:
return to_bytes(buf)
return to_bytes(buf).decode('ascii')
cdef class DesFloatType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
return float_unpack(buf)
cdef class DesDoubleType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
return double_unpack(buf)
cdef class DesLongType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
return int64_unpack(buf)
cdef class DesInt32Type(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
return int32_unpack(buf)
cdef class DesIntegerType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
return varint_unpack(buf)
cdef class DesInetAddressType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
cdef bytes byts = to_bytes(buf)
# TODO: optimize inet_ntop, inet_ntoa
if buf.size == 16:
return util.inet_ntop(socket.AF_INET6, byts)
else:
# util.inet_pton could also handle, but this is faster
# since we've already determined the AF
return socket.inet_ntoa(byts)
cdef class DesCounterColumnType(DesLongType):
pass
cdef class DesDateType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
cdef double timestamp = int64_unpack(buf) / 1000.0
return datetime_from_timestamp(timestamp)
cdef class TimestampType(DesDateType):
pass
cdef class TimeUUIDType(DesDateType):
cdef deserialize(self, Buffer *buf, int protocol_version):
return UUID(bytes=to_bytes(buf))
# Values of the 'date'` type are encoded as 32-bit unsigned integers
# representing a number of days with epoch (January 1st, 1970) at the center of the
# range (2^31).
EPOCH_OFFSET_DAYS = 2 ** 31
cdef class DesSimpleDateType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
days = uint32_unpack(buf) - EPOCH_OFFSET_DAYS
return util.Date(days)
cdef class DesShortType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
return int16_unpack(buf)
cdef class DesTimeType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
return util.Time(int64_unpack(buf))
cdef class DesUTF8Type(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
cdef val = to_bytes(buf)
return val.decode('utf8')
cdef class DesVarcharType(DesUTF8Type):
pass
cdef class _DesParameterizedType(Deserializer):
cdef object subtypes
cdef Deserializer[::1] deserializers
cdef Py_ssize_t subtypes_len
def __init__(self, cqltype):
super().__init__(cqltype)
self.subtypes = cqltype.subtypes
self.deserializers = make_deserializers(cqltype.subtypes)
self.subtypes_len = len(self.subtypes)
cdef class _DesSingleParamType(_DesParameterizedType):
cdef Deserializer deserializer
def __init__(self, cqltype):
assert cqltype.subtypes and len(cqltype.subtypes) == 1, cqltype.subtypes
super().__init__(cqltype)
self.deserializer = self.deserializers[0]
#--------------------------------------------------------------------------
# List and set deserialization
cdef class DesListType(_DesSingleParamType):
cdef deserialize(self, Buffer *buf, int protocol_version):
cdef uint16_t v2_and_below = 2
cdef int32_t v3_and_above = 3
if protocol_version >= 3:
result = _deserialize_list_or_set[int32_t](
v3_and_above, buf, protocol_version, self.deserializer)
else:
result = _deserialize_list_or_set[uint16_t](
v2_and_below, buf, protocol_version, self.deserializer)
return result
cdef class DesSetType(DesListType):
cdef deserialize(self, Buffer *buf, int protocol_version):
return util.sortedset(DesListType.deserialize(self, buf, protocol_version))
ctypedef fused itemlen_t:
uint16_t # protocol <= v2
int32_t # protocol >= v3
cdef list _deserialize_list_or_set(itemlen_t dummy_version,
Buffer *buf, int protocol_version,
Deserializer deserializer):
"""
Deserialize a list or set.
The 'dummy' parameter is needed to make fused types work, so that
we can specialize on the protocol version.
"""
cdef itemlen_t itemlen
cdef Buffer itemlen_buf
cdef Buffer elem_buf
cdef itemlen_t numelements
cdef itemlen_t idx
cdef list result = []
_unpack_len[itemlen_t](0, &numelements, buf)
idx = sizeof(itemlen_t)
for _ in range(numelements):
subelem(buf, &elem_buf, &idx)
result.append(from_binary(deserializer, &elem_buf, protocol_version))
return result
cdef inline int subelem(
Buffer *buf, Buffer *elem_buf, itemlen_t *idx_p) except -1:
"""
Read the next element from the buffer: first read the size (in bytes) of the
element, then fill elem_buf with a newly sliced buffer of this size (and the
right offset).
NOTE: The handling of 'idx' is somewhat atrocious, as there is a Cython
bug with the combination fused types + 'except' clause.
So instead, we pass in a pointer to 'idx', namely 'idx_p', and write
to this instead.
"""
cdef itemlen_t elemlen
_unpack_len[itemlen_t](idx_p[0], &elemlen, buf)
idx_p[0] += sizeof(itemlen_t)
slice_buffer(buf, elem_buf, idx_p[0], elemlen)
idx_p[0] += elemlen
return 0
cdef int _unpack_len(itemlen_t idx, itemlen_t *elemlen, Buffer *buf) except -1:
cdef itemlen_t result
cdef Buffer itemlen_buf
slice_buffer(buf, &itemlen_buf, idx, sizeof(itemlen_t))
if itemlen_t is uint16_t:
elemlen[0] = uint16_unpack(&itemlen_buf)
else:
elemlen[0] = int32_unpack(&itemlen_buf)
return 0
#--------------------------------------------------------------------------
# Map deserialization
cdef class DesMapType(_DesParameterizedType):
cdef Deserializer key_deserializer, val_deserializer
def __init__(self, cqltype):
super().__init__(cqltype)
self.key_deserializer = self.deserializers[0]
self.val_deserializer = self.deserializers[1]
cdef deserialize(self, Buffer *buf, int protocol_version):
cdef uint16_t v2_and_below = 0
cdef int32_t v3_and_above = 0
key_type, val_type = self.cqltype.subtypes
if protocol_version >= 3:
result = _deserialize_map[int32_t](
v3_and_above, buf, protocol_version,
self.key_deserializer, self.val_deserializer,
key_type, val_type)
else:
result = _deserialize_map[uint16_t](
v2_and_below, buf, protocol_version,
self.key_deserializer, self.val_deserializer,
key_type, val_type)
return result
cdef _deserialize_map(itemlen_t dummy_version,
Buffer *buf, int protocol_version,
Deserializer key_deserializer, Deserializer val_deserializer,
key_type, val_type):
cdef itemlen_t itemlen, val_len, key_len
cdef Buffer key_buf, val_buf
cdef Buffer itemlen_buf
cdef itemlen_t numelements
cdef itemlen_t idx = sizeof(itemlen_t)
cdef list result = []
_unpack_len[itemlen_t](0, &numelements, buf)
idx = sizeof(itemlen_t)
themap = util.OrderedMapSerializedKey(key_type, protocol_version)
for _ in range(numelements):
subelem(buf, &key_buf, &idx)
subelem(buf, &val_buf, &idx)
key = from_binary(key_deserializer, &key_buf, protocol_version)
val = from_binary(val_deserializer, &val_buf, protocol_version)
themap._insert_unchecked(key, to_bytes(&key_buf), val)
return themap
#--------------------------------------------------------------------------
cdef class DesTupleType(_DesParameterizedType):
# TODO: Use TupleRowParser to parse these tuples
cdef deserialize(self, Buffer *buf, int protocol_version):
cdef Py_ssize_t i, p
cdef int32_t itemlen
cdef tuple res = tuple_new(self.subtypes_len)
cdef Buffer item_buf
cdef Buffer itemlen_buf
cdef Deserializer deserializer
# collections inside UDTs are always encoded with at least the
# version 3 format
protocol_version = max(3, protocol_version)
p = 0
values = []
for i in range(self.subtypes_len):
item = None
if p < buf.size:
slice_buffer(buf, &itemlen_buf, p, 4)
itemlen = int32_unpack(&itemlen_buf)
p += 4
if itemlen >= 0:
slice_buffer(buf, &item_buf, p, itemlen)
p += itemlen
deserializer = self.deserializers[i]
item = from_binary(deserializer, &item_buf, protocol_version)
tuple_set(res, i, item)
return res
cdef class DesUserType(DesTupleType):
cdef deserialize(self, Buffer *buf, int protocol_version):
typ = self.cqltype
values = DesTupleType.deserialize(self, buf, protocol_version)
if typ.mapped_class:
return typ.mapped_class(**dict(zip(typ.fieldnames, values)))
else:
return typ.tuple_type(*values)
cdef class DesCompositeType(_DesParameterizedType):
cdef deserialize(self, Buffer *buf, int protocol_version):
cdef Py_ssize_t i, idx, start
cdef Buffer elem_buf
cdef int16_t element_length
cdef Deserializer deserializer
cdef tuple res = tuple_new(self.subtypes_len)
idx = 0
for i in range(self.subtypes_len):
if not buf.size:
# CompositeType can have missing elements at the end
# Fill the tuple with None values and slice it
#
# (I'm not sure a tuple needs to be fully initialized before
# it can be destroyed, so play it safe)
for j in range(i, self.subtypes_len):
tuple_set(res, j, None)
res = res[:i]
break
element_length = uint16_unpack(buf)
slice_buffer(buf, &elem_buf, 2, element_length)
deserializer = self.deserializers[i]
item = from_binary(deserializer, &elem_buf, protocol_version)
tuple_set(res, i, item)
# skip element length, element, and the EOC (one byte)
start = 2 + element_length + 1
slice_buffer(buf, buf, start, buf.size - start)
return res
DesDynamicCompositeType = DesCompositeType
cdef class DesReversedType(_DesSingleParamType):
cdef deserialize(self, Buffer *buf, int protocol_version):
return from_binary(self.deserializer, buf, protocol_version)
cdef class DesFrozenType(_DesSingleParamType):
cdef deserialize(self, Buffer *buf, int protocol_version):
return from_binary(self.deserializer, buf, protocol_version)
#--------------------------------------------------------------------------
cdef _ret_empty(Deserializer deserializer, Py_ssize_t buf_size):
"""
Decide whether to return None or EMPTY when a buffer size is
zero or negative. This is used by from_binary in deserializers.pxd.
"""
if buf_size < 0:
return None
elif deserializer.cqltype.support_empty_values:
return cqltypes.EMPTY
else:
return None
#--------------------------------------------------------------------------
# Generic deserialization
cdef class GenericDeserializer(Deserializer):
"""
Wrap a generic datatype for deserialization
"""
cdef deserialize(self, Buffer *buf, int protocol_version):
return self.cqltype.deserialize(to_bytes(buf), protocol_version)
def __repr__(self):
return "GenericDeserializer(%s)" % (self.cqltype,)
#--------------------------------------------------------------------------
# Helper utilities
def make_deserializers(cqltypes):
"""Create an array of Deserializers for each given cqltype in cqltypes"""
cdef Deserializer[::1] deserializers
return obj_array([find_deserializer(ct) for ct in cqltypes])
cdef dict classes = globals()
cpdef Deserializer find_deserializer(cqltype):
"""Find a deserializer for a cqltype"""
name = 'Des' + cqltype.__name__
if name in globals():
cls = classes[name]
elif issubclass(cqltype, cqltypes.ListType):
cls = DesListType
elif issubclass(cqltype, cqltypes.SetType):
cls = DesSetType
elif issubclass(cqltype, cqltypes.MapType):
cls = DesMapType
elif issubclass(cqltype, cqltypes.UserType):
# UserType is a subclass of TupleType, so should precede it
cls = DesUserType
elif issubclass(cqltype, cqltypes.TupleType):
cls = DesTupleType
elif issubclass(cqltype, cqltypes.DynamicCompositeType):
# DynamicCompositeType is a subclass of CompositeType, so should precede it
cls = DesDynamicCompositeType
elif issubclass(cqltype, cqltypes.CompositeType):
cls = DesCompositeType
elif issubclass(cqltype, cqltypes.ReversedType):
cls = DesReversedType
elif issubclass(cqltype, cqltypes.FrozenType):
cls = DesFrozenType
else:
cls = GenericDeserializer
return cls(cqltype)
def obj_array(list objs):
"""Create a (Cython) array of objects given a list of objects"""
cdef object[:] arr
cdef Py_ssize_t i
arr = cython_array(shape=(len(objs),), itemsize=sizeof(void *), format="O")
# arr[:] = objs # This does not work (segmentation faults)
for i, obj in enumerate(objs):
arr[i] = obj
return arr

47
cassandra/ioutils.pyx Normal file
View File

@@ -0,0 +1,47 @@
# Copyright 2013-2015 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
include 'cython_marshal.pyx'
from cassandra.buffer cimport Buffer, from_ptr_and_size
from libc.stdint cimport int32_t
from cassandra.bytesio cimport BytesIOReader
cdef inline int get_buf(BytesIOReader reader, Buffer *buf_out) except -1:
"""
Get a pointer into the buffer provided by BytesIOReader for the
next data item in the stream of values.
BEWARE:
If the next item has a zero negative size, the pointer will be set to NULL.
A negative size happens when the value is NULL in the database, whereas a
zero size may happen either for legacy reasons, or for data types such as
strings (which may be empty).
"""
cdef Py_ssize_t raw_val_size = read_int(reader)
cdef char *ptr
if raw_val_size <= 0:
ptr = NULL
else:
ptr = reader.read(raw_val_size)
from_ptr_and_size(ptr, raw_val_size, buf_out)
return 0
cdef inline int32_t read_int(BytesIOReader reader) except ?0xDEAD:
cdef Buffer buf
buf.ptr = reader.read(4)
buf.size = 4
return int32_unpack(&buf)

1
cassandra/numpyFlags.h Normal file
View File

@@ -0,0 +1 @@
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION

185
cassandra/numpy_parser.pyx Normal file
View File

@@ -0,0 +1,185 @@
# Copyright 2013-2015 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module provider an optional protocol parser that returns
NumPy arrays.
=============================================================================
This module should not be imported by any of the main python-driver modules,
as numpy is an optional dependency.
=============================================================================
"""
include "ioutils.pyx"
cimport cython
from libc.stdint cimport uint64_t
from cpython.ref cimport Py_INCREF, PyObject
from cassandra.bytesio cimport BytesIOReader
from cassandra.deserializers cimport Deserializer, from_binary
from cassandra.parsing cimport ParseDesc, ColumnParser, RowParser
from cassandra import cqltypes
from cassandra.util import is_little_endian
import numpy as np
# import pandas as pd
cdef extern from "numpyFlags.h":
# Include 'numpyFlags.h' into the generated C code to disable the
# deprecated NumPy API
pass
cdef extern from "Python.h":
# An integer type large enough to hold a pointer
ctypedef uint64_t Py_uintptr_t
# Simple array descriptor, useful to parse rows into a NumPy array
ctypedef struct ArrDesc:
Py_uintptr_t buf_ptr
int stride # should be large enough as we allocate contiguous arrays
int is_object
arrDescDtype = np.dtype(
[ ('buf_ptr', np.uintp)
, ('stride', np.dtype('i'))
, ('is_object', np.dtype('i'))
])
_cqltype_to_numpy = {
cqltypes.LongType: np.dtype('>i8'),
cqltypes.CounterColumnType: np.dtype('>i8'),
cqltypes.Int32Type: np.dtype('>i4'),
cqltypes.ShortType: np.dtype('>i2'),
cqltypes.FloatType: np.dtype('>f4'),
cqltypes.DoubleType: np.dtype('>f8'),
}
obj_dtype = np.dtype('O')
cdef class NumpyParser(ColumnParser):
"""Decode a ResultMessage into a bunch of NumPy arrays"""
cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc):
cdef Py_ssize_t rowcount
cdef ArrDesc[::1] array_descs
cdef ArrDesc *arrs
rowcount = read_int(reader)
array_descs, arrays = make_arrays(desc, rowcount)
arrs = &array_descs[0]
_parse_rows(reader, desc, arrs, rowcount)
arrays = [make_native_byteorder(arr) for arr in arrays]
result = dict(zip(desc.colnames, arrays))
return result
cdef _parse_rows(BytesIOReader reader, ParseDesc desc,
ArrDesc *arrs, Py_ssize_t rowcount):
cdef Py_ssize_t i
for i in range(rowcount):
unpack_row(reader, desc, arrs)
### Helper functions to create NumPy arrays and array descriptors
def make_arrays(ParseDesc desc, array_size):
"""
Allocate arrays for each result column.
returns a tuple of (array_descs, arrays), where
'array_descs' describe the arrays for NativeRowParser and
'arrays' is a dict mapping column names to arrays
(e.g. this can be fed into pandas.DataFrame)
"""
array_descs = np.empty((desc.rowsize,), arrDescDtype)
arrays = []
for i, coltype in enumerate(desc.coltypes):
arr = make_array(coltype, array_size)
array_descs[i]['buf_ptr'] = arr.ctypes.data
array_descs[i]['stride'] = arr.strides[0]
array_descs[i]['is_object'] = coltype not in _cqltype_to_numpy
arrays.append(arr)
return array_descs, arrays
def make_array(coltype, array_size):
"""
Allocate a new NumPy array of the given column type and size.
"""
dtype = _cqltype_to_numpy.get(coltype, obj_dtype)
return np.empty((array_size,), dtype=dtype)
#### Parse rows into NumPy arrays
@cython.boundscheck(False)
@cython.wraparound(False)
cdef inline int unpack_row(
BytesIOReader reader, ParseDesc desc, ArrDesc *arrays) except -1:
cdef Buffer buf
cdef Py_ssize_t i, rowsize = desc.rowsize
cdef ArrDesc arr
cdef Deserializer deserializer
for i in range(rowsize):
get_buf(reader, &buf)
arr = arrays[i]
if buf.size == 0:
raise ValueError("Cannot handle NULL value")
if arr.is_object:
deserializer = desc.deserializers[i]
val = from_binary(deserializer, &buf, desc.protocol_version)
Py_INCREF(val)
(<PyObject **> arr.buf_ptr)[0] = <PyObject *> val
else:
memcopy(buf.ptr, <char *> arr.buf_ptr, buf.size)
# Update the pointer into the array for the next time
arrays[i].buf_ptr += arr.stride
return 0
cdef inline void memcopy(char *src, char *dst, Py_ssize_t size):
"""
Our own simple memcopy which can be inlined. This is useful because our data types
are only a few bytes.
"""
cdef Py_ssize_t i
for i in range(size):
dst[i] = src[i]
def make_native_byteorder(arr):
"""
Make sure all values have a native endian in the NumPy arrays.
"""
if is_little_endian and not arr.dtype.kind == 'O':
# We have arrays in big-endian order. First swap the bytes
# into little endian order, and then update the numpy dtype
# accordingly (e.g. from '>i8' to '<i8')
#
# Ignore any object arrays of dtype('O')
return arr.byteswap().newbyteorder()
return arr

75
cassandra/obj_parser.pyx Normal file
View File

@@ -0,0 +1,75 @@
# Copyright 2013-2015 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
include "ioutils.pyx"
from cassandra.bytesio cimport BytesIOReader
from cassandra.deserializers cimport Deserializer, from_binary
from cassandra.parsing cimport ParseDesc, ColumnParser, RowParser
from cassandra.tuple cimport tuple_new, tuple_set
cdef class ListParser(ColumnParser):
"""Decode a ResultMessage into a list of tuples (or other objects)"""
cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc):
cdef Py_ssize_t i, rowcount
rowcount = read_int(reader)
cdef RowParser rowparser = TupleRowParser()
return [rowparser.unpack_row(reader, desc) for i in range(rowcount)]
cdef class LazyParser(ColumnParser):
"""Decode a ResultMessage lazily using a generator"""
cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc):
# Use a little helper function as closures (generators) are not
# supported in cpdef methods
return parse_rows_lazy(reader, desc)
def parse_rows_lazy(BytesIOReader reader, ParseDesc desc):
cdef Py_ssize_t i, rowcount
rowcount = read_int(reader)
cdef RowParser rowparser = TupleRowParser()
return (rowparser.unpack_row(reader, desc) for i in range(rowcount))
cdef class TupleRowParser(RowParser):
"""
Parse a single returned row into a tuple of objects:
(obj1, ..., objN)
"""
cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc):
assert desc.rowsize >= 0
cdef Buffer buf
cdef Py_ssize_t i, rowsize = desc.rowsize
cdef Deserializer deserializer
cdef tuple res = tuple_new(desc.rowsize)
for i in range(rowsize):
# Read the next few bytes
get_buf(reader, &buf)
# Deserialize bytes to python object
deserializer = desc.deserializers[i]
val = from_binary(deserializer, &buf, desc.protocol_version)
# Insert new object into tuple
tuple_set(res, i, val)
return res

30
cassandra/parsing.pxd Normal file
View File

@@ -0,0 +1,30 @@
# Copyright 2013-2015 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from cassandra.bytesio cimport BytesIOReader
from cassandra.deserializers cimport Deserializer
cdef class ParseDesc:
cdef public object colnames
cdef public object coltypes
cdef Deserializer[::1] deserializers
cdef public int protocol_version
cdef Py_ssize_t rowsize
cdef class ColumnParser:
cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc)
cdef class RowParser:
cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc)

44
cassandra/parsing.pyx Normal file
View File

@@ -0,0 +1,44 @@
# Copyright 2013-2015 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module containing the definitions and declarations (parsing.pxd) for parsers.
"""
cdef class ParseDesc:
"""Description of what structure to parse"""
def __init__(self, colnames, coltypes, deserializers, protocol_version):
self.colnames = colnames
self.coltypes = coltypes
self.deserializers = deserializers
self.protocol_version = protocol_version
self.rowsize = len(colnames)
cdef class ColumnParser:
"""Decode a ResultMessage into a set of columns"""
cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc):
raise NotImplementedError
cdef class RowParser:
"""Parser for a single row"""
cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc):
"""
Unpack a single row of data in a ResultMessage.
"""
raise NotImplementedError

View File

@@ -22,6 +22,7 @@ import six
from six.moves import range from six.moves import range
import io import io
from cassandra import type_codes
from cassandra import (Unavailable, WriteTimeout, ReadTimeout, from cassandra import (Unavailable, WriteTimeout, ReadTimeout,
WriteFailure, ReadFailure, FunctionFailure, WriteFailure, ReadFailure, FunctionFailure,
AlreadyExists, InvalidRequest, Unauthorized, AlreadyExists, InvalidRequest, Unauthorized,
@@ -35,10 +36,11 @@ from cassandra.cqltypes import (AsciiType, BytesType, BooleanType,
DoubleType, FloatType, Int32Type, DoubleType, FloatType, Int32Type,
InetAddressType, IntegerType, ListType, InetAddressType, IntegerType, ListType,
LongType, MapType, SetType, TimeUUIDType, LongType, MapType, SetType, TimeUUIDType,
UTF8Type, UUIDType, UserType, UTF8Type, VarcharType, UUIDType, UserType,
TupleType, lookup_casstype, SimpleDateType, TupleType, lookup_casstype, SimpleDateType,
TimeType, ByteType, ShortType) TimeType, ByteType, ShortType)
from cassandra.policies import WriteType from cassandra.policies import WriteType
from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY
from cassandra import util from cassandra import util
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -68,10 +70,16 @@ _message_types_by_opcode = {}
_UNSET_VALUE = object() _UNSET_VALUE = object()
def register_class(cls):
_message_types_by_opcode[cls.opcode] = cls
def get_registered_classes():
return _message_types_by_opcode.copy()
class _RegisterMessageType(type): class _RegisterMessageType(type):
def __init__(cls, name, bases, dct): def __init__(cls, name, bases, dct):
if not name.startswith('_'): if not name.startswith('_'):
_message_types_by_opcode[cls.opcode] = cls register_class(cls)
@six.add_metaclass(_RegisterMessageType) @six.add_metaclass(_RegisterMessageType)
@@ -531,35 +539,6 @@ RESULT_KIND_SET_KEYSPACE = 0x0003
RESULT_KIND_PREPARED = 0x0004 RESULT_KIND_PREPARED = 0x0004
RESULT_KIND_SCHEMA_CHANGE = 0x0005 RESULT_KIND_SCHEMA_CHANGE = 0x0005
class CassandraTypeCodes(object):
CUSTOM_TYPE = 0x0000
AsciiType = 0x0001
LongType = 0x0002
BytesType = 0x0003
BooleanType = 0x0004
CounterColumnType = 0x0005
DecimalType = 0x0006
DoubleType = 0x0007
FloatType = 0x0008
Int32Type = 0x0009
UTF8Type = 0x000A
DateType = 0x000B
UUIDType = 0x000C
UTF8Type = 0x000D
IntegerType = 0x000E
TimeUUIDType = 0x000F
InetAddressType = 0x0010
SimpleDateType = 0x0011
TimeType = 0x0012
ShortType = 0x0013
ByteType = 0x0014
ListType = 0x0020
MapType = 0x0021
SetType = 0x0022
UserType = 0x0030
TupleType = 0x0031
class ResultMessage(_MessageType): class ResultMessage(_MessageType):
opcode = 0x08 opcode = 0x08
name = 'RESULT' name = 'RESULT'
@@ -569,7 +548,7 @@ class ResultMessage(_MessageType):
paging_state = None paging_state = None
# Names match type name in module scope. Most are imported from cassandra.cqltypes (except CUSTOM_TYPE) # Names match type name in module scope. Most are imported from cassandra.cqltypes (except CUSTOM_TYPE)
type_codes = _cqltypes_by_code = dict((v, globals()[k]) for k, v in CassandraTypeCodes.__dict__.items() if not k.startswith('_')) type_codes = _cqltypes_by_code = dict((v, globals()[k]) for k, v in type_codes.__dict__.items() if not k.startswith('_'))
_FLAGS_GLOBAL_TABLES_SPEC = 0x0001 _FLAGS_GLOBAL_TABLES_SPEC = 0x0001
_HAS_MORE_PAGES_FLAG = 0x0002 _HAS_MORE_PAGES_FLAG = 0x0002
@@ -1017,6 +996,63 @@ class ProtocolHandler(object):
return msg return msg
def cython_protocol_handler(colparser):
"""
Given a column parser to deserialize ResultMessages, return a suitable
Cython-based protocol handler.
There are three Cython-based protocol handlers (least to most performant):
1. obj_parser.ListParser
this parser decodes result messages into a list of tuples
2. obj_parser.LazyParser
this parser decodes result messages lazily by returning an iterator
3. numpy_parser.NumPyParser
this parser decodes result messages into NumPy arrays
The default is to use obj_parser.ListParser
"""
from cassandra.row_parser import make_recv_results_rows
class FastResultMessage(ResultMessage):
"""
Cython version of Result Message that has a faster implementation of
recv_results_row.
"""
# type_codes = ResultMessage.type_codes.copy()
code_to_type = dict((v, k) for k, v in ResultMessage.type_codes.items())
recv_results_rows = classmethod(make_recv_results_rows(colparser))
class CythonProtocolHandler(ProtocolHandler):
"""
Use FastResultMessage to decode query result message messages.
"""
my_opcodes = ProtocolHandler.message_types_by_opcode.copy()
my_opcodes[FastResultMessage.opcode] = FastResultMessage
message_types_by_opcode = my_opcodes
return CythonProtocolHandler
if HAVE_CYTHON:
from cassandra.obj_parser import ListParser, LazyParser
ProtocolHandler = cython_protocol_handler(ListParser())
LazyProtocolHandler = cython_protocol_handler(LazyParser())
else:
# Use Python-based ProtocolHandler
LazyProtocolHandler = None
if HAVE_CYTHON and HAVE_NUMPY:
from cassandra.numpy_parser import NumpyParser
NumpyProtocolHandler = cython_protocol_handler(NumpyParser())
else:
NumpyProtocolHandler = None
def read_byte(f): def read_byte(f):
return int8_unpack(f.read(1)) return int8_unpack(f.read(1))

38
cassandra/row_parser.pyx Normal file
View File

@@ -0,0 +1,38 @@
# Copyright 2013-2015 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from cassandra.parsing cimport ParseDesc, ColumnParser
from cassandra.deserializers import make_deserializers
include "ioutils.pyx"
def make_recv_results_rows(ColumnParser colparser):
def recv_results_rows(cls, f, int protocol_version, user_type_map):
"""
Parse protocol data given as a BytesIO f into a set of columns (e.g. list of tuples)
This is used as the recv_results_rows method of (Fast)ResultMessage
"""
paging_state, column_metadata = cls.recv_results_metadata(f, user_type_map)
colnames = [c[2] for c in column_metadata]
coltypes = [c[3] for c in column_metadata]
desc = ParseDesc(colnames, coltypes, make_deserializers(coltypes),
protocol_version)
reader = BytesIOReader(f.read())
parsed_rows = colparser.parse_rows(reader, desc)
return (paging_state, (colnames, parsed_rows))
return recv_results_rows

41
cassandra/tuple.pxd Normal file
View File

@@ -0,0 +1,41 @@
# Copyright 2013-2015 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from cpython.tuple cimport (
PyTuple_New,
# Return value: New reference.
# Return a new tuple object of size len, or NULL on failure.
PyTuple_SET_ITEM,
# Like PyTuple_SetItem(), but does no error checking, and should
# only be used to fill in brand new tuples. Note: This function
# ``steals'' a reference to o.
)
from cpython.ref cimport (
Py_INCREF
# void Py_INCREF(object o)
# Increment the reference count for object o. The object must not
# be NULL; if you aren't sure that it isn't NULL, use
# Py_XINCREF().
)
cdef inline tuple tuple_new(Py_ssize_t n):
"""Allocate a new tuple object"""
return PyTuple_New(n)
cdef inline void tuple_set(tuple tup, Py_ssize_t idx, object item):
"""Insert new object into tuple. No item must have been set yet."""
# PyTuple_SET_ITEM steals a reference, so we need to INCREF
Py_INCREF(item)
PyTuple_SET_ITEM(tup, idx, item)

42
cassandra/type_codes.pxd Normal file
View File

@@ -0,0 +1,42 @@
# Copyright 2013-2015 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
cdef enum:
CUSTOM_TYPE
AsciiType
LongType
BytesType
BooleanType
CounterColumnType
DecimalType
DoubleType
FloatType
Int32Type
UTF8Type
DateType
UUIDType
VarcharType
IntegerType
TimeUUIDType
InetAddressType
SimpleDateType
TimeType
ShortType
ByteType
ListType
MapType
SetType
UserType
TupleType

62
cassandra/type_codes.py Normal file
View File

@@ -0,0 +1,62 @@
"""
Module with constants for Cassandra type codes.
These constants are useful for
a) mapping messages to cqltypes (cassandra/cqltypes.py)
b) optimized dispatching for (de)serialization (cassandra/encoding.py)
Type codes are repeated here from the Cassandra binary protocol specification:
0x0000 Custom: the value is a [string], see above.
0x0001 Ascii
0x0002 Bigint
0x0003 Blob
0x0004 Boolean
0x0005 Counter
0x0006 Decimal
0x0007 Double
0x0008 Float
0x0009 Int
0x000A Text
0x000B Timestamp
0x000C Uuid
0x000D Varchar
0x000E Varint
0x000F Timeuuid
0x0010 Inet
0x0020 List: the value is an [option], representing the type
of the elements of the list.
0x0021 Map: the value is two [option], representing the types of the
keys and values of the map
0x0022 Set: the value is an [option], representing the type
of the elements of the set
"""
CUSTOM_TYPE = 0x0000
AsciiType = 0x0001
LongType = 0x0002
BytesType = 0x0003
BooleanType = 0x0004
CounterColumnType = 0x0005
DecimalType = 0x0006
DoubleType = 0x0007
FloatType = 0x0008
Int32Type = 0x0009
UTF8Type = 0x000A
DateType = 0x000B
UUIDType = 0x000C
VarcharType = 0x000D
IntegerType = 0x000E
TimeUUIDType = 0x000F
InetAddressType = 0x0010
SimpleDateType = 0x0011
TimeType = 0x0012
ShortType = 0x0013
ByteType = 0x0014
ListType = 0x0020
MapType = 0x0021
SetType = 0x0022
UserType = 0x0030
TupleType = 0x0031

View File

@@ -1,12 +1,29 @@
# Copyright 2013-2015 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import with_statement from __future__ import with_statement
import calendar import calendar
import datetime import datetime
import random import random
import six import six
import uuid import uuid
import sys
DATETIME_EPOC = datetime.datetime(1970, 1, 1) DATETIME_EPOC = datetime.datetime(1970, 1, 1)
assert sys.byteorder in ('little', 'big')
is_little_endian = sys.byteorder == 'little'
def datetime_from_timestamp(timestamp): def datetime_from_timestamp(timestamp):
""" """
@@ -490,8 +507,7 @@ except ImportError:
def __init__(self, iterable=()): def __init__(self, iterable=()):
self._items = [] self._items = []
for i in iterable: self.update(iterable)
self.add(i)
def __len__(self): def __len__(self):
return len(self._items) return len(self._items)
@@ -564,6 +580,10 @@ except ImportError:
else: else:
self._items.append(item) self._items.append(item)
def update(self, iterable):
for i in iterable:
self.add(i)
def clear(self): def clear(self):
del self._items[:] del self._items[:]

View File

@@ -24,3 +24,17 @@ See :meth:`.Session.execute`, ::meth:`.Session.execute_async`, :attr:`.ResponseF
.. automethod:: encode_message .. automethod:: encode_message
.. automethod:: decode_message .. automethod:: decode_message
Faster Deserialization
----------------------
When python-driver is compiled with Cython, it uses a Cython-based deserialization path
to deserialize messages. There are two additional ProtocolHandler classes that can be
used to deserialize response messages: the first is ``LazyProtocolHandler`` and the
second is ``NumpyProtocolHandler``.They can be used as follows:
.. code:: python
from cassandra.protocol import NumpyProtocolHandler, LazyProtocolHandler
s.client_protocol_handler = LazyProtocolHandler # for a result iterator
s.client_protocol_handler = NumpyProtocolHandler # for a dict of NumPy arrays as result

View File

@@ -37,7 +37,6 @@ from distutils.errors import (CCompilerError, DistutilsPlatformError,
DistutilsExecError) DistutilsExecError)
from distutils.cmd import Command from distutils.cmd import Command
try: try:
import subprocess import subprocess
has_subprocess = True has_subprocess = True
@@ -71,6 +70,7 @@ if __name__ == '__main__' and sys.argv[1] == "install":
except ImportError: except ImportError:
pass pass
PROFILING = False
class DocCommand(Command): class DocCommand(Command):
@@ -263,11 +263,16 @@ if "--no-libev" not in sys.argv and not is_windows:
if "--no-cython" not in sys.argv: if "--no-cython" not in sys.argv:
try: try:
from Cython.Build import cythonize from Cython.Build import cythonize
cython_candidates = ['cluster', 'concurrent', 'connection', 'cqltypes', 'marshal', 'metadata', 'pool', 'protocol', 'query', 'util'] cython_candidates = ['cluster', 'concurrent', 'connection', 'cqltypes', 'metadata',
'pool', 'protocol', 'query', 'util']
compile_args = [] if is_windows else ['-Wno-unused-function'] compile_args = [] if is_windows else ['-Wno-unused-function']
extensions.extend(cythonize( extensions.extend(cythonize(
[Extension('cassandra.%s' % m, ['cassandra/%s.py' % m], extra_compile_args=compile_args) for m in cython_candidates], [Extension('cassandra.%s' % m, ['cassandra/%s.py' % m],
extra_compile_args=compile_args)
for m in cython_candidates],
exclude_failures=True)) exclude_failures=True))
extensions.extend(cythonize("cassandra/*.pyx"))
extensions.extend(cythonize("tests/unit/cython/*.pyx"))
except ImportError: except ImportError:
sys.stderr.write("Cython is not installed. Not compiling core driver files as extensions (optional).") sys.stderr.write("Cython is not installed. Not compiling core driver files as extensions (optional).")

View File

@@ -160,7 +160,7 @@ def remove_cluster():
CCM_CLUSTER.remove() CCM_CLUSTER.remove()
CCM_CLUSTER = None CCM_CLUSTER = None
return return
except WindowsError: except OSError:
ex_type, ex, tb = sys.exc_info() ex_type, ex, tb = sys.exc_info()
log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb)))
del tb del tb

View File

@@ -386,7 +386,8 @@ class TestMapColumn(BaseCassEngTestCase):
k2 = uuid4() k2 = uuid4()
now = datetime.now() now = datetime.now()
then = now + timedelta(days=1) then = now + timedelta(days=1)
m1 = TestMapModel.create(int_map={1: k1, 2: k2}, text_map={'now': now, 'then': then}) m1 = TestMapModel.create(int_map={1: k1, 2: k2},
text_map={'now': now, 'then': then})
m2 = TestMapModel.get(partition=m1.partition) m2 = TestMapModel.get(partition=m1.partition)
self.assertTrue(isinstance(m2.int_map, dict)) self.assertTrue(isinstance(m2.int_map, dict))

View File

@@ -629,7 +629,7 @@ class TestMinMaxTimeUUIDFunctions(BaseCassEngTestCase):
# test kwarg filtering # test kwarg filtering
q = TimeUUIDQueryModel.filter(partition=pk, time__lte=functions.MaxTimeUUID(midpoint)) q = TimeUUIDQueryModel.filter(partition=pk, time__lte=functions.MaxTimeUUID(midpoint))
q = [d for d in q] q = [d for d in q]
assert len(q) == 2 self.assertEqual(len(q), 2, msg="Got: %s" % q)
datas = [d.data for d in q] datas = [d.data for d in q]
assert '1' in datas assert '1' in datas
assert '2' in datas assert '2' in datas

View File

@@ -52,17 +52,17 @@ class QueryUpdateTests(BaseCassEngTestCase):
# sanity check # sanity check
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
assert row.cluster == i self.assertEqual(row.cluster, i)
assert row.count == i self.assertEqual(row.count, i)
assert row.text == str(i) self.assertEqual(row.text, str(i))
# perform update # perform update
TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6) TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6)
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
assert row.cluster == i self.assertEqual(row.cluster, i)
assert row.count == (6 if i == 3 else i) self.assertEqual(row.count, 6 if i == 3 else i)
assert row.text == str(i) self.assertEqual(row.text, str(i))
def test_update_values_validation(self): def test_update_values_validation(self):
""" tests calling udpate on models with values passed in """ """ tests calling udpate on models with values passed in """
@@ -72,9 +72,9 @@ class QueryUpdateTests(BaseCassEngTestCase):
# sanity check # sanity check
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
assert row.cluster == i self.assertEqual(row.cluster, i)
assert row.count == i self.assertEqual(row.count, i)
assert row.text == str(i) self.assertEqual(row.text, str(i))
# perform update # perform update
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
@@ -98,17 +98,17 @@ class QueryUpdateTests(BaseCassEngTestCase):
# sanity check # sanity check
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
assert row.cluster == i self.assertEqual(row.cluster, i)
assert row.count == i self.assertEqual(row.count, i)
assert row.text == str(i) self.assertEqual(row.text, str(i))
# perform update # perform update
TestQueryUpdateModel.objects(partition=partition, cluster=3).update(text=None) TestQueryUpdateModel.objects(partition=partition, cluster=3).update(text=None)
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
assert row.cluster == i self.assertEqual(row.cluster, i)
assert row.count == i self.assertEqual(row.count, i)
assert row.text == (None if i == 3 else str(i)) self.assertEqual(row.text, None if i == 3 else str(i))
def test_mixed_value_and_null_update(self): def test_mixed_value_and_null_update(self):
""" tests that updating a columns value, and removing another works properly """ """ tests that updating a columns value, and removing another works properly """
@@ -118,17 +118,17 @@ class QueryUpdateTests(BaseCassEngTestCase):
# sanity check # sanity check
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
assert row.cluster == i self.assertEqual(row.cluster, i)
assert row.count == i self.assertEqual(row.count, i)
assert row.text == str(i) self.assertEqual(row.text, str(i))
# perform update # perform update
TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6, text=None) TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6, text=None)
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
assert row.cluster == i self.assertEqual(row.cluster, i)
assert row.count == (6 if i == 3 else i) self.assertEqual(row.count, 6 if i == 3 else i)
assert row.text == (None if i == 3 else str(i)) self.assertEqual(row.text, None if i == 3 else str(i))
def test_counter_updates(self): def test_counter_updates(self):
pass pass

View File

@@ -85,11 +85,11 @@ class SchemaTests(unittest.TestCase):
session = self.session session = self.session
for i in xrange(30): for i in range(30):
execute_until_pass(session, "CREATE KEYSPACE test_{0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}".format(i)) execute_until_pass(session, "CREATE KEYSPACE test_{0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}".format(i))
execute_until_pass(session, "CREATE TABLE test_{0}.cf (key int PRIMARY KEY, value int)".format(i)) execute_until_pass(session, "CREATE TABLE test_{0}.cf (key int PRIMARY KEY, value int)".format(i))
for j in xrange(100): for j in range(100):
execute_until_pass(session, "INSERT INTO test_{0}.cf (key, value) VALUES ({1}, {1})".format(i, j)) execute_until_pass(session, "INSERT INTO test_{0}.cf (key, value) VALUES ({1}, {1})".format(i, j))
execute_until_pass(session, "DROP KEYSPACE test_{0}".format(i)) execute_until_pass(session, "DROP KEYSPACE test_{0}".format(i))
@@ -102,7 +102,7 @@ class SchemaTests(unittest.TestCase):
cluster = Cluster(protocol_version=PROTOCOL_VERSION) cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect() session = cluster.connect()
for i in xrange(30): for i in range(30):
try: try:
execute_until_pass(session, "CREATE KEYSPACE test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}") execute_until_pass(session, "CREATE KEYSPACE test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}")
except AlreadyExists: except AlreadyExists:
@@ -111,7 +111,7 @@ class SchemaTests(unittest.TestCase):
execute_until_pass(session, "CREATE TABLE test.cf (key int PRIMARY KEY, value int)") execute_until_pass(session, "CREATE TABLE test.cf (key int PRIMARY KEY, value int)")
for j in xrange(100): for j in range(100):
execute_until_pass(session, "INSERT INTO test.cf (key, value) VALUES ({0}, {0})".format(j)) execute_until_pass(session, "INSERT INTO test.cf (key, value) VALUES ({0}, {0})".format(j))
execute_until_pass(session, "DROP KEYSPACE test") execute_until_pass(session, "DROP KEYSPACE test")

View File

@@ -11,6 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
try: try:
from ccmlib import common from ccmlib import common
except ImportError as e: except ImportError as e:

View File

@@ -25,6 +25,8 @@ from cassandra.query import tuple_factory, SimpleStatement
from tests.integration import use_singledc, PROTOCOL_VERSION from tests.integration import use_singledc, PROTOCOL_VERSION
from six import next
try: try:
import unittest2 as unittest import unittest2 as unittest
except ImportError: except ImportError:

View File

@@ -21,7 +21,8 @@ from cassandra.protocol import ProtocolHandler, ResultMessage, UUIDType, read_in
from cassandra.query import tuple_factory from cassandra.query import tuple_factory
from cassandra.cluster import Cluster from cassandra.cluster import Cluster
from tests.integration import use_singledc, PROTOCOL_VERSION, execute_until_pass from tests.integration import use_singledc, PROTOCOL_VERSION, execute_until_pass
from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, get_sample from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES
from tests.integration.standard.utils import create_table_with_all_types, get_all_primitive_params
from six import binary_type from six import binary_type
import uuid import uuid
@@ -62,7 +63,7 @@ class CustomProtocolHandlerTest(unittest.TestCase):
""" """
# Ensure that we get normal uuid back first # Ensure that we get normal uuid back first
session = Cluster().connect() session = Cluster(protocol_version=PROTOCOL_VERSION).connect(keyspace="custserdes")
session.row_factory = tuple_factory session.row_factory = tuple_factory
result_set = session.execute("SELECT schema_version FROM system.local") result_set = session.execute("SELECT schema_version FROM system.local")
result = result_set.pop() result = result_set.pop()
@@ -102,15 +103,16 @@ class CustomProtocolHandlerTest(unittest.TestCase):
@test_category data_types:serialization @test_category data_types:serialization
""" """
# Connect using a custom protocol handler that tracks the various types the result message is used with. # Connect using a custom protocol handler that tracks the various types the result message is used with.
session = Cluster().connect(keyspace="custserdes") session = Cluster(protocol_version=PROTOCOL_VERSION).connect(keyspace="custserdes")
session.client_protocol_handler = CustomProtocolHandlerResultMessageTracked session.client_protocol_handler = CustomProtocolHandlerResultMessageTracked
session.row_factory = tuple_factory session.row_factory = tuple_factory
columns_string = create_table_with_all_types("test_table", session) colnames = create_table_with_all_types("alltypes", session, 1)
columns_string = ", ".join(colnames)
# verify data # verify data
params = get_all_primitive_params() params = get_all_primitive_params(0)
results = session.execute("SELECT {0} FROM alltypes WHERE zz=0".format(columns_string))[0] results = session.execute("SELECT {0} FROM alltypes WHERE primkey=0".format(columns_string))[0]
for expected, actual in zip(params, results): for expected, actual in zip(params, results):
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
# Ensure we have covered the various primitive types # Ensure we have covered the various primitive types
@@ -118,43 +120,6 @@ class CustomProtocolHandlerTest(unittest.TestCase):
session.shutdown() session.shutdown()
def create_table_with_all_types(table_name, session):
"""
Method that given a table_name and session construct a table that contains all possible primitive types
:param table_name: Name of table to create
:param session: session to use for table creation
:return: a string containing and columns. This can be used to query the table.
"""
# create table
alpha_type_list = ["zz int PRIMARY KEY"]
col_names = ["zz"]
start_index = ord('a')
for i, datatype in enumerate(PRIMITIVE_DATATYPES):
alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype))
col_names.append(chr(start_index + i))
session.execute("CREATE TABLE alltypes ({0})".format(', '.join(alpha_type_list)), timeout=120)
# create the input
params = get_all_primitive_params()
# insert into table as a simple statement
columns_string = ', '.join(col_names)
placeholders = ', '.join(["%s"] * len(col_names))
session.execute("INSERT INTO alltypes ({0}) VALUES ({1})".format(columns_string, placeholders), params, timeout=120)
return columns_string
def get_all_primitive_params():
"""
Simple utility method used to give back a list of all possible primitive data sample types.
"""
params = [0]
for datatype in PRIMITIVE_DATATYPES:
params.append((get_sample(datatype)))
return params
class CustomResultMessageRaw(ResultMessage): class CustomResultMessageRaw(ResultMessage):
""" """
This is a custom Result Message that is used to return raw results, rather then This is a custom Result Message that is used to return raw results, rather then

View File

@@ -0,0 +1,129 @@
"""Test the various Cython-based message deserializers"""
# Based on test_custom_protocol_handler.py
try:
import unittest2 as unittest
except ImportError:
import unittest
from cassandra.query import tuple_factory
from cassandra.cluster import Cluster
from cassandra.protocol import ProtocolHandler, LazyProtocolHandler, NumpyProtocolHandler
from tests.integration import use_singledc, PROTOCOL_VERSION
from tests.integration.datatype_utils import update_datatypes
from tests.integration.standard.utils import (
create_table_with_all_types, get_all_primitive_params, get_primitive_datatypes)
from tests.unit.cython.utils import cythontest, numpytest
def setup_module():
use_singledc()
update_datatypes()
class CythonProtocolHandlerTest(unittest.TestCase):
N_ITEMS = 10
@classmethod
def setUpClass(cls):
cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
cls.session = cls.cluster.connect()
cls.session.execute("CREATE KEYSPACE testspace WITH replication = "
"{ 'class' : 'SimpleStrategy', 'replication_factor': '1'}")
cls.session.set_keyspace("testspace")
cls.colnames = create_table_with_all_types("test_table", cls.session, cls.N_ITEMS)
@classmethod
def tearDownClass(cls):
cls.session.execute("DROP KEYSPACE testspace")
cls.cluster.shutdown()
@cythontest
def test_cython_parser(self):
"""
Test Cython-based parser that returns a list of tuples
"""
verify_iterator_data(self.assertEqual, get_data(ProtocolHandler))
@cythontest
def test_cython_lazy_parser(self):
"""
Test Cython-based parser that returns an iterator of tuples
"""
verify_iterator_data(self.assertEqual, get_data(LazyProtocolHandler))
@numpytest
def test_numpy_parser(self):
"""
Test Numpy-based parser that returns a NumPy array
"""
# arrays = { 'a': arr1, 'b': arr2, ... }
arrays = get_data(NumpyProtocolHandler)
colnames = self.colnames
datatypes = get_primitive_datatypes()
for colname, datatype in zip(colnames, datatypes):
arr = arrays[colname]
self.match_dtype(datatype, arr.dtype)
verify_iterator_data(self.assertEqual, arrays_to_list_of_tuples(arrays, colnames))
def match_dtype(self, datatype, dtype):
"""Match a string cqltype (e.g. 'int' or 'blob') with a numpy dtype"""
if datatype == 'smallint':
self.match_dtype_props(dtype, 'i', 2)
elif datatype == 'int':
self.match_dtype_props(dtype, 'i', 4)
elif datatype in ('bigint', 'counter'):
self.match_dtype_props(dtype, 'i', 8)
elif datatype == 'float':
self.match_dtype_props(dtype, 'f', 4)
elif datatype == 'double':
self.match_dtype_props(dtype, 'f', 8)
else:
self.assertEqual(dtype.kind, 'O', msg=(dtype, datatype))
def match_dtype_props(self, dtype, kind, size, signed=None):
self.assertEqual(dtype.kind, kind, msg=dtype)
self.assertEqual(dtype.itemsize, size, msg=dtype)
def arrays_to_list_of_tuples(arrays, colnames):
"""Convert a dict of arrays (as given by the numpy protocol handler) to a list of tuples"""
first_array = arrays[colnames[0]]
return [tuple(arrays[colname][i] for colname in colnames)
for i in range(len(first_array))]
def get_data(protocol_handler):
"""
Get some data from the test table.
:param key: if None, get all results (100.000 results), otherwise get only one result
"""
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
session = cluster.connect(keyspace="testspace")
# use our custom protocol handler
session.client_protocol_handler = protocol_handler
session.row_factory = tuple_factory
results = session.execute("SELECT * FROM test_table")
session.shutdown()
return results
def verify_iterator_data(assertEqual, results):
"""
Check the result of get_data() when this is a list or
iterator of tuples
"""
for result in results:
params = get_all_primitive_params(result[0])
assertEqual(len(params), len(result),
msg="Not the right number of columns?")
for expected, actual in zip(params, result):
assertEqual(actual, expected)

View File

@@ -301,8 +301,8 @@ class BatchStatementTests(unittest.TestCase):
keys.add(result.k) keys.add(result.k)
values.add(result.v) values.add(result.v)
self.assertEqual(set(range(10)), keys) self.assertEqual(set(range(10)), keys, msg=results)
self.assertEqual(set(range(10)), values) self.assertEqual(set(range(10)), values, msg=results)
def test_string_statements(self): def test_string_statements(self):
batch = BatchStatement(BatchType.LOGGED) batch = BatchStatement(BatchType.LOGGED)
@@ -370,6 +370,11 @@ class BatchStatementTests(unittest.TestCase):
self.session.execute(batch) self.session.execute(batch)
self.confirm_results() self.confirm_results()
def test_no_parameters_many_times(self):
for i in range(1000):
self.test_no_parameters()
self.session.execute("TRUNCATE test3rf.test")
class SerialConsistencyTests(unittest.TestCase): class SerialConsistencyTests(unittest.TestCase):
def setUp(self): def setUp(self):

View File

@@ -0,0 +1,53 @@
"""
Helper module to populate a dummy Cassandra tables with data.
"""
from tests.integration.datatype_utils import PRIMITIVE_DATATYPES, get_sample
def create_table_with_all_types(table_name, session, N):
"""
Method that given a table_name and session construct a table that contains
all possible primitive types.
:param table_name: Name of table to create
:param session: session to use for table creation
:param N: the number of items to insert into the table
:return: a list of column names
"""
# create table
alpha_type_list = ["primkey int PRIMARY KEY"]
col_names = ["primkey"]
start_index = ord('a')
for i, datatype in enumerate(PRIMITIVE_DATATYPES):
alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype))
col_names.append(chr(start_index + i))
session.execute("CREATE TABLE {0} ({1})".format(
table_name, ', '.join(alpha_type_list)), timeout=120)
# create the input
for key in range(N):
params = get_all_primitive_params(key)
# insert into table as a simple statement
columns_string = ', '.join(col_names)
placeholders = ', '.join(["%s"] * len(col_names))
session.execute("INSERT INTO {0} ({1}) VALUES ({2})".format(
table_name, columns_string, placeholders), params, timeout=120)
return col_names
def get_all_primitive_params(key):
"""
Simple utility method used to give back a list of all possible primitive data sample types.
"""
params = [key]
for datatype in PRIMITIVE_DATATYPES:
params.append(get_sample(datatype))
return params
def get_primitive_datatypes():
return ['int'] + list(PRIMITIVE_DATATYPES)

View File

@@ -75,7 +75,7 @@ class StressInsertsTests(unittest.TestCase):
break break
for conn in pool.get_connections(): for conn in pool.get_connections():
if conn.in_flight > 1: if conn.in_flight > 1:
print self.session.get_pool_state() print(self.session.get_pool_state())
leaking_connections = True leaking_connections = True
break break
i = i + 1 i = i + 1

View File

@@ -0,0 +1,14 @@
# Copyright 2013-2015 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,44 @@
# Copyright 2013-2015 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from cassandra.bytesio cimport BytesIOReader
def test_read1(assert_equal, assert_raises):
cdef BytesIOReader reader = BytesIOReader(b'abcdef')
assert_equal(reader.read(2)[:2], b'ab')
assert_equal(reader.read(2)[:2], b'cd')
assert_equal(reader.read(0)[:0], b'')
assert_equal(reader.read(2)[:2], b'ef')
def test_read2(assert_equal, assert_raises):
cdef BytesIOReader reader = BytesIOReader(b'abcdef')
reader.read(5)
reader.read(1)
def test_read3(assert_equal, assert_raises):
cdef BytesIOReader reader = BytesIOReader(b'abcdef')
reader.read(6)
def test_read_eof(assert_equal, assert_raises):
cdef BytesIOReader reader = BytesIOReader(b'abcdef')
reader.read(5)
# cannot convert reader.read to an object, do it manually
# assert_raises(EOFError, reader.read, 2)
try:
reader.read(2)
except EOFError:
pass
else:
raise Exception("Expected an EOFError")
reader.read(1) # see that we can still read this

View File

@@ -0,0 +1,35 @@
# Copyright 2013-2015 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests.unit.cython.utils import cyimport, cythontest
bytesio_testhelper = cyimport('tests.unit.cython.bytesio_testhelper')
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
class BytesIOTest(unittest.TestCase):
"""Test Cython BytesIO proxy"""
@cythontest
def test_reading(self):
bytesio_testhelper.test_read1(self.assertEqual, self.assertRaises)
bytesio_testhelper.test_read2(self.assertEqual, self.assertRaises)
bytesio_testhelper.test_read3(self.assertEqual, self.assertRaises)
@cythontest
def test_reading_error(self):
bytesio_testhelper.test_read_eof(self.assertEqual, self.assertRaises)

View File

@@ -0,0 +1,38 @@
# Copyright 2013-2015 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
def cyimport(import_path):
"""
Import a Cython module if available, otherwise return None
(and skip any relevant tests).
"""
try:
return __import__(import_path, fromlist=[True])
except ImportError:
if HAVE_CYTHON:
raise
return None
# @cythontest
# def test_something(self): ...
cythontest = unittest.skipUnless(HAVE_CYTHON, 'Cython is not available')
numpytest = unittest.skipUnless(HAVE_CYTHON and HAVE_NUMPY, 'NumPy is not available')