Merge branch 'cython_serdes'
Conflicts: .gitignore
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -4,6 +4,7 @@
|
||||
*.so
|
||||
*.egg
|
||||
*.egg-info
|
||||
*.attr
|
||||
.tox
|
||||
.python-version
|
||||
build
|
||||
@@ -19,6 +20,7 @@ setuptools*.egg
|
||||
|
||||
cassandra/*.c
|
||||
!cassandra/cmurmur3.c
|
||||
cassandra/*.html
|
||||
|
||||
# OSX
|
||||
.DS_Store
|
||||
@@ -38,3 +40,4 @@ cassandra/*.c
|
||||
|
||||
#iPython
|
||||
*.ipynb
|
||||
|
||||
|
||||
@@ -1 +1,4 @@
|
||||
include setup.py README.rst MANIFEST.in LICENSE ez_setup.py
|
||||
include cassandra/*.pyx
|
||||
include cassandra/*.pxd
|
||||
include tests/unit/cython/*.pyx
|
||||
|
||||
58
cassandra/buffer.pxd
Normal file
58
cassandra/buffer.pxd
Normal file
@@ -0,0 +1,58 @@
|
||||
# Copyright 2013-2015 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Simple buffer data structure that provides a view on existing memory
|
||||
(e.g. from a bytes object). This memory must stay alive while the
|
||||
buffer is in use.
|
||||
"""
|
||||
|
||||
from cpython.bytes cimport PyBytes_AS_STRING
|
||||
# char* PyBytes_AS_STRING(object string)
|
||||
# Macro form of PyBytes_AsString() but without error
|
||||
# checking. Only string objects are supported; no Unicode objects
|
||||
# should be passed.
|
||||
|
||||
|
||||
cdef struct Buffer:
|
||||
char *ptr
|
||||
Py_ssize_t size
|
||||
|
||||
|
||||
cdef inline bytes to_bytes(Buffer *buf):
|
||||
return buf.ptr[:buf.size]
|
||||
|
||||
cdef inline char *buf_ptr(Buffer *buf):
|
||||
return buf.ptr
|
||||
|
||||
cdef inline char *buf_read(Buffer *buf, Py_ssize_t size) except NULL:
|
||||
if size > buf.size:
|
||||
raise IndexError("Requested more than length of buffer")
|
||||
return buf.ptr
|
||||
|
||||
cdef inline int slice_buffer(Buffer *buf, Buffer *out,
|
||||
Py_ssize_t start, Py_ssize_t size) except -1:
|
||||
if size < 0:
|
||||
raise ValueError("Length must be positive")
|
||||
|
||||
if start + size > buf.size:
|
||||
raise IndexError("Buffer slice out of bounds")
|
||||
|
||||
out.ptr = buf.ptr + start
|
||||
out.size = size
|
||||
return 0
|
||||
|
||||
cdef inline void from_ptr_and_size(char *ptr, Py_ssize_t size, Buffer *out):
|
||||
out.ptr = ptr
|
||||
out.size = size
|
||||
20
cassandra/bytesio.pxd
Normal file
20
cassandra/bytesio.pxd
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright 2013-2015 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
cdef class BytesIOReader:
|
||||
cdef bytes buf
|
||||
cdef char *buf_ptr
|
||||
cdef Py_ssize_t pos
|
||||
cdef Py_ssize_t size
|
||||
cdef char *read(self, Py_ssize_t n = ?) except NULL
|
||||
44
cassandra/bytesio.pyx
Normal file
44
cassandra/bytesio.pyx
Normal file
@@ -0,0 +1,44 @@
|
||||
# Copyright 2013-2015 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
cdef class BytesIOReader:
|
||||
"""
|
||||
This class provides efficient support for reading bytes from a 'bytes' buffer,
|
||||
by returning char * values directly without allocating intermediate objects.
|
||||
"""
|
||||
|
||||
def __init__(self, bytes buf):
|
||||
self.buf = buf
|
||||
self.size = len(buf)
|
||||
self.buf_ptr = self.buf
|
||||
|
||||
cdef char *read(self, Py_ssize_t n = -1) except NULL:
|
||||
"""Read at most size bytes from the file
|
||||
(less if the read hits EOF before obtaining size bytes).
|
||||
|
||||
If the size argument is negative or omitted, read all data until EOF
|
||||
is reached. The bytes are returned as a string object. An empty
|
||||
string is returned when EOF is encountered immediately.
|
||||
"""
|
||||
cdef Py_ssize_t newpos = self.pos + n
|
||||
if n < 0:
|
||||
newpos = self.size
|
||||
elif newpos > self.size:
|
||||
# Raise an error here, as we do not want the caller to consume past the
|
||||
# end of the buffer
|
||||
raise EOFError("Cannot read past the end of the file")
|
||||
|
||||
cdef char *res = self.buf_ptr + self.pos
|
||||
self.pos = newpos
|
||||
return res
|
||||
@@ -288,7 +288,7 @@ class _CassandraType(object):
|
||||
Given a set of other CassandraTypes, create a new subtype of this type
|
||||
using them as parameters. This is how composite types are constructed.
|
||||
|
||||
>>> MapType.apply_parameters(DateType, BooleanType)
|
||||
>>> MapType.apply_parameters([DateType, BooleanType])
|
||||
<class 'cassandra.types.MapType(DateType, BooleanType)'>
|
||||
|
||||
`subtypes` will be a sequence of CassandraTypes. If provided, `names`
|
||||
@@ -884,27 +884,7 @@ class UserType(TupleType):
|
||||
|
||||
@classmethod
|
||||
def deserialize_safe(cls, byts, protocol_version):
|
||||
proto_version = max(3, protocol_version)
|
||||
p = 0
|
||||
values = []
|
||||
for col_type in cls.subtypes:
|
||||
if p == len(byts):
|
||||
break
|
||||
itemlen = int32_unpack(byts[p:p + 4])
|
||||
p += 4
|
||||
if itemlen >= 0:
|
||||
item = byts[p:p + itemlen]
|
||||
p += itemlen
|
||||
else:
|
||||
item = None
|
||||
# collections inside UDTs are always encoded with at least the
|
||||
# version 3 format
|
||||
values.append(col_type.from_binary(item, proto_version))
|
||||
|
||||
if len(values) < len(cls.subtypes):
|
||||
nones = [None] * (len(cls.subtypes) - len(values))
|
||||
values = values + nones
|
||||
|
||||
values = super(UserType, cls).deserialize_safe(byts, protocol_version)
|
||||
if cls.mapped_class:
|
||||
return cls.mapped_class(**dict(zip(cls.fieldnames, values)))
|
||||
else:
|
||||
|
||||
11
cassandra/cython_deps.py
Normal file
11
cassandra/cython_deps.py
Normal file
@@ -0,0 +1,11 @@
|
||||
try:
|
||||
from cassandra.row_parser import make_recv_results_rows
|
||||
HAVE_CYTHON = True
|
||||
except ImportError:
|
||||
HAVE_CYTHON = False
|
||||
|
||||
try:
|
||||
import numpy
|
||||
HAVE_NUMPY = True
|
||||
except ImportError:
|
||||
HAVE_NUMPY = False
|
||||
123
cassandra/cython_marshal.pyx
Normal file
123
cassandra/cython_marshal.pyx
Normal file
@@ -0,0 +1,123 @@
|
||||
# -- cython: profile=True
|
||||
#
|
||||
# Copyright 2013-2015 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import six
|
||||
|
||||
from libc.stdint cimport (int8_t, int16_t, int32_t, int64_t,
|
||||
uint8_t, uint16_t, uint32_t, uint64_t)
|
||||
from cassandra.buffer cimport Buffer, buf_read, to_bytes
|
||||
|
||||
cdef bint is_little_endian
|
||||
from cassandra.util import is_little_endian
|
||||
|
||||
cdef bint PY3 = six.PY3
|
||||
|
||||
|
||||
cdef inline void swap_order(char *buf, Py_ssize_t size):
|
||||
"""
|
||||
Swap the byteorder of `buf` in-place on little-endian platforms
|
||||
(reverse all the bytes).
|
||||
There are functions ntohl etc, but these may be POSIX-dependent.
|
||||
"""
|
||||
cdef Py_ssize_t start, end, i
|
||||
cdef char c
|
||||
|
||||
if is_little_endian:
|
||||
for i in range(div2(size)):
|
||||
end = size - i - 1
|
||||
c = buf[i]
|
||||
buf[i] = buf[end]
|
||||
buf[end] = c
|
||||
|
||||
cdef inline Py_ssize_t div2(Py_ssize_t x):
|
||||
return x >> 1
|
||||
|
||||
### Unpacking of signed integers
|
||||
|
||||
cdef inline int64_t int64_unpack(Buffer *buf) except ?0xDEAD:
|
||||
cdef int64_t x = (<int64_t *> buf_read(buf, 8))[0]
|
||||
cdef char *p = <char *> &x
|
||||
swap_order(<char *> &x, 8)
|
||||
return x
|
||||
|
||||
cdef inline int32_t int32_unpack(Buffer *buf) except ?0xDEAD:
|
||||
cdef int32_t x = (<int32_t *> buf_read(buf, 4))[0]
|
||||
cdef char *p = <char *> &x
|
||||
swap_order(<char *> &x, 4)
|
||||
return x
|
||||
|
||||
cdef inline int16_t int16_unpack(Buffer *buf) except ?0xDED:
|
||||
cdef int16_t x = (<int16_t *> buf_read(buf, 2))[0]
|
||||
swap_order(<char *> &x, 2)
|
||||
return x
|
||||
|
||||
cdef inline int8_t int8_unpack(Buffer *buf) except ?80:
|
||||
return (<int8_t *> buf_read(buf, 1))[0]
|
||||
|
||||
cdef inline uint64_t uint64_unpack(Buffer *buf) except ?0xDEAD:
|
||||
cdef uint64_t x = (<uint64_t *> buf_read(buf, 8))[0]
|
||||
swap_order(<char *> &x, 8)
|
||||
return x
|
||||
|
||||
cdef inline uint32_t uint32_unpack(Buffer *buf) except ?0xDEAD:
|
||||
cdef uint32_t x = (<uint32_t *> buf_read(buf, 4))[0]
|
||||
swap_order(<char *> &x, 4)
|
||||
return x
|
||||
|
||||
cdef inline uint16_t uint16_unpack(Buffer *buf) except ?0xDEAD:
|
||||
cdef uint16_t x = (<uint16_t *> buf_read(buf, 2))[0]
|
||||
swap_order(<char *> &x, 2)
|
||||
return x
|
||||
|
||||
cdef inline uint8_t uint8_unpack(Buffer *buf) except ?0xff:
|
||||
return (<uint8_t *> buf_read(buf, 1))[0]
|
||||
|
||||
cdef inline double double_unpack(Buffer *buf) except ?1.74:
|
||||
cdef double x = (<double *> buf_read(buf, 8))[0]
|
||||
swap_order(<char *> &x, 8)
|
||||
return x
|
||||
|
||||
cdef inline float float_unpack(Buffer *buf) except ?1.74:
|
||||
cdef float x = (<float *> buf_read(buf, 4))[0]
|
||||
swap_order(<char *> &x, 4)
|
||||
return x
|
||||
|
||||
|
||||
cdef varint_unpack(Buffer *term):
|
||||
"""Unpack a variable-sized integer"""
|
||||
if PY3:
|
||||
return varint_unpack_py3(to_bytes(term))
|
||||
else:
|
||||
return varint_unpack_py2(to_bytes(term))
|
||||
|
||||
# TODO: Optimize these two functions
|
||||
cdef varint_unpack_py3(bytes term):
|
||||
cdef int64_t one = 1L
|
||||
|
||||
val = int(''.join("%02x" % i for i in term), 16)
|
||||
if (term[0] & 128) != 0:
|
||||
# There is a bug in Cython (0.20 - 0.22), where if we do
|
||||
# '1 << (len(term) * 8)' Cython generates '1' directly into the
|
||||
# C code, causing integer overflows
|
||||
val -= one << (len(term) * 8)
|
||||
return val
|
||||
|
||||
cdef varint_unpack_py2(bytes term): # noqa
|
||||
cdef int64_t one = 1L
|
||||
val = int(term.encode('hex'), 16)
|
||||
if (ord(term[0]) & 128) != 0:
|
||||
val = val - (one << (len(term) * 8))
|
||||
return val
|
||||
2
cassandra/cython_utils.pxd
Normal file
2
cassandra/cython_utils.pxd
Normal file
@@ -0,0 +1,2 @@
|
||||
from libc.stdint cimport int64_t
|
||||
cdef datetime_from_timestamp(double timestamp)
|
||||
44
cassandra/cython_utils.pyx
Normal file
44
cassandra/cython_utils.pyx
Normal file
@@ -0,0 +1,44 @@
|
||||
# Copyright 2013-2015 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Duplicate module of util.py, with some accelerated functions
|
||||
used for deserialization.
|
||||
"""
|
||||
|
||||
from cpython.datetime cimport (
|
||||
timedelta_new,
|
||||
# cdef inline object timedelta_new(int days, int seconds, int useconds)
|
||||
# Create timedelta object using DateTime CAPI factory function.
|
||||
# Note, there are no range checks for any of the arguments.
|
||||
import_datetime,
|
||||
# Datetime C API initialization function.
|
||||
# You have to call it before any usage of DateTime CAPI functions.
|
||||
)
|
||||
|
||||
import datetime
|
||||
import sys
|
||||
|
||||
cdef bint is_little_endian
|
||||
from cassandra.util import is_little_endian
|
||||
|
||||
import_datetime()
|
||||
|
||||
DATETIME_EPOC = datetime.datetime(1970, 1, 1)
|
||||
|
||||
|
||||
cdef datetime_from_timestamp(double timestamp):
|
||||
cdef int seconds = <int> timestamp
|
||||
cdef int microseconds = (<int64_t> (timestamp * 1000000)) % 1000000
|
||||
return DATETIME_EPOC + timedelta_new(0, seconds, microseconds)
|
||||
43
cassandra/deserializers.pxd
Normal file
43
cassandra/deserializers.pxd
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright 2013-2015 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from cassandra.buffer cimport Buffer
|
||||
|
||||
cdef class Deserializer:
|
||||
# The cqltypes._CassandraType corresponding to this deserializer
|
||||
cdef object cqltype
|
||||
|
||||
# String may be empty, whereas other values may not be.
|
||||
# Other values may be NULL, in which case the integer length
|
||||
# of the binary data is negative. However, non-string types
|
||||
# may also return a zero length for legacy reasons
|
||||
# (see http://code.metager.de/source/xref/apache/cassandra/doc/native_protocol_v3.spec
|
||||
# paragraph 6)
|
||||
cdef bint empty_binary_ok
|
||||
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version)
|
||||
# cdef deserialize(self, CString byts, protocol_version)
|
||||
|
||||
|
||||
cdef inline object from_binary(Deserializer deserializer,
|
||||
Buffer *buf,
|
||||
int protocol_version):
|
||||
if buf.size < 0:
|
||||
return None
|
||||
elif buf.size == 0 and not deserializer.empty_binary_ok:
|
||||
return _ret_empty(deserializer, buf.size)
|
||||
else:
|
||||
return deserializer.deserialize(buf, protocol_version)
|
||||
|
||||
cdef _ret_empty(Deserializer deserializer, Py_ssize_t buf_size)
|
||||
512
cassandra/deserializers.pyx
Normal file
512
cassandra/deserializers.pyx
Normal file
@@ -0,0 +1,512 @@
|
||||
# Copyright 2013-2015 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from libc.stdint cimport int32_t, uint16_t
|
||||
|
||||
include 'cython_marshal.pyx'
|
||||
from cassandra.buffer cimport Buffer, to_bytes, slice_buffer
|
||||
from cassandra.cython_utils cimport datetime_from_timestamp
|
||||
|
||||
from cython.view cimport array as cython_array
|
||||
from cassandra.tuple cimport tuple_new, tuple_set
|
||||
|
||||
import socket
|
||||
from decimal import Decimal
|
||||
from uuid import UUID
|
||||
|
||||
from cassandra import cqltypes
|
||||
from cassandra import util
|
||||
|
||||
cdef bint PY2 = six.PY2
|
||||
|
||||
cdef class Deserializer:
|
||||
"""Cython-based deserializer class for a cqltype"""
|
||||
|
||||
def __init__(self, cqltype):
|
||||
self.cqltype = cqltype
|
||||
self.empty_binary_ok = cqltype.empty_binary_ok
|
||||
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
cdef class DesBytesType(Deserializer):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
return to_bytes(buf)
|
||||
|
||||
|
||||
# TODO: Use libmpdec: http://www.bytereef.org/mpdecimal/index.html
|
||||
cdef class DesDecimalType(Deserializer):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
cdef Buffer varint_buf
|
||||
slice_buffer(buf, &varint_buf, 4, buf.size - 4)
|
||||
|
||||
scale = int32_unpack(buf)
|
||||
unscaled = varint_unpack(&varint_buf)
|
||||
|
||||
return Decimal('%de%d' % (unscaled, -scale))
|
||||
|
||||
|
||||
cdef class DesUUIDType(Deserializer):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
return UUID(bytes=to_bytes(buf))
|
||||
|
||||
|
||||
cdef class DesBooleanType(Deserializer):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
if int8_unpack(buf):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
cdef class DesByteType(Deserializer):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
return int8_unpack(buf)
|
||||
|
||||
|
||||
cdef class DesAsciiType(Deserializer):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
if PY2:
|
||||
return to_bytes(buf)
|
||||
return to_bytes(buf).decode('ascii')
|
||||
|
||||
|
||||
cdef class DesFloatType(Deserializer):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
return float_unpack(buf)
|
||||
|
||||
|
||||
cdef class DesDoubleType(Deserializer):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
return double_unpack(buf)
|
||||
|
||||
|
||||
cdef class DesLongType(Deserializer):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
return int64_unpack(buf)
|
||||
|
||||
|
||||
cdef class DesInt32Type(Deserializer):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
return int32_unpack(buf)
|
||||
|
||||
|
||||
cdef class DesIntegerType(Deserializer):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
return varint_unpack(buf)
|
||||
|
||||
|
||||
cdef class DesInetAddressType(Deserializer):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
cdef bytes byts = to_bytes(buf)
|
||||
|
||||
# TODO: optimize inet_ntop, inet_ntoa
|
||||
if buf.size == 16:
|
||||
return util.inet_ntop(socket.AF_INET6, byts)
|
||||
else:
|
||||
# util.inet_pton could also handle, but this is faster
|
||||
# since we've already determined the AF
|
||||
return socket.inet_ntoa(byts)
|
||||
|
||||
|
||||
cdef class DesCounterColumnType(DesLongType):
|
||||
pass
|
||||
|
||||
|
||||
cdef class DesDateType(Deserializer):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
cdef double timestamp = int64_unpack(buf) / 1000.0
|
||||
return datetime_from_timestamp(timestamp)
|
||||
|
||||
|
||||
cdef class TimestampType(DesDateType):
|
||||
pass
|
||||
|
||||
|
||||
cdef class TimeUUIDType(DesDateType):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
return UUID(bytes=to_bytes(buf))
|
||||
|
||||
|
||||
# Values of the 'date'` type are encoded as 32-bit unsigned integers
|
||||
# representing a number of days with epoch (January 1st, 1970) at the center of the
|
||||
# range (2^31).
|
||||
EPOCH_OFFSET_DAYS = 2 ** 31
|
||||
|
||||
cdef class DesSimpleDateType(Deserializer):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
days = uint32_unpack(buf) - EPOCH_OFFSET_DAYS
|
||||
return util.Date(days)
|
||||
|
||||
|
||||
cdef class DesShortType(Deserializer):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
return int16_unpack(buf)
|
||||
|
||||
|
||||
cdef class DesTimeType(Deserializer):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
return util.Time(int64_unpack(buf))
|
||||
|
||||
|
||||
cdef class DesUTF8Type(Deserializer):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
cdef val = to_bytes(buf)
|
||||
return val.decode('utf8')
|
||||
|
||||
|
||||
cdef class DesVarcharType(DesUTF8Type):
|
||||
pass
|
||||
|
||||
|
||||
cdef class _DesParameterizedType(Deserializer):
|
||||
|
||||
cdef object subtypes
|
||||
cdef Deserializer[::1] deserializers
|
||||
cdef Py_ssize_t subtypes_len
|
||||
|
||||
def __init__(self, cqltype):
|
||||
super().__init__(cqltype)
|
||||
self.subtypes = cqltype.subtypes
|
||||
self.deserializers = make_deserializers(cqltype.subtypes)
|
||||
self.subtypes_len = len(self.subtypes)
|
||||
|
||||
|
||||
cdef class _DesSingleParamType(_DesParameterizedType):
|
||||
cdef Deserializer deserializer
|
||||
|
||||
def __init__(self, cqltype):
|
||||
assert cqltype.subtypes and len(cqltype.subtypes) == 1, cqltype.subtypes
|
||||
super().__init__(cqltype)
|
||||
self.deserializer = self.deserializers[0]
|
||||
|
||||
|
||||
#--------------------------------------------------------------------------
|
||||
# List and set deserialization
|
||||
|
||||
cdef class DesListType(_DesSingleParamType):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
cdef uint16_t v2_and_below = 2
|
||||
cdef int32_t v3_and_above = 3
|
||||
|
||||
if protocol_version >= 3:
|
||||
result = _deserialize_list_or_set[int32_t](
|
||||
v3_and_above, buf, protocol_version, self.deserializer)
|
||||
else:
|
||||
result = _deserialize_list_or_set[uint16_t](
|
||||
v2_and_below, buf, protocol_version, self.deserializer)
|
||||
|
||||
return result
|
||||
|
||||
cdef class DesSetType(DesListType):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
return util.sortedset(DesListType.deserialize(self, buf, protocol_version))
|
||||
|
||||
|
||||
ctypedef fused itemlen_t:
|
||||
uint16_t # protocol <= v2
|
||||
int32_t # protocol >= v3
|
||||
|
||||
cdef list _deserialize_list_or_set(itemlen_t dummy_version,
|
||||
Buffer *buf, int protocol_version,
|
||||
Deserializer deserializer):
|
||||
"""
|
||||
Deserialize a list or set.
|
||||
|
||||
The 'dummy' parameter is needed to make fused types work, so that
|
||||
we can specialize on the protocol version.
|
||||
"""
|
||||
cdef itemlen_t itemlen
|
||||
cdef Buffer itemlen_buf
|
||||
cdef Buffer elem_buf
|
||||
|
||||
cdef itemlen_t numelements
|
||||
cdef itemlen_t idx
|
||||
cdef list result = []
|
||||
|
||||
_unpack_len[itemlen_t](0, &numelements, buf)
|
||||
idx = sizeof(itemlen_t)
|
||||
|
||||
for _ in range(numelements):
|
||||
subelem(buf, &elem_buf, &idx)
|
||||
result.append(from_binary(deserializer, &elem_buf, protocol_version))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
cdef inline int subelem(
|
||||
Buffer *buf, Buffer *elem_buf, itemlen_t *idx_p) except -1:
|
||||
"""
|
||||
Read the next element from the buffer: first read the size (in bytes) of the
|
||||
element, then fill elem_buf with a newly sliced buffer of this size (and the
|
||||
right offset).
|
||||
|
||||
NOTE: The handling of 'idx' is somewhat atrocious, as there is a Cython
|
||||
bug with the combination fused types + 'except' clause.
|
||||
So instead, we pass in a pointer to 'idx', namely 'idx_p', and write
|
||||
to this instead.
|
||||
"""
|
||||
cdef itemlen_t elemlen
|
||||
|
||||
_unpack_len[itemlen_t](idx_p[0], &elemlen, buf)
|
||||
idx_p[0] += sizeof(itemlen_t)
|
||||
slice_buffer(buf, elem_buf, idx_p[0], elemlen)
|
||||
idx_p[0] += elemlen
|
||||
return 0
|
||||
|
||||
|
||||
cdef int _unpack_len(itemlen_t idx, itemlen_t *elemlen, Buffer *buf) except -1:
|
||||
cdef itemlen_t result
|
||||
cdef Buffer itemlen_buf
|
||||
slice_buffer(buf, &itemlen_buf, idx, sizeof(itemlen_t))
|
||||
|
||||
if itemlen_t is uint16_t:
|
||||
elemlen[0] = uint16_unpack(&itemlen_buf)
|
||||
else:
|
||||
elemlen[0] = int32_unpack(&itemlen_buf)
|
||||
|
||||
return 0
|
||||
|
||||
#--------------------------------------------------------------------------
|
||||
# Map deserialization
|
||||
|
||||
cdef class DesMapType(_DesParameterizedType):
|
||||
|
||||
cdef Deserializer key_deserializer, val_deserializer
|
||||
|
||||
def __init__(self, cqltype):
|
||||
super().__init__(cqltype)
|
||||
self.key_deserializer = self.deserializers[0]
|
||||
self.val_deserializer = self.deserializers[1]
|
||||
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
cdef uint16_t v2_and_below = 0
|
||||
cdef int32_t v3_and_above = 0
|
||||
key_type, val_type = self.cqltype.subtypes
|
||||
|
||||
if protocol_version >= 3:
|
||||
result = _deserialize_map[int32_t](
|
||||
v3_and_above, buf, protocol_version,
|
||||
self.key_deserializer, self.val_deserializer,
|
||||
key_type, val_type)
|
||||
else:
|
||||
result = _deserialize_map[uint16_t](
|
||||
v2_and_below, buf, protocol_version,
|
||||
self.key_deserializer, self.val_deserializer,
|
||||
key_type, val_type)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
cdef _deserialize_map(itemlen_t dummy_version,
|
||||
Buffer *buf, int protocol_version,
|
||||
Deserializer key_deserializer, Deserializer val_deserializer,
|
||||
key_type, val_type):
|
||||
cdef itemlen_t itemlen, val_len, key_len
|
||||
cdef Buffer key_buf, val_buf
|
||||
cdef Buffer itemlen_buf
|
||||
|
||||
cdef itemlen_t numelements
|
||||
cdef itemlen_t idx = sizeof(itemlen_t)
|
||||
cdef list result = []
|
||||
|
||||
_unpack_len[itemlen_t](0, &numelements, buf)
|
||||
idx = sizeof(itemlen_t)
|
||||
themap = util.OrderedMapSerializedKey(key_type, protocol_version)
|
||||
for _ in range(numelements):
|
||||
subelem(buf, &key_buf, &idx)
|
||||
subelem(buf, &val_buf, &idx)
|
||||
key = from_binary(key_deserializer, &key_buf, protocol_version)
|
||||
val = from_binary(val_deserializer, &val_buf, protocol_version)
|
||||
themap._insert_unchecked(key, to_bytes(&key_buf), val)
|
||||
|
||||
return themap
|
||||
|
||||
#--------------------------------------------------------------------------
|
||||
|
||||
cdef class DesTupleType(_DesParameterizedType):
|
||||
|
||||
# TODO: Use TupleRowParser to parse these tuples
|
||||
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
cdef Py_ssize_t i, p
|
||||
cdef int32_t itemlen
|
||||
cdef tuple res = tuple_new(self.subtypes_len)
|
||||
cdef Buffer item_buf
|
||||
cdef Buffer itemlen_buf
|
||||
cdef Deserializer deserializer
|
||||
|
||||
# collections inside UDTs are always encoded with at least the
|
||||
# version 3 format
|
||||
protocol_version = max(3, protocol_version)
|
||||
|
||||
p = 0
|
||||
values = []
|
||||
for i in range(self.subtypes_len):
|
||||
item = None
|
||||
if p < buf.size:
|
||||
slice_buffer(buf, &itemlen_buf, p, 4)
|
||||
itemlen = int32_unpack(&itemlen_buf)
|
||||
p += 4
|
||||
if itemlen >= 0:
|
||||
slice_buffer(buf, &item_buf, p, itemlen)
|
||||
p += itemlen
|
||||
|
||||
deserializer = self.deserializers[i]
|
||||
item = from_binary(deserializer, &item_buf, protocol_version)
|
||||
|
||||
tuple_set(res, i, item)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
cdef class DesUserType(DesTupleType):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
typ = self.cqltype
|
||||
values = DesTupleType.deserialize(self, buf, protocol_version)
|
||||
if typ.mapped_class:
|
||||
return typ.mapped_class(**dict(zip(typ.fieldnames, values)))
|
||||
else:
|
||||
return typ.tuple_type(*values)
|
||||
|
||||
|
||||
cdef class DesCompositeType(_DesParameterizedType):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
cdef Py_ssize_t i, idx, start
|
||||
cdef Buffer elem_buf
|
||||
cdef int16_t element_length
|
||||
cdef Deserializer deserializer
|
||||
cdef tuple res = tuple_new(self.subtypes_len)
|
||||
|
||||
idx = 0
|
||||
for i in range(self.subtypes_len):
|
||||
if not buf.size:
|
||||
# CompositeType can have missing elements at the end
|
||||
|
||||
# Fill the tuple with None values and slice it
|
||||
#
|
||||
# (I'm not sure a tuple needs to be fully initialized before
|
||||
# it can be destroyed, so play it safe)
|
||||
for j in range(i, self.subtypes_len):
|
||||
tuple_set(res, j, None)
|
||||
res = res[:i]
|
||||
break
|
||||
|
||||
element_length = uint16_unpack(buf)
|
||||
slice_buffer(buf, &elem_buf, 2, element_length)
|
||||
|
||||
deserializer = self.deserializers[i]
|
||||
item = from_binary(deserializer, &elem_buf, protocol_version)
|
||||
tuple_set(res, i, item)
|
||||
|
||||
# skip element length, element, and the EOC (one byte)
|
||||
start = 2 + element_length + 1
|
||||
slice_buffer(buf, buf, start, buf.size - start)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
DesDynamicCompositeType = DesCompositeType
|
||||
|
||||
|
||||
cdef class DesReversedType(_DesSingleParamType):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
return from_binary(self.deserializer, buf, protocol_version)
|
||||
|
||||
|
||||
cdef class DesFrozenType(_DesSingleParamType):
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
return from_binary(self.deserializer, buf, protocol_version)
|
||||
|
||||
#--------------------------------------------------------------------------
|
||||
|
||||
cdef _ret_empty(Deserializer deserializer, Py_ssize_t buf_size):
|
||||
"""
|
||||
Decide whether to return None or EMPTY when a buffer size is
|
||||
zero or negative. This is used by from_binary in deserializers.pxd.
|
||||
"""
|
||||
if buf_size < 0:
|
||||
return None
|
||||
elif deserializer.cqltype.support_empty_values:
|
||||
return cqltypes.EMPTY
|
||||
else:
|
||||
return None
|
||||
|
||||
#--------------------------------------------------------------------------
|
||||
# Generic deserialization
|
||||
|
||||
cdef class GenericDeserializer(Deserializer):
|
||||
"""
|
||||
Wrap a generic datatype for deserialization
|
||||
"""
|
||||
|
||||
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||
return self.cqltype.deserialize(to_bytes(buf), protocol_version)
|
||||
|
||||
def __repr__(self):
|
||||
return "GenericDeserializer(%s)" % (self.cqltype,)
|
||||
|
||||
#--------------------------------------------------------------------------
|
||||
# Helper utilities
|
||||
|
||||
def make_deserializers(cqltypes):
|
||||
"""Create an array of Deserializers for each given cqltype in cqltypes"""
|
||||
cdef Deserializer[::1] deserializers
|
||||
return obj_array([find_deserializer(ct) for ct in cqltypes])
|
||||
|
||||
|
||||
cdef dict classes = globals()
|
||||
|
||||
cpdef Deserializer find_deserializer(cqltype):
|
||||
"""Find a deserializer for a cqltype"""
|
||||
name = 'Des' + cqltype.__name__
|
||||
|
||||
if name in globals():
|
||||
cls = classes[name]
|
||||
elif issubclass(cqltype, cqltypes.ListType):
|
||||
cls = DesListType
|
||||
elif issubclass(cqltype, cqltypes.SetType):
|
||||
cls = DesSetType
|
||||
elif issubclass(cqltype, cqltypes.MapType):
|
||||
cls = DesMapType
|
||||
elif issubclass(cqltype, cqltypes.UserType):
|
||||
# UserType is a subclass of TupleType, so should precede it
|
||||
cls = DesUserType
|
||||
elif issubclass(cqltype, cqltypes.TupleType):
|
||||
cls = DesTupleType
|
||||
elif issubclass(cqltype, cqltypes.DynamicCompositeType):
|
||||
# DynamicCompositeType is a subclass of CompositeType, so should precede it
|
||||
cls = DesDynamicCompositeType
|
||||
elif issubclass(cqltype, cqltypes.CompositeType):
|
||||
cls = DesCompositeType
|
||||
elif issubclass(cqltype, cqltypes.ReversedType):
|
||||
cls = DesReversedType
|
||||
elif issubclass(cqltype, cqltypes.FrozenType):
|
||||
cls = DesFrozenType
|
||||
else:
|
||||
cls = GenericDeserializer
|
||||
|
||||
return cls(cqltype)
|
||||
|
||||
|
||||
def obj_array(list objs):
|
||||
"""Create a (Cython) array of objects given a list of objects"""
|
||||
cdef object[:] arr
|
||||
cdef Py_ssize_t i
|
||||
arr = cython_array(shape=(len(objs),), itemsize=sizeof(void *), format="O")
|
||||
# arr[:] = objs # This does not work (segmentation faults)
|
||||
for i, obj in enumerate(objs):
|
||||
arr[i] = obj
|
||||
return arr
|
||||
47
cassandra/ioutils.pyx
Normal file
47
cassandra/ioutils.pyx
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright 2013-2015 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
include 'cython_marshal.pyx'
|
||||
from cassandra.buffer cimport Buffer, from_ptr_and_size
|
||||
|
||||
from libc.stdint cimport int32_t
|
||||
from cassandra.bytesio cimport BytesIOReader
|
||||
|
||||
|
||||
cdef inline int get_buf(BytesIOReader reader, Buffer *buf_out) except -1:
|
||||
"""
|
||||
Get a pointer into the buffer provided by BytesIOReader for the
|
||||
next data item in the stream of values.
|
||||
|
||||
BEWARE:
|
||||
If the next item has a zero negative size, the pointer will be set to NULL.
|
||||
A negative size happens when the value is NULL in the database, whereas a
|
||||
zero size may happen either for legacy reasons, or for data types such as
|
||||
strings (which may be empty).
|
||||
"""
|
||||
cdef Py_ssize_t raw_val_size = read_int(reader)
|
||||
cdef char *ptr
|
||||
if raw_val_size <= 0:
|
||||
ptr = NULL
|
||||
else:
|
||||
ptr = reader.read(raw_val_size)
|
||||
|
||||
from_ptr_and_size(ptr, raw_val_size, buf_out)
|
||||
return 0
|
||||
|
||||
cdef inline int32_t read_int(BytesIOReader reader) except ?0xDEAD:
|
||||
cdef Buffer buf
|
||||
buf.ptr = reader.read(4)
|
||||
buf.size = 4
|
||||
return int32_unpack(&buf)
|
||||
1
cassandra/numpyFlags.h
Normal file
1
cassandra/numpyFlags.h
Normal file
@@ -0,0 +1 @@
|
||||
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
|
||||
185
cassandra/numpy_parser.pyx
Normal file
185
cassandra/numpy_parser.pyx
Normal file
@@ -0,0 +1,185 @@
|
||||
# Copyright 2013-2015 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This module provider an optional protocol parser that returns
|
||||
NumPy arrays.
|
||||
|
||||
=============================================================================
|
||||
This module should not be imported by any of the main python-driver modules,
|
||||
as numpy is an optional dependency.
|
||||
=============================================================================
|
||||
"""
|
||||
|
||||
include "ioutils.pyx"
|
||||
|
||||
cimport cython
|
||||
from libc.stdint cimport uint64_t
|
||||
from cpython.ref cimport Py_INCREF, PyObject
|
||||
|
||||
from cassandra.bytesio cimport BytesIOReader
|
||||
from cassandra.deserializers cimport Deserializer, from_binary
|
||||
from cassandra.parsing cimport ParseDesc, ColumnParser, RowParser
|
||||
from cassandra import cqltypes
|
||||
from cassandra.util import is_little_endian
|
||||
|
||||
import numpy as np
|
||||
# import pandas as pd
|
||||
|
||||
cdef extern from "numpyFlags.h":
|
||||
# Include 'numpyFlags.h' into the generated C code to disable the
|
||||
# deprecated NumPy API
|
||||
pass
|
||||
|
||||
cdef extern from "Python.h":
|
||||
# An integer type large enough to hold a pointer
|
||||
ctypedef uint64_t Py_uintptr_t
|
||||
|
||||
|
||||
# Simple array descriptor, useful to parse rows into a NumPy array
|
||||
ctypedef struct ArrDesc:
|
||||
Py_uintptr_t buf_ptr
|
||||
int stride # should be large enough as we allocate contiguous arrays
|
||||
int is_object
|
||||
|
||||
arrDescDtype = np.dtype(
|
||||
[ ('buf_ptr', np.uintp)
|
||||
, ('stride', np.dtype('i'))
|
||||
, ('is_object', np.dtype('i'))
|
||||
])
|
||||
|
||||
_cqltype_to_numpy = {
|
||||
cqltypes.LongType: np.dtype('>i8'),
|
||||
cqltypes.CounterColumnType: np.dtype('>i8'),
|
||||
cqltypes.Int32Type: np.dtype('>i4'),
|
||||
cqltypes.ShortType: np.dtype('>i2'),
|
||||
cqltypes.FloatType: np.dtype('>f4'),
|
||||
cqltypes.DoubleType: np.dtype('>f8'),
|
||||
}
|
||||
|
||||
obj_dtype = np.dtype('O')
|
||||
|
||||
|
||||
cdef class NumpyParser(ColumnParser):
|
||||
"""Decode a ResultMessage into a bunch of NumPy arrays"""
|
||||
|
||||
cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc):
|
||||
cdef Py_ssize_t rowcount
|
||||
cdef ArrDesc[::1] array_descs
|
||||
cdef ArrDesc *arrs
|
||||
|
||||
rowcount = read_int(reader)
|
||||
array_descs, arrays = make_arrays(desc, rowcount)
|
||||
arrs = &array_descs[0]
|
||||
|
||||
_parse_rows(reader, desc, arrs, rowcount)
|
||||
|
||||
arrays = [make_native_byteorder(arr) for arr in arrays]
|
||||
result = dict(zip(desc.colnames, arrays))
|
||||
return result
|
||||
|
||||
|
||||
cdef _parse_rows(BytesIOReader reader, ParseDesc desc,
|
||||
ArrDesc *arrs, Py_ssize_t rowcount):
|
||||
cdef Py_ssize_t i
|
||||
|
||||
for i in range(rowcount):
|
||||
unpack_row(reader, desc, arrs)
|
||||
|
||||
|
||||
### Helper functions to create NumPy arrays and array descriptors
|
||||
|
||||
def make_arrays(ParseDesc desc, array_size):
|
||||
"""
|
||||
Allocate arrays for each result column.
|
||||
|
||||
returns a tuple of (array_descs, arrays), where
|
||||
'array_descs' describe the arrays for NativeRowParser and
|
||||
'arrays' is a dict mapping column names to arrays
|
||||
(e.g. this can be fed into pandas.DataFrame)
|
||||
"""
|
||||
array_descs = np.empty((desc.rowsize,), arrDescDtype)
|
||||
arrays = []
|
||||
|
||||
for i, coltype in enumerate(desc.coltypes):
|
||||
arr = make_array(coltype, array_size)
|
||||
array_descs[i]['buf_ptr'] = arr.ctypes.data
|
||||
array_descs[i]['stride'] = arr.strides[0]
|
||||
array_descs[i]['is_object'] = coltype not in _cqltype_to_numpy
|
||||
arrays.append(arr)
|
||||
|
||||
return array_descs, arrays
|
||||
|
||||
|
||||
def make_array(coltype, array_size):
|
||||
"""
|
||||
Allocate a new NumPy array of the given column type and size.
|
||||
"""
|
||||
dtype = _cqltype_to_numpy.get(coltype, obj_dtype)
|
||||
return np.empty((array_size,), dtype=dtype)
|
||||
|
||||
|
||||
#### Parse rows into NumPy arrays
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef inline int unpack_row(
|
||||
BytesIOReader reader, ParseDesc desc, ArrDesc *arrays) except -1:
|
||||
cdef Buffer buf
|
||||
cdef Py_ssize_t i, rowsize = desc.rowsize
|
||||
cdef ArrDesc arr
|
||||
cdef Deserializer deserializer
|
||||
|
||||
for i in range(rowsize):
|
||||
get_buf(reader, &buf)
|
||||
arr = arrays[i]
|
||||
|
||||
if buf.size == 0:
|
||||
raise ValueError("Cannot handle NULL value")
|
||||
if arr.is_object:
|
||||
deserializer = desc.deserializers[i]
|
||||
val = from_binary(deserializer, &buf, desc.protocol_version)
|
||||
Py_INCREF(val)
|
||||
(<PyObject **> arr.buf_ptr)[0] = <PyObject *> val
|
||||
else:
|
||||
memcopy(buf.ptr, <char *> arr.buf_ptr, buf.size)
|
||||
|
||||
# Update the pointer into the array for the next time
|
||||
arrays[i].buf_ptr += arr.stride
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
cdef inline void memcopy(char *src, char *dst, Py_ssize_t size):
|
||||
"""
|
||||
Our own simple memcopy which can be inlined. This is useful because our data types
|
||||
are only a few bytes.
|
||||
"""
|
||||
cdef Py_ssize_t i
|
||||
for i in range(size):
|
||||
dst[i] = src[i]
|
||||
|
||||
|
||||
def make_native_byteorder(arr):
|
||||
"""
|
||||
Make sure all values have a native endian in the NumPy arrays.
|
||||
"""
|
||||
if is_little_endian and not arr.dtype.kind == 'O':
|
||||
# We have arrays in big-endian order. First swap the bytes
|
||||
# into little endian order, and then update the numpy dtype
|
||||
# accordingly (e.g. from '>i8' to '<i8')
|
||||
#
|
||||
# Ignore any object arrays of dtype('O')
|
||||
return arr.byteswap().newbyteorder()
|
||||
return arr
|
||||
75
cassandra/obj_parser.pyx
Normal file
75
cassandra/obj_parser.pyx
Normal file
@@ -0,0 +1,75 @@
|
||||
# Copyright 2013-2015 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
include "ioutils.pyx"
|
||||
|
||||
from cassandra.bytesio cimport BytesIOReader
|
||||
from cassandra.deserializers cimport Deserializer, from_binary
|
||||
from cassandra.parsing cimport ParseDesc, ColumnParser, RowParser
|
||||
from cassandra.tuple cimport tuple_new, tuple_set
|
||||
|
||||
|
||||
cdef class ListParser(ColumnParser):
|
||||
"""Decode a ResultMessage into a list of tuples (or other objects)"""
|
||||
|
||||
cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc):
|
||||
cdef Py_ssize_t i, rowcount
|
||||
rowcount = read_int(reader)
|
||||
cdef RowParser rowparser = TupleRowParser()
|
||||
return [rowparser.unpack_row(reader, desc) for i in range(rowcount)]
|
||||
|
||||
|
||||
cdef class LazyParser(ColumnParser):
|
||||
"""Decode a ResultMessage lazily using a generator"""
|
||||
|
||||
cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc):
|
||||
# Use a little helper function as closures (generators) are not
|
||||
# supported in cpdef methods
|
||||
return parse_rows_lazy(reader, desc)
|
||||
|
||||
|
||||
def parse_rows_lazy(BytesIOReader reader, ParseDesc desc):
|
||||
cdef Py_ssize_t i, rowcount
|
||||
rowcount = read_int(reader)
|
||||
cdef RowParser rowparser = TupleRowParser()
|
||||
return (rowparser.unpack_row(reader, desc) for i in range(rowcount))
|
||||
|
||||
|
||||
cdef class TupleRowParser(RowParser):
|
||||
"""
|
||||
Parse a single returned row into a tuple of objects:
|
||||
|
||||
(obj1, ..., objN)
|
||||
"""
|
||||
|
||||
cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc):
|
||||
assert desc.rowsize >= 0
|
||||
|
||||
cdef Buffer buf
|
||||
cdef Py_ssize_t i, rowsize = desc.rowsize
|
||||
cdef Deserializer deserializer
|
||||
cdef tuple res = tuple_new(desc.rowsize)
|
||||
|
||||
for i in range(rowsize):
|
||||
# Read the next few bytes
|
||||
get_buf(reader, &buf)
|
||||
|
||||
# Deserialize bytes to python object
|
||||
deserializer = desc.deserializers[i]
|
||||
val = from_binary(deserializer, &buf, desc.protocol_version)
|
||||
|
||||
# Insert new object into tuple
|
||||
tuple_set(res, i, val)
|
||||
|
||||
return res
|
||||
30
cassandra/parsing.pxd
Normal file
30
cassandra/parsing.pxd
Normal file
@@ -0,0 +1,30 @@
|
||||
# Copyright 2013-2015 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from cassandra.bytesio cimport BytesIOReader
|
||||
from cassandra.deserializers cimport Deserializer
|
||||
|
||||
cdef class ParseDesc:
|
||||
cdef public object colnames
|
||||
cdef public object coltypes
|
||||
cdef Deserializer[::1] deserializers
|
||||
cdef public int protocol_version
|
||||
cdef Py_ssize_t rowsize
|
||||
|
||||
cdef class ColumnParser:
|
||||
cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc)
|
||||
|
||||
cdef class RowParser:
|
||||
cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc)
|
||||
|
||||
44
cassandra/parsing.pyx
Normal file
44
cassandra/parsing.pyx
Normal file
@@ -0,0 +1,44 @@
|
||||
# Copyright 2013-2015 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Module containing the definitions and declarations (parsing.pxd) for parsers.
|
||||
"""
|
||||
|
||||
cdef class ParseDesc:
|
||||
"""Description of what structure to parse"""
|
||||
|
||||
def __init__(self, colnames, coltypes, deserializers, protocol_version):
|
||||
self.colnames = colnames
|
||||
self.coltypes = coltypes
|
||||
self.deserializers = deserializers
|
||||
self.protocol_version = protocol_version
|
||||
self.rowsize = len(colnames)
|
||||
|
||||
|
||||
cdef class ColumnParser:
|
||||
"""Decode a ResultMessage into a set of columns"""
|
||||
|
||||
cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
cdef class RowParser:
|
||||
"""Parser for a single row"""
|
||||
|
||||
cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc):
|
||||
"""
|
||||
Unpack a single row of data in a ResultMessage.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -22,6 +22,7 @@ import six
|
||||
from six.moves import range
|
||||
import io
|
||||
|
||||
from cassandra import type_codes
|
||||
from cassandra import (Unavailable, WriteTimeout, ReadTimeout,
|
||||
WriteFailure, ReadFailure, FunctionFailure,
|
||||
AlreadyExists, InvalidRequest, Unauthorized,
|
||||
@@ -35,10 +36,11 @@ from cassandra.cqltypes import (AsciiType, BytesType, BooleanType,
|
||||
DoubleType, FloatType, Int32Type,
|
||||
InetAddressType, IntegerType, ListType,
|
||||
LongType, MapType, SetType, TimeUUIDType,
|
||||
UTF8Type, UUIDType, UserType,
|
||||
UTF8Type, VarcharType, UUIDType, UserType,
|
||||
TupleType, lookup_casstype, SimpleDateType,
|
||||
TimeType, ByteType, ShortType)
|
||||
from cassandra.policies import WriteType
|
||||
from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY
|
||||
from cassandra import util
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -68,10 +70,16 @@ _message_types_by_opcode = {}
|
||||
|
||||
_UNSET_VALUE = object()
|
||||
|
||||
def register_class(cls):
|
||||
_message_types_by_opcode[cls.opcode] = cls
|
||||
|
||||
def get_registered_classes():
|
||||
return _message_types_by_opcode.copy()
|
||||
|
||||
class _RegisterMessageType(type):
|
||||
def __init__(cls, name, bases, dct):
|
||||
if not name.startswith('_'):
|
||||
_message_types_by_opcode[cls.opcode] = cls
|
||||
register_class(cls)
|
||||
|
||||
|
||||
@six.add_metaclass(_RegisterMessageType)
|
||||
@@ -531,35 +539,6 @@ RESULT_KIND_SET_KEYSPACE = 0x0003
|
||||
RESULT_KIND_PREPARED = 0x0004
|
||||
RESULT_KIND_SCHEMA_CHANGE = 0x0005
|
||||
|
||||
class CassandraTypeCodes(object):
|
||||
CUSTOM_TYPE = 0x0000
|
||||
AsciiType = 0x0001
|
||||
LongType = 0x0002
|
||||
BytesType = 0x0003
|
||||
BooleanType = 0x0004
|
||||
CounterColumnType = 0x0005
|
||||
DecimalType = 0x0006
|
||||
DoubleType = 0x0007
|
||||
FloatType = 0x0008
|
||||
Int32Type = 0x0009
|
||||
UTF8Type = 0x000A
|
||||
DateType = 0x000B
|
||||
UUIDType = 0x000C
|
||||
UTF8Type = 0x000D
|
||||
IntegerType = 0x000E
|
||||
TimeUUIDType = 0x000F
|
||||
InetAddressType = 0x0010
|
||||
SimpleDateType = 0x0011
|
||||
TimeType = 0x0012
|
||||
ShortType = 0x0013
|
||||
ByteType = 0x0014
|
||||
ListType = 0x0020
|
||||
MapType = 0x0021
|
||||
SetType = 0x0022
|
||||
UserType = 0x0030
|
||||
TupleType = 0x0031
|
||||
|
||||
|
||||
class ResultMessage(_MessageType):
|
||||
opcode = 0x08
|
||||
name = 'RESULT'
|
||||
@@ -569,7 +548,7 @@ class ResultMessage(_MessageType):
|
||||
paging_state = None
|
||||
|
||||
# Names match type name in module scope. Most are imported from cassandra.cqltypes (except CUSTOM_TYPE)
|
||||
type_codes = _cqltypes_by_code = dict((v, globals()[k]) for k, v in CassandraTypeCodes.__dict__.items() if not k.startswith('_'))
|
||||
type_codes = _cqltypes_by_code = dict((v, globals()[k]) for k, v in type_codes.__dict__.items() if not k.startswith('_'))
|
||||
|
||||
_FLAGS_GLOBAL_TABLES_SPEC = 0x0001
|
||||
_HAS_MORE_PAGES_FLAG = 0x0002
|
||||
@@ -1017,6 +996,63 @@ class ProtocolHandler(object):
|
||||
return msg
|
||||
|
||||
|
||||
def cython_protocol_handler(colparser):
|
||||
"""
|
||||
Given a column parser to deserialize ResultMessages, return a suitable
|
||||
Cython-based protocol handler.
|
||||
|
||||
There are three Cython-based protocol handlers (least to most performant):
|
||||
|
||||
1. obj_parser.ListParser
|
||||
this parser decodes result messages into a list of tuples
|
||||
|
||||
2. obj_parser.LazyParser
|
||||
this parser decodes result messages lazily by returning an iterator
|
||||
|
||||
3. numpy_parser.NumPyParser
|
||||
this parser decodes result messages into NumPy arrays
|
||||
|
||||
The default is to use obj_parser.ListParser
|
||||
"""
|
||||
from cassandra.row_parser import make_recv_results_rows
|
||||
|
||||
class FastResultMessage(ResultMessage):
|
||||
"""
|
||||
Cython version of Result Message that has a faster implementation of
|
||||
recv_results_row.
|
||||
"""
|
||||
# type_codes = ResultMessage.type_codes.copy()
|
||||
code_to_type = dict((v, k) for k, v in ResultMessage.type_codes.items())
|
||||
recv_results_rows = classmethod(make_recv_results_rows(colparser))
|
||||
|
||||
class CythonProtocolHandler(ProtocolHandler):
|
||||
"""
|
||||
Use FastResultMessage to decode query result message messages.
|
||||
"""
|
||||
|
||||
my_opcodes = ProtocolHandler.message_types_by_opcode.copy()
|
||||
my_opcodes[FastResultMessage.opcode] = FastResultMessage
|
||||
message_types_by_opcode = my_opcodes
|
||||
|
||||
return CythonProtocolHandler
|
||||
|
||||
|
||||
if HAVE_CYTHON:
|
||||
from cassandra.obj_parser import ListParser, LazyParser
|
||||
ProtocolHandler = cython_protocol_handler(ListParser())
|
||||
LazyProtocolHandler = cython_protocol_handler(LazyParser())
|
||||
else:
|
||||
# Use Python-based ProtocolHandler
|
||||
LazyProtocolHandler = None
|
||||
|
||||
|
||||
if HAVE_CYTHON and HAVE_NUMPY:
|
||||
from cassandra.numpy_parser import NumpyParser
|
||||
NumpyProtocolHandler = cython_protocol_handler(NumpyParser())
|
||||
else:
|
||||
NumpyProtocolHandler = None
|
||||
|
||||
|
||||
def read_byte(f):
|
||||
return int8_unpack(f.read(1))
|
||||
|
||||
|
||||
38
cassandra/row_parser.pyx
Normal file
38
cassandra/row_parser.pyx
Normal file
@@ -0,0 +1,38 @@
|
||||
# Copyright 2013-2015 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from cassandra.parsing cimport ParseDesc, ColumnParser
|
||||
from cassandra.deserializers import make_deserializers
|
||||
|
||||
include "ioutils.pyx"
|
||||
|
||||
def make_recv_results_rows(ColumnParser colparser):
|
||||
def recv_results_rows(cls, f, int protocol_version, user_type_map):
|
||||
"""
|
||||
Parse protocol data given as a BytesIO f into a set of columns (e.g. list of tuples)
|
||||
This is used as the recv_results_rows method of (Fast)ResultMessage
|
||||
"""
|
||||
paging_state, column_metadata = cls.recv_results_metadata(f, user_type_map)
|
||||
|
||||
colnames = [c[2] for c in column_metadata]
|
||||
coltypes = [c[3] for c in column_metadata]
|
||||
|
||||
desc = ParseDesc(colnames, coltypes, make_deserializers(coltypes),
|
||||
protocol_version)
|
||||
reader = BytesIOReader(f.read())
|
||||
parsed_rows = colparser.parse_rows(reader, desc)
|
||||
|
||||
return (paging_state, (colnames, parsed_rows))
|
||||
|
||||
return recv_results_rows
|
||||
41
cassandra/tuple.pxd
Normal file
41
cassandra/tuple.pxd
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright 2013-2015 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from cpython.tuple cimport (
|
||||
PyTuple_New,
|
||||
# Return value: New reference.
|
||||
# Return a new tuple object of size len, or NULL on failure.
|
||||
PyTuple_SET_ITEM,
|
||||
# Like PyTuple_SetItem(), but does no error checking, and should
|
||||
# only be used to fill in brand new tuples. Note: This function
|
||||
# ``steals'' a reference to o.
|
||||
)
|
||||
|
||||
from cpython.ref cimport (
|
||||
Py_INCREF
|
||||
# void Py_INCREF(object o)
|
||||
# Increment the reference count for object o. The object must not
|
||||
# be NULL; if you aren't sure that it isn't NULL, use
|
||||
# Py_XINCREF().
|
||||
)
|
||||
|
||||
cdef inline tuple tuple_new(Py_ssize_t n):
|
||||
"""Allocate a new tuple object"""
|
||||
return PyTuple_New(n)
|
||||
|
||||
cdef inline void tuple_set(tuple tup, Py_ssize_t idx, object item):
|
||||
"""Insert new object into tuple. No item must have been set yet."""
|
||||
# PyTuple_SET_ITEM steals a reference, so we need to INCREF
|
||||
Py_INCREF(item)
|
||||
PyTuple_SET_ITEM(tup, idx, item)
|
||||
42
cassandra/type_codes.pxd
Normal file
42
cassandra/type_codes.pxd
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright 2013-2015 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
cdef enum:
|
||||
CUSTOM_TYPE
|
||||
AsciiType
|
||||
LongType
|
||||
BytesType
|
||||
BooleanType
|
||||
CounterColumnType
|
||||
DecimalType
|
||||
DoubleType
|
||||
FloatType
|
||||
Int32Type
|
||||
UTF8Type
|
||||
DateType
|
||||
UUIDType
|
||||
VarcharType
|
||||
IntegerType
|
||||
TimeUUIDType
|
||||
InetAddressType
|
||||
SimpleDateType
|
||||
TimeType
|
||||
ShortType
|
||||
ByteType
|
||||
ListType
|
||||
MapType
|
||||
SetType
|
||||
UserType
|
||||
TupleType
|
||||
|
||||
62
cassandra/type_codes.py
Normal file
62
cassandra/type_codes.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Module with constants for Cassandra type codes.
|
||||
|
||||
These constants are useful for
|
||||
|
||||
a) mapping messages to cqltypes (cassandra/cqltypes.py)
|
||||
b) optimized dispatching for (de)serialization (cassandra/encoding.py)
|
||||
|
||||
Type codes are repeated here from the Cassandra binary protocol specification:
|
||||
|
||||
0x0000 Custom: the value is a [string], see above.
|
||||
0x0001 Ascii
|
||||
0x0002 Bigint
|
||||
0x0003 Blob
|
||||
0x0004 Boolean
|
||||
0x0005 Counter
|
||||
0x0006 Decimal
|
||||
0x0007 Double
|
||||
0x0008 Float
|
||||
0x0009 Int
|
||||
0x000A Text
|
||||
0x000B Timestamp
|
||||
0x000C Uuid
|
||||
0x000D Varchar
|
||||
0x000E Varint
|
||||
0x000F Timeuuid
|
||||
0x0010 Inet
|
||||
0x0020 List: the value is an [option], representing the type
|
||||
of the elements of the list.
|
||||
0x0021 Map: the value is two [option], representing the types of the
|
||||
keys and values of the map
|
||||
0x0022 Set: the value is an [option], representing the type
|
||||
of the elements of the set
|
||||
"""
|
||||
|
||||
CUSTOM_TYPE = 0x0000
|
||||
AsciiType = 0x0001
|
||||
LongType = 0x0002
|
||||
BytesType = 0x0003
|
||||
BooleanType = 0x0004
|
||||
CounterColumnType = 0x0005
|
||||
DecimalType = 0x0006
|
||||
DoubleType = 0x0007
|
||||
FloatType = 0x0008
|
||||
Int32Type = 0x0009
|
||||
UTF8Type = 0x000A
|
||||
DateType = 0x000B
|
||||
UUIDType = 0x000C
|
||||
VarcharType = 0x000D
|
||||
IntegerType = 0x000E
|
||||
TimeUUIDType = 0x000F
|
||||
InetAddressType = 0x0010
|
||||
SimpleDateType = 0x0011
|
||||
TimeType = 0x0012
|
||||
ShortType = 0x0013
|
||||
ByteType = 0x0014
|
||||
ListType = 0x0020
|
||||
MapType = 0x0021
|
||||
SetType = 0x0022
|
||||
UserType = 0x0030
|
||||
TupleType = 0x0031
|
||||
|
||||
@@ -1,12 +1,29 @@
|
||||
# Copyright 2013-2015 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import with_statement
|
||||
import calendar
|
||||
import datetime
|
||||
import random
|
||||
import six
|
||||
import uuid
|
||||
import sys
|
||||
|
||||
DATETIME_EPOC = datetime.datetime(1970, 1, 1)
|
||||
|
||||
assert sys.byteorder in ('little', 'big')
|
||||
is_little_endian = sys.byteorder == 'little'
|
||||
|
||||
def datetime_from_timestamp(timestamp):
|
||||
"""
|
||||
@@ -490,8 +507,7 @@ except ImportError:
|
||||
|
||||
def __init__(self, iterable=()):
|
||||
self._items = []
|
||||
for i in iterable:
|
||||
self.add(i)
|
||||
self.update(iterable)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._items)
|
||||
@@ -564,6 +580,10 @@ except ImportError:
|
||||
else:
|
||||
self._items.append(item)
|
||||
|
||||
def update(self, iterable):
|
||||
for i in iterable:
|
||||
self.add(i)
|
||||
|
||||
def clear(self):
|
||||
del self._items[:]
|
||||
|
||||
|
||||
@@ -24,3 +24,17 @@ See :meth:`.Session.execute`, ::meth:`.Session.execute_async`, :attr:`.ResponseF
|
||||
.. automethod:: encode_message
|
||||
|
||||
.. automethod:: decode_message
|
||||
|
||||
Faster Deserialization
|
||||
----------------------
|
||||
When python-driver is compiled with Cython, it uses a Cython-based deserialization path
|
||||
to deserialize messages. There are two additional ProtocolHandler classes that can be
|
||||
used to deserialize response messages: the first is ``LazyProtocolHandler`` and the
|
||||
second is ``NumpyProtocolHandler``.They can be used as follows:
|
||||
|
||||
.. code:: python
|
||||
|
||||
from cassandra.protocol import NumpyProtocolHandler, LazyProtocolHandler
|
||||
s.client_protocol_handler = LazyProtocolHandler # for a result iterator
|
||||
s.client_protocol_handler = NumpyProtocolHandler # for a dict of NumPy arrays as result
|
||||
|
||||
|
||||
11
setup.py
11
setup.py
@@ -37,7 +37,6 @@ from distutils.errors import (CCompilerError, DistutilsPlatformError,
|
||||
DistutilsExecError)
|
||||
from distutils.cmd import Command
|
||||
|
||||
|
||||
try:
|
||||
import subprocess
|
||||
has_subprocess = True
|
||||
@@ -71,6 +70,7 @@ if __name__ == '__main__' and sys.argv[1] == "install":
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
PROFILING = False
|
||||
|
||||
class DocCommand(Command):
|
||||
|
||||
@@ -263,11 +263,16 @@ if "--no-libev" not in sys.argv and not is_windows:
|
||||
if "--no-cython" not in sys.argv:
|
||||
try:
|
||||
from Cython.Build import cythonize
|
||||
cython_candidates = ['cluster', 'concurrent', 'connection', 'cqltypes', 'marshal', 'metadata', 'pool', 'protocol', 'query', 'util']
|
||||
cython_candidates = ['cluster', 'concurrent', 'connection', 'cqltypes', 'metadata',
|
||||
'pool', 'protocol', 'query', 'util']
|
||||
compile_args = [] if is_windows else ['-Wno-unused-function']
|
||||
extensions.extend(cythonize(
|
||||
[Extension('cassandra.%s' % m, ['cassandra/%s.py' % m], extra_compile_args=compile_args) for m in cython_candidates],
|
||||
[Extension('cassandra.%s' % m, ['cassandra/%s.py' % m],
|
||||
extra_compile_args=compile_args)
|
||||
for m in cython_candidates],
|
||||
exclude_failures=True))
|
||||
extensions.extend(cythonize("cassandra/*.pyx"))
|
||||
extensions.extend(cythonize("tests/unit/cython/*.pyx"))
|
||||
except ImportError:
|
||||
sys.stderr.write("Cython is not installed. Not compiling core driver files as extensions (optional).")
|
||||
|
||||
|
||||
@@ -160,7 +160,7 @@ def remove_cluster():
|
||||
CCM_CLUSTER.remove()
|
||||
CCM_CLUSTER = None
|
||||
return
|
||||
except WindowsError:
|
||||
except OSError:
|
||||
ex_type, ex, tb = sys.exc_info()
|
||||
log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb)))
|
||||
del tb
|
||||
|
||||
@@ -386,7 +386,8 @@ class TestMapColumn(BaseCassEngTestCase):
|
||||
k2 = uuid4()
|
||||
now = datetime.now()
|
||||
then = now + timedelta(days=1)
|
||||
m1 = TestMapModel.create(int_map={1: k1, 2: k2}, text_map={'now': now, 'then': then})
|
||||
m1 = TestMapModel.create(int_map={1: k1, 2: k2},
|
||||
text_map={'now': now, 'then': then})
|
||||
m2 = TestMapModel.get(partition=m1.partition)
|
||||
|
||||
self.assertTrue(isinstance(m2.int_map, dict))
|
||||
|
||||
@@ -629,7 +629,7 @@ class TestMinMaxTimeUUIDFunctions(BaseCassEngTestCase):
|
||||
# test kwarg filtering
|
||||
q = TimeUUIDQueryModel.filter(partition=pk, time__lte=functions.MaxTimeUUID(midpoint))
|
||||
q = [d for d in q]
|
||||
assert len(q) == 2
|
||||
self.assertEqual(len(q), 2, msg="Got: %s" % q)
|
||||
datas = [d.data for d in q]
|
||||
assert '1' in datas
|
||||
assert '2' in datas
|
||||
|
||||
@@ -52,17 +52,17 @@ class QueryUpdateTests(BaseCassEngTestCase):
|
||||
|
||||
# sanity check
|
||||
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
||||
assert row.cluster == i
|
||||
assert row.count == i
|
||||
assert row.text == str(i)
|
||||
self.assertEqual(row.cluster, i)
|
||||
self.assertEqual(row.count, i)
|
||||
self.assertEqual(row.text, str(i))
|
||||
|
||||
# perform update
|
||||
TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6)
|
||||
|
||||
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
||||
assert row.cluster == i
|
||||
assert row.count == (6 if i == 3 else i)
|
||||
assert row.text == str(i)
|
||||
self.assertEqual(row.cluster, i)
|
||||
self.assertEqual(row.count, 6 if i == 3 else i)
|
||||
self.assertEqual(row.text, str(i))
|
||||
|
||||
def test_update_values_validation(self):
|
||||
""" tests calling udpate on models with values passed in """
|
||||
@@ -72,9 +72,9 @@ class QueryUpdateTests(BaseCassEngTestCase):
|
||||
|
||||
# sanity check
|
||||
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
||||
assert row.cluster == i
|
||||
assert row.count == i
|
||||
assert row.text == str(i)
|
||||
self.assertEqual(row.cluster, i)
|
||||
self.assertEqual(row.count, i)
|
||||
self.assertEqual(row.text, str(i))
|
||||
|
||||
# perform update
|
||||
with self.assertRaises(ValidationError):
|
||||
@@ -98,17 +98,17 @@ class QueryUpdateTests(BaseCassEngTestCase):
|
||||
|
||||
# sanity check
|
||||
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
||||
assert row.cluster == i
|
||||
assert row.count == i
|
||||
assert row.text == str(i)
|
||||
self.assertEqual(row.cluster, i)
|
||||
self.assertEqual(row.count, i)
|
||||
self.assertEqual(row.text, str(i))
|
||||
|
||||
# perform update
|
||||
TestQueryUpdateModel.objects(partition=partition, cluster=3).update(text=None)
|
||||
|
||||
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
||||
assert row.cluster == i
|
||||
assert row.count == i
|
||||
assert row.text == (None if i == 3 else str(i))
|
||||
self.assertEqual(row.cluster, i)
|
||||
self.assertEqual(row.count, i)
|
||||
self.assertEqual(row.text, None if i == 3 else str(i))
|
||||
|
||||
def test_mixed_value_and_null_update(self):
|
||||
""" tests that updating a columns value, and removing another works properly """
|
||||
@@ -118,17 +118,17 @@ class QueryUpdateTests(BaseCassEngTestCase):
|
||||
|
||||
# sanity check
|
||||
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
||||
assert row.cluster == i
|
||||
assert row.count == i
|
||||
assert row.text == str(i)
|
||||
self.assertEqual(row.cluster, i)
|
||||
self.assertEqual(row.count, i)
|
||||
self.assertEqual(row.text, str(i))
|
||||
|
||||
# perform update
|
||||
TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6, text=None)
|
||||
|
||||
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
||||
assert row.cluster == i
|
||||
assert row.count == (6 if i == 3 else i)
|
||||
assert row.text == (None if i == 3 else str(i))
|
||||
self.assertEqual(row.cluster, i)
|
||||
self.assertEqual(row.count, 6 if i == 3 else i)
|
||||
self.assertEqual(row.text, None if i == 3 else str(i))
|
||||
|
||||
def test_counter_updates(self):
|
||||
pass
|
||||
|
||||
@@ -85,11 +85,11 @@ class SchemaTests(unittest.TestCase):
|
||||
|
||||
session = self.session
|
||||
|
||||
for i in xrange(30):
|
||||
for i in range(30):
|
||||
execute_until_pass(session, "CREATE KEYSPACE test_{0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}".format(i))
|
||||
execute_until_pass(session, "CREATE TABLE test_{0}.cf (key int PRIMARY KEY, value int)".format(i))
|
||||
|
||||
for j in xrange(100):
|
||||
for j in range(100):
|
||||
execute_until_pass(session, "INSERT INTO test_{0}.cf (key, value) VALUES ({1}, {1})".format(i, j))
|
||||
|
||||
execute_until_pass(session, "DROP KEYSPACE test_{0}".format(i))
|
||||
@@ -102,7 +102,7 @@ class SchemaTests(unittest.TestCase):
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect()
|
||||
|
||||
for i in xrange(30):
|
||||
for i in range(30):
|
||||
try:
|
||||
execute_until_pass(session, "CREATE KEYSPACE test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}")
|
||||
except AlreadyExists:
|
||||
@@ -111,7 +111,7 @@ class SchemaTests(unittest.TestCase):
|
||||
|
||||
execute_until_pass(session, "CREATE TABLE test.cf (key int PRIMARY KEY, value int)")
|
||||
|
||||
for j in xrange(100):
|
||||
for j in range(100):
|
||||
execute_until_pass(session, "INSERT INTO test.cf (key, value) VALUES ({0}, {0})".format(j))
|
||||
|
||||
execute_until_pass(session, "DROP KEYSPACE test")
|
||||
|
||||
@@ -11,6 +11,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
try:
|
||||
from ccmlib import common
|
||||
except ImportError as e:
|
||||
|
||||
@@ -25,6 +25,8 @@ from cassandra.query import tuple_factory, SimpleStatement
|
||||
|
||||
from tests.integration import use_singledc, PROTOCOL_VERSION
|
||||
|
||||
from six import next
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
|
||||
@@ -21,7 +21,8 @@ from cassandra.protocol import ProtocolHandler, ResultMessage, UUIDType, read_in
|
||||
from cassandra.query import tuple_factory
|
||||
from cassandra.cluster import Cluster
|
||||
from tests.integration import use_singledc, PROTOCOL_VERSION, execute_until_pass
|
||||
from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, get_sample
|
||||
from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES
|
||||
from tests.integration.standard.utils import create_table_with_all_types, get_all_primitive_params
|
||||
from six import binary_type
|
||||
|
||||
import uuid
|
||||
@@ -62,7 +63,7 @@ class CustomProtocolHandlerTest(unittest.TestCase):
|
||||
"""
|
||||
|
||||
# Ensure that we get normal uuid back first
|
||||
session = Cluster().connect()
|
||||
session = Cluster(protocol_version=PROTOCOL_VERSION).connect(keyspace="custserdes")
|
||||
session.row_factory = tuple_factory
|
||||
result_set = session.execute("SELECT schema_version FROM system.local")
|
||||
result = result_set.pop()
|
||||
@@ -102,15 +103,16 @@ class CustomProtocolHandlerTest(unittest.TestCase):
|
||||
@test_category data_types:serialization
|
||||
"""
|
||||
# Connect using a custom protocol handler that tracks the various types the result message is used with.
|
||||
session = Cluster().connect(keyspace="custserdes")
|
||||
session = Cluster(protocol_version=PROTOCOL_VERSION).connect(keyspace="custserdes")
|
||||
session.client_protocol_handler = CustomProtocolHandlerResultMessageTracked
|
||||
session.row_factory = tuple_factory
|
||||
|
||||
columns_string = create_table_with_all_types("test_table", session)
|
||||
colnames = create_table_with_all_types("alltypes", session, 1)
|
||||
columns_string = ", ".join(colnames)
|
||||
|
||||
# verify data
|
||||
params = get_all_primitive_params()
|
||||
results = session.execute("SELECT {0} FROM alltypes WHERE zz=0".format(columns_string))[0]
|
||||
params = get_all_primitive_params(0)
|
||||
results = session.execute("SELECT {0} FROM alltypes WHERE primkey=0".format(columns_string))[0]
|
||||
for expected, actual in zip(params, results):
|
||||
self.assertEqual(actual, expected)
|
||||
# Ensure we have covered the various primitive types
|
||||
@@ -118,43 +120,6 @@ class CustomProtocolHandlerTest(unittest.TestCase):
|
||||
session.shutdown()
|
||||
|
||||
|
||||
def create_table_with_all_types(table_name, session):
|
||||
"""
|
||||
Method that given a table_name and session construct a table that contains all possible primitive types
|
||||
:param table_name: Name of table to create
|
||||
:param session: session to use for table creation
|
||||
:return: a string containing and columns. This can be used to query the table.
|
||||
"""
|
||||
# create table
|
||||
alpha_type_list = ["zz int PRIMARY KEY"]
|
||||
col_names = ["zz"]
|
||||
start_index = ord('a')
|
||||
for i, datatype in enumerate(PRIMITIVE_DATATYPES):
|
||||
alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype))
|
||||
col_names.append(chr(start_index + i))
|
||||
|
||||
session.execute("CREATE TABLE alltypes ({0})".format(', '.join(alpha_type_list)), timeout=120)
|
||||
|
||||
# create the input
|
||||
params = get_all_primitive_params()
|
||||
|
||||
# insert into table as a simple statement
|
||||
columns_string = ', '.join(col_names)
|
||||
placeholders = ', '.join(["%s"] * len(col_names))
|
||||
session.execute("INSERT INTO alltypes ({0}) VALUES ({1})".format(columns_string, placeholders), params, timeout=120)
|
||||
return columns_string
|
||||
|
||||
|
||||
def get_all_primitive_params():
|
||||
"""
|
||||
Simple utility method used to give back a list of all possible primitive data sample types.
|
||||
"""
|
||||
params = [0]
|
||||
for datatype in PRIMITIVE_DATATYPES:
|
||||
params.append((get_sample(datatype)))
|
||||
return params
|
||||
|
||||
|
||||
class CustomResultMessageRaw(ResultMessage):
|
||||
"""
|
||||
This is a custom Result Message that is used to return raw results, rather then
|
||||
|
||||
129
tests/integration/standard/test_cython_protocol_handlers.py
Normal file
129
tests/integration/standard/test_cython_protocol_handlers.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Test the various Cython-based message deserializers"""
|
||||
|
||||
# Based on test_custom_protocol_handler.py
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest
|
||||
|
||||
from cassandra.query import tuple_factory
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.protocol import ProtocolHandler, LazyProtocolHandler, NumpyProtocolHandler
|
||||
|
||||
from tests.integration import use_singledc, PROTOCOL_VERSION
|
||||
from tests.integration.datatype_utils import update_datatypes
|
||||
from tests.integration.standard.utils import (
|
||||
create_table_with_all_types, get_all_primitive_params, get_primitive_datatypes)
|
||||
|
||||
from tests.unit.cython.utils import cythontest, numpytest
|
||||
|
||||
def setup_module():
|
||||
use_singledc()
|
||||
update_datatypes()
|
||||
|
||||
|
||||
class CythonProtocolHandlerTest(unittest.TestCase):
|
||||
|
||||
N_ITEMS = 10
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
cls.session = cls.cluster.connect()
|
||||
cls.session.execute("CREATE KEYSPACE testspace WITH replication = "
|
||||
"{ 'class' : 'SimpleStrategy', 'replication_factor': '1'}")
|
||||
cls.session.set_keyspace("testspace")
|
||||
cls.colnames = create_table_with_all_types("test_table", cls.session, cls.N_ITEMS)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.session.execute("DROP KEYSPACE testspace")
|
||||
cls.cluster.shutdown()
|
||||
|
||||
@cythontest
|
||||
def test_cython_parser(self):
|
||||
"""
|
||||
Test Cython-based parser that returns a list of tuples
|
||||
"""
|
||||
verify_iterator_data(self.assertEqual, get_data(ProtocolHandler))
|
||||
|
||||
@cythontest
|
||||
def test_cython_lazy_parser(self):
|
||||
"""
|
||||
Test Cython-based parser that returns an iterator of tuples
|
||||
"""
|
||||
verify_iterator_data(self.assertEqual, get_data(LazyProtocolHandler))
|
||||
|
||||
@numpytest
|
||||
def test_numpy_parser(self):
|
||||
"""
|
||||
Test Numpy-based parser that returns a NumPy array
|
||||
"""
|
||||
# arrays = { 'a': arr1, 'b': arr2, ... }
|
||||
arrays = get_data(NumpyProtocolHandler)
|
||||
|
||||
colnames = self.colnames
|
||||
datatypes = get_primitive_datatypes()
|
||||
for colname, datatype in zip(colnames, datatypes):
|
||||
arr = arrays[colname]
|
||||
self.match_dtype(datatype, arr.dtype)
|
||||
|
||||
verify_iterator_data(self.assertEqual, arrays_to_list_of_tuples(arrays, colnames))
|
||||
|
||||
def match_dtype(self, datatype, dtype):
|
||||
"""Match a string cqltype (e.g. 'int' or 'blob') with a numpy dtype"""
|
||||
if datatype == 'smallint':
|
||||
self.match_dtype_props(dtype, 'i', 2)
|
||||
elif datatype == 'int':
|
||||
self.match_dtype_props(dtype, 'i', 4)
|
||||
elif datatype in ('bigint', 'counter'):
|
||||
self.match_dtype_props(dtype, 'i', 8)
|
||||
elif datatype == 'float':
|
||||
self.match_dtype_props(dtype, 'f', 4)
|
||||
elif datatype == 'double':
|
||||
self.match_dtype_props(dtype, 'f', 8)
|
||||
else:
|
||||
self.assertEqual(dtype.kind, 'O', msg=(dtype, datatype))
|
||||
|
||||
def match_dtype_props(self, dtype, kind, size, signed=None):
|
||||
self.assertEqual(dtype.kind, kind, msg=dtype)
|
||||
self.assertEqual(dtype.itemsize, size, msg=dtype)
|
||||
|
||||
|
||||
def arrays_to_list_of_tuples(arrays, colnames):
|
||||
"""Convert a dict of arrays (as given by the numpy protocol handler) to a list of tuples"""
|
||||
first_array = arrays[colnames[0]]
|
||||
return [tuple(arrays[colname][i] for colname in colnames)
|
||||
for i in range(len(first_array))]
|
||||
|
||||
|
||||
def get_data(protocol_handler):
|
||||
"""
|
||||
Get some data from the test table.
|
||||
|
||||
:param key: if None, get all results (100.000 results), otherwise get only one result
|
||||
"""
|
||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||
session = cluster.connect(keyspace="testspace")
|
||||
|
||||
# use our custom protocol handler
|
||||
session.client_protocol_handler = protocol_handler
|
||||
session.row_factory = tuple_factory
|
||||
|
||||
results = session.execute("SELECT * FROM test_table")
|
||||
session.shutdown()
|
||||
return results
|
||||
|
||||
|
||||
def verify_iterator_data(assertEqual, results):
|
||||
"""
|
||||
Check the result of get_data() when this is a list or
|
||||
iterator of tuples
|
||||
"""
|
||||
for result in results:
|
||||
params = get_all_primitive_params(result[0])
|
||||
assertEqual(len(params), len(result),
|
||||
msg="Not the right number of columns?")
|
||||
for expected, actual in zip(params, result):
|
||||
assertEqual(actual, expected)
|
||||
@@ -301,8 +301,8 @@ class BatchStatementTests(unittest.TestCase):
|
||||
keys.add(result.k)
|
||||
values.add(result.v)
|
||||
|
||||
self.assertEqual(set(range(10)), keys)
|
||||
self.assertEqual(set(range(10)), values)
|
||||
self.assertEqual(set(range(10)), keys, msg=results)
|
||||
self.assertEqual(set(range(10)), values, msg=results)
|
||||
|
||||
def test_string_statements(self):
|
||||
batch = BatchStatement(BatchType.LOGGED)
|
||||
@@ -370,6 +370,11 @@ class BatchStatementTests(unittest.TestCase):
|
||||
self.session.execute(batch)
|
||||
self.confirm_results()
|
||||
|
||||
def test_no_parameters_many_times(self):
|
||||
for i in range(1000):
|
||||
self.test_no_parameters()
|
||||
self.session.execute("TRUNCATE test3rf.test")
|
||||
|
||||
|
||||
class SerialConsistencyTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
53
tests/integration/standard/utils.py
Normal file
53
tests/integration/standard/utils.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
Helper module to populate a dummy Cassandra tables with data.
|
||||
"""
|
||||
|
||||
from tests.integration.datatype_utils import PRIMITIVE_DATATYPES, get_sample
|
||||
|
||||
def create_table_with_all_types(table_name, session, N):
|
||||
"""
|
||||
Method that given a table_name and session construct a table that contains
|
||||
all possible primitive types.
|
||||
|
||||
:param table_name: Name of table to create
|
||||
:param session: session to use for table creation
|
||||
:param N: the number of items to insert into the table
|
||||
|
||||
:return: a list of column names
|
||||
"""
|
||||
# create table
|
||||
alpha_type_list = ["primkey int PRIMARY KEY"]
|
||||
col_names = ["primkey"]
|
||||
start_index = ord('a')
|
||||
for i, datatype in enumerate(PRIMITIVE_DATATYPES):
|
||||
alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype))
|
||||
col_names.append(chr(start_index + i))
|
||||
|
||||
session.execute("CREATE TABLE {0} ({1})".format(
|
||||
table_name, ', '.join(alpha_type_list)), timeout=120)
|
||||
|
||||
# create the input
|
||||
|
||||
for key in range(N):
|
||||
params = get_all_primitive_params(key)
|
||||
|
||||
# insert into table as a simple statement
|
||||
columns_string = ', '.join(col_names)
|
||||
placeholders = ', '.join(["%s"] * len(col_names))
|
||||
session.execute("INSERT INTO {0} ({1}) VALUES ({2})".format(
|
||||
table_name, columns_string, placeholders), params, timeout=120)
|
||||
return col_names
|
||||
|
||||
|
||||
def get_all_primitive_params(key):
|
||||
"""
|
||||
Simple utility method used to give back a list of all possible primitive data sample types.
|
||||
"""
|
||||
params = [key]
|
||||
for datatype in PRIMITIVE_DATATYPES:
|
||||
params.append(get_sample(datatype))
|
||||
return params
|
||||
|
||||
|
||||
def get_primitive_datatypes():
|
||||
return ['int'] + list(PRIMITIVE_DATATYPES)
|
||||
@@ -75,7 +75,7 @@ class StressInsertsTests(unittest.TestCase):
|
||||
break
|
||||
for conn in pool.get_connections():
|
||||
if conn.in_flight > 1:
|
||||
print self.session.get_pool_state()
|
||||
print(self.session.get_pool_state())
|
||||
leaking_connections = True
|
||||
break
|
||||
i = i + 1
|
||||
|
||||
14
tests/unit/cython/__init__.py
Normal file
14
tests/unit/cython/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2013-2015 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
44
tests/unit/cython/bytesio_testhelper.pyx
Normal file
44
tests/unit/cython/bytesio_testhelper.pyx
Normal file
@@ -0,0 +1,44 @@
|
||||
# Copyright 2013-2015 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from cassandra.bytesio cimport BytesIOReader
|
||||
|
||||
def test_read1(assert_equal, assert_raises):
|
||||
cdef BytesIOReader reader = BytesIOReader(b'abcdef')
|
||||
assert_equal(reader.read(2)[:2], b'ab')
|
||||
assert_equal(reader.read(2)[:2], b'cd')
|
||||
assert_equal(reader.read(0)[:0], b'')
|
||||
assert_equal(reader.read(2)[:2], b'ef')
|
||||
|
||||
def test_read2(assert_equal, assert_raises):
|
||||
cdef BytesIOReader reader = BytesIOReader(b'abcdef')
|
||||
reader.read(5)
|
||||
reader.read(1)
|
||||
|
||||
def test_read3(assert_equal, assert_raises):
|
||||
cdef BytesIOReader reader = BytesIOReader(b'abcdef')
|
||||
reader.read(6)
|
||||
|
||||
def test_read_eof(assert_equal, assert_raises):
|
||||
cdef BytesIOReader reader = BytesIOReader(b'abcdef')
|
||||
reader.read(5)
|
||||
# cannot convert reader.read to an object, do it manually
|
||||
# assert_raises(EOFError, reader.read, 2)
|
||||
try:
|
||||
reader.read(2)
|
||||
except EOFError:
|
||||
pass
|
||||
else:
|
||||
raise Exception("Expected an EOFError")
|
||||
reader.read(1) # see that we can still read this
|
||||
35
tests/unit/cython/test_bytesio.py
Normal file
35
tests/unit/cython/test_bytesio.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright 2013-2015 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from tests.unit.cython.utils import cyimport, cythontest
|
||||
bytesio_testhelper = cyimport('tests.unit.cython.bytesio_testhelper')
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
|
||||
class BytesIOTest(unittest.TestCase):
|
||||
"""Test Cython BytesIO proxy"""
|
||||
|
||||
@cythontest
|
||||
def test_reading(self):
|
||||
bytesio_testhelper.test_read1(self.assertEqual, self.assertRaises)
|
||||
bytesio_testhelper.test_read2(self.assertEqual, self.assertRaises)
|
||||
bytesio_testhelper.test_read3(self.assertEqual, self.assertRaises)
|
||||
|
||||
@cythontest
|
||||
def test_reading_error(self):
|
||||
bytesio_testhelper.test_read_eof(self.assertEqual, self.assertRaises)
|
||||
38
tests/unit/cython/utils.py
Normal file
38
tests/unit/cython/utils.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# Copyright 2013-2015 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
def cyimport(import_path):
|
||||
"""
|
||||
Import a Cython module if available, otherwise return None
|
||||
(and skip any relevant tests).
|
||||
"""
|
||||
try:
|
||||
return __import__(import_path, fromlist=[True])
|
||||
except ImportError:
|
||||
if HAVE_CYTHON:
|
||||
raise
|
||||
return None
|
||||
|
||||
|
||||
# @cythontest
|
||||
# def test_something(self): ...
|
||||
cythontest = unittest.skipUnless(HAVE_CYTHON, 'Cython is not available')
|
||||
numpytest = unittest.skipUnless(HAVE_CYTHON and HAVE_NUMPY, 'NumPy is not available')
|
||||
Reference in New Issue
Block a user