Merge branch 'cython_serdes'
Conflicts: .gitignore
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -4,6 +4,7 @@
|
|||||||
*.so
|
*.so
|
||||||
*.egg
|
*.egg
|
||||||
*.egg-info
|
*.egg-info
|
||||||
|
*.attr
|
||||||
.tox
|
.tox
|
||||||
.python-version
|
.python-version
|
||||||
build
|
build
|
||||||
@@ -19,6 +20,7 @@ setuptools*.egg
|
|||||||
|
|
||||||
cassandra/*.c
|
cassandra/*.c
|
||||||
!cassandra/cmurmur3.c
|
!cassandra/cmurmur3.c
|
||||||
|
cassandra/*.html
|
||||||
|
|
||||||
# OSX
|
# OSX
|
||||||
.DS_Store
|
.DS_Store
|
||||||
@@ -38,3 +40,4 @@ cassandra/*.c
|
|||||||
|
|
||||||
#iPython
|
#iPython
|
||||||
*.ipynb
|
*.ipynb
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1,4 @@
|
|||||||
include setup.py README.rst MANIFEST.in LICENSE ez_setup.py
|
include setup.py README.rst MANIFEST.in LICENSE ez_setup.py
|
||||||
|
include cassandra/*.pyx
|
||||||
|
include cassandra/*.pxd
|
||||||
|
include tests/unit/cython/*.pyx
|
||||||
|
|||||||
58
cassandra/buffer.pxd
Normal file
58
cassandra/buffer.pxd
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
# Copyright 2013-2015 DataStax, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Simple buffer data structure that provides a view on existing memory
|
||||||
|
(e.g. from a bytes object). This memory must stay alive while the
|
||||||
|
buffer is in use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from cpython.bytes cimport PyBytes_AS_STRING
|
||||||
|
# char* PyBytes_AS_STRING(object string)
|
||||||
|
# Macro form of PyBytes_AsString() but without error
|
||||||
|
# checking. Only string objects are supported; no Unicode objects
|
||||||
|
# should be passed.
|
||||||
|
|
||||||
|
|
||||||
|
cdef struct Buffer:
|
||||||
|
char *ptr
|
||||||
|
Py_ssize_t size
|
||||||
|
|
||||||
|
|
||||||
|
cdef inline bytes to_bytes(Buffer *buf):
|
||||||
|
return buf.ptr[:buf.size]
|
||||||
|
|
||||||
|
cdef inline char *buf_ptr(Buffer *buf):
|
||||||
|
return buf.ptr
|
||||||
|
|
||||||
|
cdef inline char *buf_read(Buffer *buf, Py_ssize_t size) except NULL:
|
||||||
|
if size > buf.size:
|
||||||
|
raise IndexError("Requested more than length of buffer")
|
||||||
|
return buf.ptr
|
||||||
|
|
||||||
|
cdef inline int slice_buffer(Buffer *buf, Buffer *out,
|
||||||
|
Py_ssize_t start, Py_ssize_t size) except -1:
|
||||||
|
if size < 0:
|
||||||
|
raise ValueError("Length must be positive")
|
||||||
|
|
||||||
|
if start + size > buf.size:
|
||||||
|
raise IndexError("Buffer slice out of bounds")
|
||||||
|
|
||||||
|
out.ptr = buf.ptr + start
|
||||||
|
out.size = size
|
||||||
|
return 0
|
||||||
|
|
||||||
|
cdef inline void from_ptr_and_size(char *ptr, Py_ssize_t size, Buffer *out):
|
||||||
|
out.ptr = ptr
|
||||||
|
out.size = size
|
||||||
20
cassandra/bytesio.pxd
Normal file
20
cassandra/bytesio.pxd
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# Copyright 2013-2015 DataStax, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
cdef class BytesIOReader:
|
||||||
|
cdef bytes buf
|
||||||
|
cdef char *buf_ptr
|
||||||
|
cdef Py_ssize_t pos
|
||||||
|
cdef Py_ssize_t size
|
||||||
|
cdef char *read(self, Py_ssize_t n = ?) except NULL
|
||||||
44
cassandra/bytesio.pyx
Normal file
44
cassandra/bytesio.pyx
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
# Copyright 2013-2015 DataStax, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
cdef class BytesIOReader:
|
||||||
|
"""
|
||||||
|
This class provides efficient support for reading bytes from a 'bytes' buffer,
|
||||||
|
by returning char * values directly without allocating intermediate objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, bytes buf):
|
||||||
|
self.buf = buf
|
||||||
|
self.size = len(buf)
|
||||||
|
self.buf_ptr = self.buf
|
||||||
|
|
||||||
|
cdef char *read(self, Py_ssize_t n = -1) except NULL:
|
||||||
|
"""Read at most size bytes from the file
|
||||||
|
(less if the read hits EOF before obtaining size bytes).
|
||||||
|
|
||||||
|
If the size argument is negative or omitted, read all data until EOF
|
||||||
|
is reached. The bytes are returned as a string object. An empty
|
||||||
|
string is returned when EOF is encountered immediately.
|
||||||
|
"""
|
||||||
|
cdef Py_ssize_t newpos = self.pos + n
|
||||||
|
if n < 0:
|
||||||
|
newpos = self.size
|
||||||
|
elif newpos > self.size:
|
||||||
|
# Raise an error here, as we do not want the caller to consume past the
|
||||||
|
# end of the buffer
|
||||||
|
raise EOFError("Cannot read past the end of the file")
|
||||||
|
|
||||||
|
cdef char *res = self.buf_ptr + self.pos
|
||||||
|
self.pos = newpos
|
||||||
|
return res
|
||||||
@@ -288,7 +288,7 @@ class _CassandraType(object):
|
|||||||
Given a set of other CassandraTypes, create a new subtype of this type
|
Given a set of other CassandraTypes, create a new subtype of this type
|
||||||
using them as parameters. This is how composite types are constructed.
|
using them as parameters. This is how composite types are constructed.
|
||||||
|
|
||||||
>>> MapType.apply_parameters(DateType, BooleanType)
|
>>> MapType.apply_parameters([DateType, BooleanType])
|
||||||
<class 'cassandra.types.MapType(DateType, BooleanType)'>
|
<class 'cassandra.types.MapType(DateType, BooleanType)'>
|
||||||
|
|
||||||
`subtypes` will be a sequence of CassandraTypes. If provided, `names`
|
`subtypes` will be a sequence of CassandraTypes. If provided, `names`
|
||||||
@@ -884,27 +884,7 @@ class UserType(TupleType):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def deserialize_safe(cls, byts, protocol_version):
|
def deserialize_safe(cls, byts, protocol_version):
|
||||||
proto_version = max(3, protocol_version)
|
values = super(UserType, cls).deserialize_safe(byts, protocol_version)
|
||||||
p = 0
|
|
||||||
values = []
|
|
||||||
for col_type in cls.subtypes:
|
|
||||||
if p == len(byts):
|
|
||||||
break
|
|
||||||
itemlen = int32_unpack(byts[p:p + 4])
|
|
||||||
p += 4
|
|
||||||
if itemlen >= 0:
|
|
||||||
item = byts[p:p + itemlen]
|
|
||||||
p += itemlen
|
|
||||||
else:
|
|
||||||
item = None
|
|
||||||
# collections inside UDTs are always encoded with at least the
|
|
||||||
# version 3 format
|
|
||||||
values.append(col_type.from_binary(item, proto_version))
|
|
||||||
|
|
||||||
if len(values) < len(cls.subtypes):
|
|
||||||
nones = [None] * (len(cls.subtypes) - len(values))
|
|
||||||
values = values + nones
|
|
||||||
|
|
||||||
if cls.mapped_class:
|
if cls.mapped_class:
|
||||||
return cls.mapped_class(**dict(zip(cls.fieldnames, values)))
|
return cls.mapped_class(**dict(zip(cls.fieldnames, values)))
|
||||||
else:
|
else:
|
||||||
|
|||||||
11
cassandra/cython_deps.py
Normal file
11
cassandra/cython_deps.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
try:
|
||||||
|
from cassandra.row_parser import make_recv_results_rows
|
||||||
|
HAVE_CYTHON = True
|
||||||
|
except ImportError:
|
||||||
|
HAVE_CYTHON = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import numpy
|
||||||
|
HAVE_NUMPY = True
|
||||||
|
except ImportError:
|
||||||
|
HAVE_NUMPY = False
|
||||||
123
cassandra/cython_marshal.pyx
Normal file
123
cassandra/cython_marshal.pyx
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
# -- cython: profile=True
|
||||||
|
#
|
||||||
|
# Copyright 2013-2015 DataStax, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
|
from libc.stdint cimport (int8_t, int16_t, int32_t, int64_t,
|
||||||
|
uint8_t, uint16_t, uint32_t, uint64_t)
|
||||||
|
from cassandra.buffer cimport Buffer, buf_read, to_bytes
|
||||||
|
|
||||||
|
cdef bint is_little_endian
|
||||||
|
from cassandra.util import is_little_endian
|
||||||
|
|
||||||
|
cdef bint PY3 = six.PY3
|
||||||
|
|
||||||
|
|
||||||
|
cdef inline void swap_order(char *buf, Py_ssize_t size):
|
||||||
|
"""
|
||||||
|
Swap the byteorder of `buf` in-place on little-endian platforms
|
||||||
|
(reverse all the bytes).
|
||||||
|
There are functions ntohl etc, but these may be POSIX-dependent.
|
||||||
|
"""
|
||||||
|
cdef Py_ssize_t start, end, i
|
||||||
|
cdef char c
|
||||||
|
|
||||||
|
if is_little_endian:
|
||||||
|
for i in range(div2(size)):
|
||||||
|
end = size - i - 1
|
||||||
|
c = buf[i]
|
||||||
|
buf[i] = buf[end]
|
||||||
|
buf[end] = c
|
||||||
|
|
||||||
|
cdef inline Py_ssize_t div2(Py_ssize_t x):
|
||||||
|
return x >> 1
|
||||||
|
|
||||||
|
### Unpacking of signed integers
|
||||||
|
|
||||||
|
cdef inline int64_t int64_unpack(Buffer *buf) except ?0xDEAD:
|
||||||
|
cdef int64_t x = (<int64_t *> buf_read(buf, 8))[0]
|
||||||
|
cdef char *p = <char *> &x
|
||||||
|
swap_order(<char *> &x, 8)
|
||||||
|
return x
|
||||||
|
|
||||||
|
cdef inline int32_t int32_unpack(Buffer *buf) except ?0xDEAD:
|
||||||
|
cdef int32_t x = (<int32_t *> buf_read(buf, 4))[0]
|
||||||
|
cdef char *p = <char *> &x
|
||||||
|
swap_order(<char *> &x, 4)
|
||||||
|
return x
|
||||||
|
|
||||||
|
cdef inline int16_t int16_unpack(Buffer *buf) except ?0xDED:
|
||||||
|
cdef int16_t x = (<int16_t *> buf_read(buf, 2))[0]
|
||||||
|
swap_order(<char *> &x, 2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
cdef inline int8_t int8_unpack(Buffer *buf) except ?80:
|
||||||
|
return (<int8_t *> buf_read(buf, 1))[0]
|
||||||
|
|
||||||
|
cdef inline uint64_t uint64_unpack(Buffer *buf) except ?0xDEAD:
|
||||||
|
cdef uint64_t x = (<uint64_t *> buf_read(buf, 8))[0]
|
||||||
|
swap_order(<char *> &x, 8)
|
||||||
|
return x
|
||||||
|
|
||||||
|
cdef inline uint32_t uint32_unpack(Buffer *buf) except ?0xDEAD:
|
||||||
|
cdef uint32_t x = (<uint32_t *> buf_read(buf, 4))[0]
|
||||||
|
swap_order(<char *> &x, 4)
|
||||||
|
return x
|
||||||
|
|
||||||
|
cdef inline uint16_t uint16_unpack(Buffer *buf) except ?0xDEAD:
|
||||||
|
cdef uint16_t x = (<uint16_t *> buf_read(buf, 2))[0]
|
||||||
|
swap_order(<char *> &x, 2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
cdef inline uint8_t uint8_unpack(Buffer *buf) except ?0xff:
|
||||||
|
return (<uint8_t *> buf_read(buf, 1))[0]
|
||||||
|
|
||||||
|
cdef inline double double_unpack(Buffer *buf) except ?1.74:
|
||||||
|
cdef double x = (<double *> buf_read(buf, 8))[0]
|
||||||
|
swap_order(<char *> &x, 8)
|
||||||
|
return x
|
||||||
|
|
||||||
|
cdef inline float float_unpack(Buffer *buf) except ?1.74:
|
||||||
|
cdef float x = (<float *> buf_read(buf, 4))[0]
|
||||||
|
swap_order(<char *> &x, 4)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
cdef varint_unpack(Buffer *term):
|
||||||
|
"""Unpack a variable-sized integer"""
|
||||||
|
if PY3:
|
||||||
|
return varint_unpack_py3(to_bytes(term))
|
||||||
|
else:
|
||||||
|
return varint_unpack_py2(to_bytes(term))
|
||||||
|
|
||||||
|
# TODO: Optimize these two functions
|
||||||
|
cdef varint_unpack_py3(bytes term):
|
||||||
|
cdef int64_t one = 1L
|
||||||
|
|
||||||
|
val = int(''.join("%02x" % i for i in term), 16)
|
||||||
|
if (term[0] & 128) != 0:
|
||||||
|
# There is a bug in Cython (0.20 - 0.22), where if we do
|
||||||
|
# '1 << (len(term) * 8)' Cython generates '1' directly into the
|
||||||
|
# C code, causing integer overflows
|
||||||
|
val -= one << (len(term) * 8)
|
||||||
|
return val
|
||||||
|
|
||||||
|
cdef varint_unpack_py2(bytes term): # noqa
|
||||||
|
cdef int64_t one = 1L
|
||||||
|
val = int(term.encode('hex'), 16)
|
||||||
|
if (ord(term[0]) & 128) != 0:
|
||||||
|
val = val - (one << (len(term) * 8))
|
||||||
|
return val
|
||||||
2
cassandra/cython_utils.pxd
Normal file
2
cassandra/cython_utils.pxd
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from libc.stdint cimport int64_t
|
||||||
|
cdef datetime_from_timestamp(double timestamp)
|
||||||
44
cassandra/cython_utils.pyx
Normal file
44
cassandra/cython_utils.pyx
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
# Copyright 2013-2015 DataStax, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Duplicate module of util.py, with some accelerated functions
|
||||||
|
used for deserialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from cpython.datetime cimport (
|
||||||
|
timedelta_new,
|
||||||
|
# cdef inline object timedelta_new(int days, int seconds, int useconds)
|
||||||
|
# Create timedelta object using DateTime CAPI factory function.
|
||||||
|
# Note, there are no range checks for any of the arguments.
|
||||||
|
import_datetime,
|
||||||
|
# Datetime C API initialization function.
|
||||||
|
# You have to call it before any usage of DateTime CAPI functions.
|
||||||
|
)
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import sys
|
||||||
|
|
||||||
|
cdef bint is_little_endian
|
||||||
|
from cassandra.util import is_little_endian
|
||||||
|
|
||||||
|
import_datetime()
|
||||||
|
|
||||||
|
DATETIME_EPOC = datetime.datetime(1970, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
cdef datetime_from_timestamp(double timestamp):
|
||||||
|
cdef int seconds = <int> timestamp
|
||||||
|
cdef int microseconds = (<int64_t> (timestamp * 1000000)) % 1000000
|
||||||
|
return DATETIME_EPOC + timedelta_new(0, seconds, microseconds)
|
||||||
43
cassandra/deserializers.pxd
Normal file
43
cassandra/deserializers.pxd
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
# Copyright 2013-2015 DataStax, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from cassandra.buffer cimport Buffer
|
||||||
|
|
||||||
|
cdef class Deserializer:
|
||||||
|
# The cqltypes._CassandraType corresponding to this deserializer
|
||||||
|
cdef object cqltype
|
||||||
|
|
||||||
|
# String may be empty, whereas other values may not be.
|
||||||
|
# Other values may be NULL, in which case the integer length
|
||||||
|
# of the binary data is negative. However, non-string types
|
||||||
|
# may also return a zero length for legacy reasons
|
||||||
|
# (see http://code.metager.de/source/xref/apache/cassandra/doc/native_protocol_v3.spec
|
||||||
|
# paragraph 6)
|
||||||
|
cdef bint empty_binary_ok
|
||||||
|
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version)
|
||||||
|
# cdef deserialize(self, CString byts, protocol_version)
|
||||||
|
|
||||||
|
|
||||||
|
cdef inline object from_binary(Deserializer deserializer,
|
||||||
|
Buffer *buf,
|
||||||
|
int protocol_version):
|
||||||
|
if buf.size < 0:
|
||||||
|
return None
|
||||||
|
elif buf.size == 0 and not deserializer.empty_binary_ok:
|
||||||
|
return _ret_empty(deserializer, buf.size)
|
||||||
|
else:
|
||||||
|
return deserializer.deserialize(buf, protocol_version)
|
||||||
|
|
||||||
|
cdef _ret_empty(Deserializer deserializer, Py_ssize_t buf_size)
|
||||||
512
cassandra/deserializers.pyx
Normal file
512
cassandra/deserializers.pyx
Normal file
@@ -0,0 +1,512 @@
|
|||||||
|
# Copyright 2013-2015 DataStax, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from libc.stdint cimport int32_t, uint16_t
|
||||||
|
|
||||||
|
include 'cython_marshal.pyx'
|
||||||
|
from cassandra.buffer cimport Buffer, to_bytes, slice_buffer
|
||||||
|
from cassandra.cython_utils cimport datetime_from_timestamp
|
||||||
|
|
||||||
|
from cython.view cimport array as cython_array
|
||||||
|
from cassandra.tuple cimport tuple_new, tuple_set
|
||||||
|
|
||||||
|
import socket
|
||||||
|
from decimal import Decimal
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from cassandra import cqltypes
|
||||||
|
from cassandra import util
|
||||||
|
|
||||||
|
cdef bint PY2 = six.PY2
|
||||||
|
|
||||||
|
cdef class Deserializer:
|
||||||
|
"""Cython-based deserializer class for a cqltype"""
|
||||||
|
|
||||||
|
def __init__(self, cqltype):
|
||||||
|
self.cqltype = cqltype
|
||||||
|
self.empty_binary_ok = cqltype.empty_binary_ok
|
||||||
|
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
cdef class DesBytesType(Deserializer):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
return to_bytes(buf)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Use libmpdec: http://www.bytereef.org/mpdecimal/index.html
|
||||||
|
cdef class DesDecimalType(Deserializer):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
cdef Buffer varint_buf
|
||||||
|
slice_buffer(buf, &varint_buf, 4, buf.size - 4)
|
||||||
|
|
||||||
|
scale = int32_unpack(buf)
|
||||||
|
unscaled = varint_unpack(&varint_buf)
|
||||||
|
|
||||||
|
return Decimal('%de%d' % (unscaled, -scale))
|
||||||
|
|
||||||
|
|
||||||
|
cdef class DesUUIDType(Deserializer):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
return UUID(bytes=to_bytes(buf))
|
||||||
|
|
||||||
|
|
||||||
|
cdef class DesBooleanType(Deserializer):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
if int8_unpack(buf):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
cdef class DesByteType(Deserializer):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
return int8_unpack(buf)
|
||||||
|
|
||||||
|
|
||||||
|
cdef class DesAsciiType(Deserializer):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
if PY2:
|
||||||
|
return to_bytes(buf)
|
||||||
|
return to_bytes(buf).decode('ascii')
|
||||||
|
|
||||||
|
|
||||||
|
cdef class DesFloatType(Deserializer):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
return float_unpack(buf)
|
||||||
|
|
||||||
|
|
||||||
|
cdef class DesDoubleType(Deserializer):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
return double_unpack(buf)
|
||||||
|
|
||||||
|
|
||||||
|
cdef class DesLongType(Deserializer):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
return int64_unpack(buf)
|
||||||
|
|
||||||
|
|
||||||
|
cdef class DesInt32Type(Deserializer):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
return int32_unpack(buf)
|
||||||
|
|
||||||
|
|
||||||
|
cdef class DesIntegerType(Deserializer):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
return varint_unpack(buf)
|
||||||
|
|
||||||
|
|
||||||
|
cdef class DesInetAddressType(Deserializer):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
cdef bytes byts = to_bytes(buf)
|
||||||
|
|
||||||
|
# TODO: optimize inet_ntop, inet_ntoa
|
||||||
|
if buf.size == 16:
|
||||||
|
return util.inet_ntop(socket.AF_INET6, byts)
|
||||||
|
else:
|
||||||
|
# util.inet_pton could also handle, but this is faster
|
||||||
|
# since we've already determined the AF
|
||||||
|
return socket.inet_ntoa(byts)
|
||||||
|
|
||||||
|
|
||||||
|
cdef class DesCounterColumnType(DesLongType):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
cdef class DesDateType(Deserializer):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
cdef double timestamp = int64_unpack(buf) / 1000.0
|
||||||
|
return datetime_from_timestamp(timestamp)
|
||||||
|
|
||||||
|
|
||||||
|
cdef class TimestampType(DesDateType):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
cdef class TimeUUIDType(DesDateType):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
return UUID(bytes=to_bytes(buf))
|
||||||
|
|
||||||
|
|
||||||
|
# Values of the 'date'` type are encoded as 32-bit unsigned integers
|
||||||
|
# representing a number of days with epoch (January 1st, 1970) at the center of the
|
||||||
|
# range (2^31).
|
||||||
|
EPOCH_OFFSET_DAYS = 2 ** 31
|
||||||
|
|
||||||
|
cdef class DesSimpleDateType(Deserializer):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
days = uint32_unpack(buf) - EPOCH_OFFSET_DAYS
|
||||||
|
return util.Date(days)
|
||||||
|
|
||||||
|
|
||||||
|
cdef class DesShortType(Deserializer):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
return int16_unpack(buf)
|
||||||
|
|
||||||
|
|
||||||
|
cdef class DesTimeType(Deserializer):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
return util.Time(int64_unpack(buf))
|
||||||
|
|
||||||
|
|
||||||
|
cdef class DesUTF8Type(Deserializer):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
cdef val = to_bytes(buf)
|
||||||
|
return val.decode('utf8')
|
||||||
|
|
||||||
|
|
||||||
|
cdef class DesVarcharType(DesUTF8Type):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
cdef class _DesParameterizedType(Deserializer):
|
||||||
|
|
||||||
|
cdef object subtypes
|
||||||
|
cdef Deserializer[::1] deserializers
|
||||||
|
cdef Py_ssize_t subtypes_len
|
||||||
|
|
||||||
|
def __init__(self, cqltype):
|
||||||
|
super().__init__(cqltype)
|
||||||
|
self.subtypes = cqltype.subtypes
|
||||||
|
self.deserializers = make_deserializers(cqltype.subtypes)
|
||||||
|
self.subtypes_len = len(self.subtypes)
|
||||||
|
|
||||||
|
|
||||||
|
cdef class _DesSingleParamType(_DesParameterizedType):
|
||||||
|
cdef Deserializer deserializer
|
||||||
|
|
||||||
|
def __init__(self, cqltype):
|
||||||
|
assert cqltype.subtypes and len(cqltype.subtypes) == 1, cqltype.subtypes
|
||||||
|
super().__init__(cqltype)
|
||||||
|
self.deserializer = self.deserializers[0]
|
||||||
|
|
||||||
|
|
||||||
|
#--------------------------------------------------------------------------
|
||||||
|
# List and set deserialization
|
||||||
|
|
||||||
|
cdef class DesListType(_DesSingleParamType):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
cdef uint16_t v2_and_below = 2
|
||||||
|
cdef int32_t v3_and_above = 3
|
||||||
|
|
||||||
|
if protocol_version >= 3:
|
||||||
|
result = _deserialize_list_or_set[int32_t](
|
||||||
|
v3_and_above, buf, protocol_version, self.deserializer)
|
||||||
|
else:
|
||||||
|
result = _deserialize_list_or_set[uint16_t](
|
||||||
|
v2_and_below, buf, protocol_version, self.deserializer)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
cdef class DesSetType(DesListType):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
return util.sortedset(DesListType.deserialize(self, buf, protocol_version))
|
||||||
|
|
||||||
|
|
||||||
|
ctypedef fused itemlen_t:
|
||||||
|
uint16_t # protocol <= v2
|
||||||
|
int32_t # protocol >= v3
|
||||||
|
|
||||||
|
cdef list _deserialize_list_or_set(itemlen_t dummy_version,
|
||||||
|
Buffer *buf, int protocol_version,
|
||||||
|
Deserializer deserializer):
|
||||||
|
"""
|
||||||
|
Deserialize a list or set.
|
||||||
|
|
||||||
|
The 'dummy' parameter is needed to make fused types work, so that
|
||||||
|
we can specialize on the protocol version.
|
||||||
|
"""
|
||||||
|
cdef itemlen_t itemlen
|
||||||
|
cdef Buffer itemlen_buf
|
||||||
|
cdef Buffer elem_buf
|
||||||
|
|
||||||
|
cdef itemlen_t numelements
|
||||||
|
cdef itemlen_t idx
|
||||||
|
cdef list result = []
|
||||||
|
|
||||||
|
_unpack_len[itemlen_t](0, &numelements, buf)
|
||||||
|
idx = sizeof(itemlen_t)
|
||||||
|
|
||||||
|
for _ in range(numelements):
|
||||||
|
subelem(buf, &elem_buf, &idx)
|
||||||
|
result.append(from_binary(deserializer, &elem_buf, protocol_version))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
cdef inline int subelem(
|
||||||
|
Buffer *buf, Buffer *elem_buf, itemlen_t *idx_p) except -1:
|
||||||
|
"""
|
||||||
|
Read the next element from the buffer: first read the size (in bytes) of the
|
||||||
|
element, then fill elem_buf with a newly sliced buffer of this size (and the
|
||||||
|
right offset).
|
||||||
|
|
||||||
|
NOTE: The handling of 'idx' is somewhat atrocious, as there is a Cython
|
||||||
|
bug with the combination fused types + 'except' clause.
|
||||||
|
So instead, we pass in a pointer to 'idx', namely 'idx_p', and write
|
||||||
|
to this instead.
|
||||||
|
"""
|
||||||
|
cdef itemlen_t elemlen
|
||||||
|
|
||||||
|
_unpack_len[itemlen_t](idx_p[0], &elemlen, buf)
|
||||||
|
idx_p[0] += sizeof(itemlen_t)
|
||||||
|
slice_buffer(buf, elem_buf, idx_p[0], elemlen)
|
||||||
|
idx_p[0] += elemlen
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
cdef int _unpack_len(itemlen_t idx, itemlen_t *elemlen, Buffer *buf) except -1:
|
||||||
|
cdef itemlen_t result
|
||||||
|
cdef Buffer itemlen_buf
|
||||||
|
slice_buffer(buf, &itemlen_buf, idx, sizeof(itemlen_t))
|
||||||
|
|
||||||
|
if itemlen_t is uint16_t:
|
||||||
|
elemlen[0] = uint16_unpack(&itemlen_buf)
|
||||||
|
else:
|
||||||
|
elemlen[0] = int32_unpack(&itemlen_buf)
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
#--------------------------------------------------------------------------
|
||||||
|
# Map deserialization
|
||||||
|
|
||||||
|
cdef class DesMapType(_DesParameterizedType):
|
||||||
|
|
||||||
|
cdef Deserializer key_deserializer, val_deserializer
|
||||||
|
|
||||||
|
def __init__(self, cqltype):
|
||||||
|
super().__init__(cqltype)
|
||||||
|
self.key_deserializer = self.deserializers[0]
|
||||||
|
self.val_deserializer = self.deserializers[1]
|
||||||
|
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
cdef uint16_t v2_and_below = 0
|
||||||
|
cdef int32_t v3_and_above = 0
|
||||||
|
key_type, val_type = self.cqltype.subtypes
|
||||||
|
|
||||||
|
if protocol_version >= 3:
|
||||||
|
result = _deserialize_map[int32_t](
|
||||||
|
v3_and_above, buf, protocol_version,
|
||||||
|
self.key_deserializer, self.val_deserializer,
|
||||||
|
key_type, val_type)
|
||||||
|
else:
|
||||||
|
result = _deserialize_map[uint16_t](
|
||||||
|
v2_and_below, buf, protocol_version,
|
||||||
|
self.key_deserializer, self.val_deserializer,
|
||||||
|
key_type, val_type)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
cdef _deserialize_map(itemlen_t dummy_version,
|
||||||
|
Buffer *buf, int protocol_version,
|
||||||
|
Deserializer key_deserializer, Deserializer val_deserializer,
|
||||||
|
key_type, val_type):
|
||||||
|
cdef itemlen_t itemlen, val_len, key_len
|
||||||
|
cdef Buffer key_buf, val_buf
|
||||||
|
cdef Buffer itemlen_buf
|
||||||
|
|
||||||
|
cdef itemlen_t numelements
|
||||||
|
cdef itemlen_t idx = sizeof(itemlen_t)
|
||||||
|
cdef list result = []
|
||||||
|
|
||||||
|
_unpack_len[itemlen_t](0, &numelements, buf)
|
||||||
|
idx = sizeof(itemlen_t)
|
||||||
|
themap = util.OrderedMapSerializedKey(key_type, protocol_version)
|
||||||
|
for _ in range(numelements):
|
||||||
|
subelem(buf, &key_buf, &idx)
|
||||||
|
subelem(buf, &val_buf, &idx)
|
||||||
|
key = from_binary(key_deserializer, &key_buf, protocol_version)
|
||||||
|
val = from_binary(val_deserializer, &val_buf, protocol_version)
|
||||||
|
themap._insert_unchecked(key, to_bytes(&key_buf), val)
|
||||||
|
|
||||||
|
return themap
|
||||||
|
|
||||||
|
#--------------------------------------------------------------------------
|
||||||
|
|
||||||
|
cdef class DesTupleType(_DesParameterizedType):
|
||||||
|
|
||||||
|
# TODO: Use TupleRowParser to parse these tuples
|
||||||
|
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
cdef Py_ssize_t i, p
|
||||||
|
cdef int32_t itemlen
|
||||||
|
cdef tuple res = tuple_new(self.subtypes_len)
|
||||||
|
cdef Buffer item_buf
|
||||||
|
cdef Buffer itemlen_buf
|
||||||
|
cdef Deserializer deserializer
|
||||||
|
|
||||||
|
# collections inside UDTs are always encoded with at least the
|
||||||
|
# version 3 format
|
||||||
|
protocol_version = max(3, protocol_version)
|
||||||
|
|
||||||
|
p = 0
|
||||||
|
values = []
|
||||||
|
for i in range(self.subtypes_len):
|
||||||
|
item = None
|
||||||
|
if p < buf.size:
|
||||||
|
slice_buffer(buf, &itemlen_buf, p, 4)
|
||||||
|
itemlen = int32_unpack(&itemlen_buf)
|
||||||
|
p += 4
|
||||||
|
if itemlen >= 0:
|
||||||
|
slice_buffer(buf, &item_buf, p, itemlen)
|
||||||
|
p += itemlen
|
||||||
|
|
||||||
|
deserializer = self.deserializers[i]
|
||||||
|
item = from_binary(deserializer, &item_buf, protocol_version)
|
||||||
|
|
||||||
|
tuple_set(res, i, item)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
cdef class DesUserType(DesTupleType):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
typ = self.cqltype
|
||||||
|
values = DesTupleType.deserialize(self, buf, protocol_version)
|
||||||
|
if typ.mapped_class:
|
||||||
|
return typ.mapped_class(**dict(zip(typ.fieldnames, values)))
|
||||||
|
else:
|
||||||
|
return typ.tuple_type(*values)
|
||||||
|
|
||||||
|
|
||||||
|
cdef class DesCompositeType(_DesParameterizedType):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
cdef Py_ssize_t i, idx, start
|
||||||
|
cdef Buffer elem_buf
|
||||||
|
cdef int16_t element_length
|
||||||
|
cdef Deserializer deserializer
|
||||||
|
cdef tuple res = tuple_new(self.subtypes_len)
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
for i in range(self.subtypes_len):
|
||||||
|
if not buf.size:
|
||||||
|
# CompositeType can have missing elements at the end
|
||||||
|
|
||||||
|
# Fill the tuple with None values and slice it
|
||||||
|
#
|
||||||
|
# (I'm not sure a tuple needs to be fully initialized before
|
||||||
|
# it can be destroyed, so play it safe)
|
||||||
|
for j in range(i, self.subtypes_len):
|
||||||
|
tuple_set(res, j, None)
|
||||||
|
res = res[:i]
|
||||||
|
break
|
||||||
|
|
||||||
|
element_length = uint16_unpack(buf)
|
||||||
|
slice_buffer(buf, &elem_buf, 2, element_length)
|
||||||
|
|
||||||
|
deserializer = self.deserializers[i]
|
||||||
|
item = from_binary(deserializer, &elem_buf, protocol_version)
|
||||||
|
tuple_set(res, i, item)
|
||||||
|
|
||||||
|
# skip element length, element, and the EOC (one byte)
|
||||||
|
start = 2 + element_length + 1
|
||||||
|
slice_buffer(buf, buf, start, buf.size - start)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
DesDynamicCompositeType = DesCompositeType
|
||||||
|
|
||||||
|
|
||||||
|
cdef class DesReversedType(_DesSingleParamType):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
return from_binary(self.deserializer, buf, protocol_version)
|
||||||
|
|
||||||
|
|
||||||
|
cdef class DesFrozenType(_DesSingleParamType):
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
return from_binary(self.deserializer, buf, protocol_version)
|
||||||
|
|
||||||
|
#--------------------------------------------------------------------------
|
||||||
|
|
||||||
|
cdef _ret_empty(Deserializer deserializer, Py_ssize_t buf_size):
|
||||||
|
"""
|
||||||
|
Decide whether to return None or EMPTY when a buffer size is
|
||||||
|
zero or negative. This is used by from_binary in deserializers.pxd.
|
||||||
|
"""
|
||||||
|
if buf_size < 0:
|
||||||
|
return None
|
||||||
|
elif deserializer.cqltype.support_empty_values:
|
||||||
|
return cqltypes.EMPTY
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
#--------------------------------------------------------------------------
|
||||||
|
# Generic deserialization
|
||||||
|
|
||||||
|
cdef class GenericDeserializer(Deserializer):
|
||||||
|
"""
|
||||||
|
Wrap a generic datatype for deserialization
|
||||||
|
"""
|
||||||
|
|
||||||
|
cdef deserialize(self, Buffer *buf, int protocol_version):
|
||||||
|
return self.cqltype.deserialize(to_bytes(buf), protocol_version)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "GenericDeserializer(%s)" % (self.cqltype,)
|
||||||
|
|
||||||
|
#--------------------------------------------------------------------------
|
||||||
|
# Helper utilities
|
||||||
|
|
||||||
|
def make_deserializers(cqltypes):
|
||||||
|
"""Create an array of Deserializers for each given cqltype in cqltypes"""
|
||||||
|
cdef Deserializer[::1] deserializers
|
||||||
|
return obj_array([find_deserializer(ct) for ct in cqltypes])
|
||||||
|
|
||||||
|
|
||||||
|
cdef dict classes = globals()
|
||||||
|
|
||||||
|
cpdef Deserializer find_deserializer(cqltype):
|
||||||
|
"""Find a deserializer for a cqltype"""
|
||||||
|
name = 'Des' + cqltype.__name__
|
||||||
|
|
||||||
|
if name in globals():
|
||||||
|
cls = classes[name]
|
||||||
|
elif issubclass(cqltype, cqltypes.ListType):
|
||||||
|
cls = DesListType
|
||||||
|
elif issubclass(cqltype, cqltypes.SetType):
|
||||||
|
cls = DesSetType
|
||||||
|
elif issubclass(cqltype, cqltypes.MapType):
|
||||||
|
cls = DesMapType
|
||||||
|
elif issubclass(cqltype, cqltypes.UserType):
|
||||||
|
# UserType is a subclass of TupleType, so should precede it
|
||||||
|
cls = DesUserType
|
||||||
|
elif issubclass(cqltype, cqltypes.TupleType):
|
||||||
|
cls = DesTupleType
|
||||||
|
elif issubclass(cqltype, cqltypes.DynamicCompositeType):
|
||||||
|
# DynamicCompositeType is a subclass of CompositeType, so should precede it
|
||||||
|
cls = DesDynamicCompositeType
|
||||||
|
elif issubclass(cqltype, cqltypes.CompositeType):
|
||||||
|
cls = DesCompositeType
|
||||||
|
elif issubclass(cqltype, cqltypes.ReversedType):
|
||||||
|
cls = DesReversedType
|
||||||
|
elif issubclass(cqltype, cqltypes.FrozenType):
|
||||||
|
cls = DesFrozenType
|
||||||
|
else:
|
||||||
|
cls = GenericDeserializer
|
||||||
|
|
||||||
|
return cls(cqltype)
|
||||||
|
|
||||||
|
|
||||||
|
def obj_array(list objs):
|
||||||
|
"""Create a (Cython) array of objects given a list of objects"""
|
||||||
|
cdef object[:] arr
|
||||||
|
cdef Py_ssize_t i
|
||||||
|
arr = cython_array(shape=(len(objs),), itemsize=sizeof(void *), format="O")
|
||||||
|
# arr[:] = objs # This does not work (segmentation faults)
|
||||||
|
for i, obj in enumerate(objs):
|
||||||
|
arr[i] = obj
|
||||||
|
return arr
|
||||||
47
cassandra/ioutils.pyx
Normal file
47
cassandra/ioutils.pyx
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# Copyright 2013-2015 DataStax, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
include 'cython_marshal.pyx'
|
||||||
|
from cassandra.buffer cimport Buffer, from_ptr_and_size
|
||||||
|
|
||||||
|
from libc.stdint cimport int32_t
|
||||||
|
from cassandra.bytesio cimport BytesIOReader
|
||||||
|
|
||||||
|
|
||||||
|
cdef inline int get_buf(BytesIOReader reader, Buffer *buf_out) except -1:
|
||||||
|
"""
|
||||||
|
Get a pointer into the buffer provided by BytesIOReader for the
|
||||||
|
next data item in the stream of values.
|
||||||
|
|
||||||
|
BEWARE:
|
||||||
|
If the next item has a zero negative size, the pointer will be set to NULL.
|
||||||
|
A negative size happens when the value is NULL in the database, whereas a
|
||||||
|
zero size may happen either for legacy reasons, or for data types such as
|
||||||
|
strings (which may be empty).
|
||||||
|
"""
|
||||||
|
cdef Py_ssize_t raw_val_size = read_int(reader)
|
||||||
|
cdef char *ptr
|
||||||
|
if raw_val_size <= 0:
|
||||||
|
ptr = NULL
|
||||||
|
else:
|
||||||
|
ptr = reader.read(raw_val_size)
|
||||||
|
|
||||||
|
from_ptr_and_size(ptr, raw_val_size, buf_out)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
cdef inline int32_t read_int(BytesIOReader reader) except ?0xDEAD:
|
||||||
|
cdef Buffer buf
|
||||||
|
buf.ptr = reader.read(4)
|
||||||
|
buf.size = 4
|
||||||
|
return int32_unpack(&buf)
|
||||||
1
cassandra/numpyFlags.h
Normal file
1
cassandra/numpyFlags.h
Normal file
@@ -0,0 +1 @@
|
|||||||
|
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
|
||||||
185
cassandra/numpy_parser.pyx
Normal file
185
cassandra/numpy_parser.pyx
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
# Copyright 2013-2015 DataStax, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This module provider an optional protocol parser that returns
|
||||||
|
NumPy arrays.
|
||||||
|
|
||||||
|
=============================================================================
|
||||||
|
This module should not be imported by any of the main python-driver modules,
|
||||||
|
as numpy is an optional dependency.
|
||||||
|
=============================================================================
|
||||||
|
"""
|
||||||
|
|
||||||
|
include "ioutils.pyx"
|
||||||
|
|
||||||
|
cimport cython
|
||||||
|
from libc.stdint cimport uint64_t
|
||||||
|
from cpython.ref cimport Py_INCREF, PyObject
|
||||||
|
|
||||||
|
from cassandra.bytesio cimport BytesIOReader
|
||||||
|
from cassandra.deserializers cimport Deserializer, from_binary
|
||||||
|
from cassandra.parsing cimport ParseDesc, ColumnParser, RowParser
|
||||||
|
from cassandra import cqltypes
|
||||||
|
from cassandra.util import is_little_endian
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
# import pandas as pd
|
||||||
|
|
||||||
|
cdef extern from "numpyFlags.h":
|
||||||
|
# Include 'numpyFlags.h' into the generated C code to disable the
|
||||||
|
# deprecated NumPy API
|
||||||
|
pass
|
||||||
|
|
||||||
|
cdef extern from "Python.h":
|
||||||
|
# An integer type large enough to hold a pointer
|
||||||
|
ctypedef uint64_t Py_uintptr_t
|
||||||
|
|
||||||
|
|
||||||
|
# Simple array descriptor, useful to parse rows into a NumPy array
|
||||||
|
ctypedef struct ArrDesc:
|
||||||
|
Py_uintptr_t buf_ptr
|
||||||
|
int stride # should be large enough as we allocate contiguous arrays
|
||||||
|
int is_object
|
||||||
|
|
||||||
|
arrDescDtype = np.dtype(
|
||||||
|
[ ('buf_ptr', np.uintp)
|
||||||
|
, ('stride', np.dtype('i'))
|
||||||
|
, ('is_object', np.dtype('i'))
|
||||||
|
])
|
||||||
|
|
||||||
|
_cqltype_to_numpy = {
|
||||||
|
cqltypes.LongType: np.dtype('>i8'),
|
||||||
|
cqltypes.CounterColumnType: np.dtype('>i8'),
|
||||||
|
cqltypes.Int32Type: np.dtype('>i4'),
|
||||||
|
cqltypes.ShortType: np.dtype('>i2'),
|
||||||
|
cqltypes.FloatType: np.dtype('>f4'),
|
||||||
|
cqltypes.DoubleType: np.dtype('>f8'),
|
||||||
|
}
|
||||||
|
|
||||||
|
obj_dtype = np.dtype('O')
|
||||||
|
|
||||||
|
|
||||||
|
cdef class NumpyParser(ColumnParser):
|
||||||
|
"""Decode a ResultMessage into a bunch of NumPy arrays"""
|
||||||
|
|
||||||
|
cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc):
|
||||||
|
cdef Py_ssize_t rowcount
|
||||||
|
cdef ArrDesc[::1] array_descs
|
||||||
|
cdef ArrDesc *arrs
|
||||||
|
|
||||||
|
rowcount = read_int(reader)
|
||||||
|
array_descs, arrays = make_arrays(desc, rowcount)
|
||||||
|
arrs = &array_descs[0]
|
||||||
|
|
||||||
|
_parse_rows(reader, desc, arrs, rowcount)
|
||||||
|
|
||||||
|
arrays = [make_native_byteorder(arr) for arr in arrays]
|
||||||
|
result = dict(zip(desc.colnames, arrays))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
cdef _parse_rows(BytesIOReader reader, ParseDesc desc,
|
||||||
|
ArrDesc *arrs, Py_ssize_t rowcount):
|
||||||
|
cdef Py_ssize_t i
|
||||||
|
|
||||||
|
for i in range(rowcount):
|
||||||
|
unpack_row(reader, desc, arrs)
|
||||||
|
|
||||||
|
|
||||||
|
### Helper functions to create NumPy arrays and array descriptors
|
||||||
|
|
||||||
|
def make_arrays(ParseDesc desc, array_size):
|
||||||
|
"""
|
||||||
|
Allocate arrays for each result column.
|
||||||
|
|
||||||
|
returns a tuple of (array_descs, arrays), where
|
||||||
|
'array_descs' describe the arrays for NativeRowParser and
|
||||||
|
'arrays' is a dict mapping column names to arrays
|
||||||
|
(e.g. this can be fed into pandas.DataFrame)
|
||||||
|
"""
|
||||||
|
array_descs = np.empty((desc.rowsize,), arrDescDtype)
|
||||||
|
arrays = []
|
||||||
|
|
||||||
|
for i, coltype in enumerate(desc.coltypes):
|
||||||
|
arr = make_array(coltype, array_size)
|
||||||
|
array_descs[i]['buf_ptr'] = arr.ctypes.data
|
||||||
|
array_descs[i]['stride'] = arr.strides[0]
|
||||||
|
array_descs[i]['is_object'] = coltype not in _cqltype_to_numpy
|
||||||
|
arrays.append(arr)
|
||||||
|
|
||||||
|
return array_descs, arrays
|
||||||
|
|
||||||
|
|
||||||
|
def make_array(coltype, array_size):
|
||||||
|
"""
|
||||||
|
Allocate a new NumPy array of the given column type and size.
|
||||||
|
"""
|
||||||
|
dtype = _cqltype_to_numpy.get(coltype, obj_dtype)
|
||||||
|
return np.empty((array_size,), dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
#### Parse rows into NumPy arrays
|
||||||
|
|
||||||
|
@cython.boundscheck(False)
|
||||||
|
@cython.wraparound(False)
|
||||||
|
cdef inline int unpack_row(
|
||||||
|
BytesIOReader reader, ParseDesc desc, ArrDesc *arrays) except -1:
|
||||||
|
cdef Buffer buf
|
||||||
|
cdef Py_ssize_t i, rowsize = desc.rowsize
|
||||||
|
cdef ArrDesc arr
|
||||||
|
cdef Deserializer deserializer
|
||||||
|
|
||||||
|
for i in range(rowsize):
|
||||||
|
get_buf(reader, &buf)
|
||||||
|
arr = arrays[i]
|
||||||
|
|
||||||
|
if buf.size == 0:
|
||||||
|
raise ValueError("Cannot handle NULL value")
|
||||||
|
if arr.is_object:
|
||||||
|
deserializer = desc.deserializers[i]
|
||||||
|
val = from_binary(deserializer, &buf, desc.protocol_version)
|
||||||
|
Py_INCREF(val)
|
||||||
|
(<PyObject **> arr.buf_ptr)[0] = <PyObject *> val
|
||||||
|
else:
|
||||||
|
memcopy(buf.ptr, <char *> arr.buf_ptr, buf.size)
|
||||||
|
|
||||||
|
# Update the pointer into the array for the next time
|
||||||
|
arrays[i].buf_ptr += arr.stride
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
cdef inline void memcopy(char *src, char *dst, Py_ssize_t size):
|
||||||
|
"""
|
||||||
|
Our own simple memcopy which can be inlined. This is useful because our data types
|
||||||
|
are only a few bytes.
|
||||||
|
"""
|
||||||
|
cdef Py_ssize_t i
|
||||||
|
for i in range(size):
|
||||||
|
dst[i] = src[i]
|
||||||
|
|
||||||
|
|
||||||
|
def make_native_byteorder(arr):
|
||||||
|
"""
|
||||||
|
Make sure all values have a native endian in the NumPy arrays.
|
||||||
|
"""
|
||||||
|
if is_little_endian and not arr.dtype.kind == 'O':
|
||||||
|
# We have arrays in big-endian order. First swap the bytes
|
||||||
|
# into little endian order, and then update the numpy dtype
|
||||||
|
# accordingly (e.g. from '>i8' to '<i8')
|
||||||
|
#
|
||||||
|
# Ignore any object arrays of dtype('O')
|
||||||
|
return arr.byteswap().newbyteorder()
|
||||||
|
return arr
|
||||||
75
cassandra/obj_parser.pyx
Normal file
75
cassandra/obj_parser.pyx
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
# Copyright 2013-2015 DataStax, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
include "ioutils.pyx"
|
||||||
|
|
||||||
|
from cassandra.bytesio cimport BytesIOReader
|
||||||
|
from cassandra.deserializers cimport Deserializer, from_binary
|
||||||
|
from cassandra.parsing cimport ParseDesc, ColumnParser, RowParser
|
||||||
|
from cassandra.tuple cimport tuple_new, tuple_set
|
||||||
|
|
||||||
|
|
||||||
|
cdef class ListParser(ColumnParser):
|
||||||
|
"""Decode a ResultMessage into a list of tuples (or other objects)"""
|
||||||
|
|
||||||
|
cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc):
|
||||||
|
cdef Py_ssize_t i, rowcount
|
||||||
|
rowcount = read_int(reader)
|
||||||
|
cdef RowParser rowparser = TupleRowParser()
|
||||||
|
return [rowparser.unpack_row(reader, desc) for i in range(rowcount)]
|
||||||
|
|
||||||
|
|
||||||
|
cdef class LazyParser(ColumnParser):
|
||||||
|
"""Decode a ResultMessage lazily using a generator"""
|
||||||
|
|
||||||
|
cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc):
|
||||||
|
# Use a little helper function as closures (generators) are not
|
||||||
|
# supported in cpdef methods
|
||||||
|
return parse_rows_lazy(reader, desc)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_rows_lazy(BytesIOReader reader, ParseDesc desc):
|
||||||
|
cdef Py_ssize_t i, rowcount
|
||||||
|
rowcount = read_int(reader)
|
||||||
|
cdef RowParser rowparser = TupleRowParser()
|
||||||
|
return (rowparser.unpack_row(reader, desc) for i in range(rowcount))
|
||||||
|
|
||||||
|
|
||||||
|
cdef class TupleRowParser(RowParser):
|
||||||
|
"""
|
||||||
|
Parse a single returned row into a tuple of objects:
|
||||||
|
|
||||||
|
(obj1, ..., objN)
|
||||||
|
"""
|
||||||
|
|
||||||
|
cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc):
|
||||||
|
assert desc.rowsize >= 0
|
||||||
|
|
||||||
|
cdef Buffer buf
|
||||||
|
cdef Py_ssize_t i, rowsize = desc.rowsize
|
||||||
|
cdef Deserializer deserializer
|
||||||
|
cdef tuple res = tuple_new(desc.rowsize)
|
||||||
|
|
||||||
|
for i in range(rowsize):
|
||||||
|
# Read the next few bytes
|
||||||
|
get_buf(reader, &buf)
|
||||||
|
|
||||||
|
# Deserialize bytes to python object
|
||||||
|
deserializer = desc.deserializers[i]
|
||||||
|
val = from_binary(deserializer, &buf, desc.protocol_version)
|
||||||
|
|
||||||
|
# Insert new object into tuple
|
||||||
|
tuple_set(res, i, val)
|
||||||
|
|
||||||
|
return res
|
||||||
30
cassandra/parsing.pxd
Normal file
30
cassandra/parsing.pxd
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# Copyright 2013-2015 DataStax, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from cassandra.bytesio cimport BytesIOReader
|
||||||
|
from cassandra.deserializers cimport Deserializer
|
||||||
|
|
||||||
|
cdef class ParseDesc:
|
||||||
|
cdef public object colnames
|
||||||
|
cdef public object coltypes
|
||||||
|
cdef Deserializer[::1] deserializers
|
||||||
|
cdef public int protocol_version
|
||||||
|
cdef Py_ssize_t rowsize
|
||||||
|
|
||||||
|
cdef class ColumnParser:
|
||||||
|
cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc)
|
||||||
|
|
||||||
|
cdef class RowParser:
|
||||||
|
cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc)
|
||||||
|
|
||||||
44
cassandra/parsing.pyx
Normal file
44
cassandra/parsing.pyx
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
# Copyright 2013-2015 DataStax, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Module containing the definitions and declarations (parsing.pxd) for parsers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cdef class ParseDesc:
|
||||||
|
"""Description of what structure to parse"""
|
||||||
|
|
||||||
|
def __init__(self, colnames, coltypes, deserializers, protocol_version):
|
||||||
|
self.colnames = colnames
|
||||||
|
self.coltypes = coltypes
|
||||||
|
self.deserializers = deserializers
|
||||||
|
self.protocol_version = protocol_version
|
||||||
|
self.rowsize = len(colnames)
|
||||||
|
|
||||||
|
|
||||||
|
cdef class ColumnParser:
|
||||||
|
"""Decode a ResultMessage into a set of columns"""
|
||||||
|
|
||||||
|
cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
cdef class RowParser:
|
||||||
|
"""Parser for a single row"""
|
||||||
|
|
||||||
|
cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc):
|
||||||
|
"""
|
||||||
|
Unpack a single row of data in a ResultMessage.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
@@ -22,6 +22,7 @@ import six
|
|||||||
from six.moves import range
|
from six.moves import range
|
||||||
import io
|
import io
|
||||||
|
|
||||||
|
from cassandra import type_codes
|
||||||
from cassandra import (Unavailable, WriteTimeout, ReadTimeout,
|
from cassandra import (Unavailable, WriteTimeout, ReadTimeout,
|
||||||
WriteFailure, ReadFailure, FunctionFailure,
|
WriteFailure, ReadFailure, FunctionFailure,
|
||||||
AlreadyExists, InvalidRequest, Unauthorized,
|
AlreadyExists, InvalidRequest, Unauthorized,
|
||||||
@@ -35,10 +36,11 @@ from cassandra.cqltypes import (AsciiType, BytesType, BooleanType,
|
|||||||
DoubleType, FloatType, Int32Type,
|
DoubleType, FloatType, Int32Type,
|
||||||
InetAddressType, IntegerType, ListType,
|
InetAddressType, IntegerType, ListType,
|
||||||
LongType, MapType, SetType, TimeUUIDType,
|
LongType, MapType, SetType, TimeUUIDType,
|
||||||
UTF8Type, UUIDType, UserType,
|
UTF8Type, VarcharType, UUIDType, UserType,
|
||||||
TupleType, lookup_casstype, SimpleDateType,
|
TupleType, lookup_casstype, SimpleDateType,
|
||||||
TimeType, ByteType, ShortType)
|
TimeType, ByteType, ShortType)
|
||||||
from cassandra.policies import WriteType
|
from cassandra.policies import WriteType
|
||||||
|
from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY
|
||||||
from cassandra import util
|
from cassandra import util
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@@ -68,10 +70,16 @@ _message_types_by_opcode = {}
|
|||||||
|
|
||||||
_UNSET_VALUE = object()
|
_UNSET_VALUE = object()
|
||||||
|
|
||||||
|
def register_class(cls):
|
||||||
|
_message_types_by_opcode[cls.opcode] = cls
|
||||||
|
|
||||||
|
def get_registered_classes():
|
||||||
|
return _message_types_by_opcode.copy()
|
||||||
|
|
||||||
class _RegisterMessageType(type):
|
class _RegisterMessageType(type):
|
||||||
def __init__(cls, name, bases, dct):
|
def __init__(cls, name, bases, dct):
|
||||||
if not name.startswith('_'):
|
if not name.startswith('_'):
|
||||||
_message_types_by_opcode[cls.opcode] = cls
|
register_class(cls)
|
||||||
|
|
||||||
|
|
||||||
@six.add_metaclass(_RegisterMessageType)
|
@six.add_metaclass(_RegisterMessageType)
|
||||||
@@ -531,35 +539,6 @@ RESULT_KIND_SET_KEYSPACE = 0x0003
|
|||||||
RESULT_KIND_PREPARED = 0x0004
|
RESULT_KIND_PREPARED = 0x0004
|
||||||
RESULT_KIND_SCHEMA_CHANGE = 0x0005
|
RESULT_KIND_SCHEMA_CHANGE = 0x0005
|
||||||
|
|
||||||
class CassandraTypeCodes(object):
|
|
||||||
CUSTOM_TYPE = 0x0000
|
|
||||||
AsciiType = 0x0001
|
|
||||||
LongType = 0x0002
|
|
||||||
BytesType = 0x0003
|
|
||||||
BooleanType = 0x0004
|
|
||||||
CounterColumnType = 0x0005
|
|
||||||
DecimalType = 0x0006
|
|
||||||
DoubleType = 0x0007
|
|
||||||
FloatType = 0x0008
|
|
||||||
Int32Type = 0x0009
|
|
||||||
UTF8Type = 0x000A
|
|
||||||
DateType = 0x000B
|
|
||||||
UUIDType = 0x000C
|
|
||||||
UTF8Type = 0x000D
|
|
||||||
IntegerType = 0x000E
|
|
||||||
TimeUUIDType = 0x000F
|
|
||||||
InetAddressType = 0x0010
|
|
||||||
SimpleDateType = 0x0011
|
|
||||||
TimeType = 0x0012
|
|
||||||
ShortType = 0x0013
|
|
||||||
ByteType = 0x0014
|
|
||||||
ListType = 0x0020
|
|
||||||
MapType = 0x0021
|
|
||||||
SetType = 0x0022
|
|
||||||
UserType = 0x0030
|
|
||||||
TupleType = 0x0031
|
|
||||||
|
|
||||||
|
|
||||||
class ResultMessage(_MessageType):
|
class ResultMessage(_MessageType):
|
||||||
opcode = 0x08
|
opcode = 0x08
|
||||||
name = 'RESULT'
|
name = 'RESULT'
|
||||||
@@ -569,7 +548,7 @@ class ResultMessage(_MessageType):
|
|||||||
paging_state = None
|
paging_state = None
|
||||||
|
|
||||||
# Names match type name in module scope. Most are imported from cassandra.cqltypes (except CUSTOM_TYPE)
|
# Names match type name in module scope. Most are imported from cassandra.cqltypes (except CUSTOM_TYPE)
|
||||||
type_codes = _cqltypes_by_code = dict((v, globals()[k]) for k, v in CassandraTypeCodes.__dict__.items() if not k.startswith('_'))
|
type_codes = _cqltypes_by_code = dict((v, globals()[k]) for k, v in type_codes.__dict__.items() if not k.startswith('_'))
|
||||||
|
|
||||||
_FLAGS_GLOBAL_TABLES_SPEC = 0x0001
|
_FLAGS_GLOBAL_TABLES_SPEC = 0x0001
|
||||||
_HAS_MORE_PAGES_FLAG = 0x0002
|
_HAS_MORE_PAGES_FLAG = 0x0002
|
||||||
@@ -1017,6 +996,63 @@ class ProtocolHandler(object):
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
def cython_protocol_handler(colparser):
|
||||||
|
"""
|
||||||
|
Given a column parser to deserialize ResultMessages, return a suitable
|
||||||
|
Cython-based protocol handler.
|
||||||
|
|
||||||
|
There are three Cython-based protocol handlers (least to most performant):
|
||||||
|
|
||||||
|
1. obj_parser.ListParser
|
||||||
|
this parser decodes result messages into a list of tuples
|
||||||
|
|
||||||
|
2. obj_parser.LazyParser
|
||||||
|
this parser decodes result messages lazily by returning an iterator
|
||||||
|
|
||||||
|
3. numpy_parser.NumPyParser
|
||||||
|
this parser decodes result messages into NumPy arrays
|
||||||
|
|
||||||
|
The default is to use obj_parser.ListParser
|
||||||
|
"""
|
||||||
|
from cassandra.row_parser import make_recv_results_rows
|
||||||
|
|
||||||
|
class FastResultMessage(ResultMessage):
|
||||||
|
"""
|
||||||
|
Cython version of Result Message that has a faster implementation of
|
||||||
|
recv_results_row.
|
||||||
|
"""
|
||||||
|
# type_codes = ResultMessage.type_codes.copy()
|
||||||
|
code_to_type = dict((v, k) for k, v in ResultMessage.type_codes.items())
|
||||||
|
recv_results_rows = classmethod(make_recv_results_rows(colparser))
|
||||||
|
|
||||||
|
class CythonProtocolHandler(ProtocolHandler):
|
||||||
|
"""
|
||||||
|
Use FastResultMessage to decode query result message messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
my_opcodes = ProtocolHandler.message_types_by_opcode.copy()
|
||||||
|
my_opcodes[FastResultMessage.opcode] = FastResultMessage
|
||||||
|
message_types_by_opcode = my_opcodes
|
||||||
|
|
||||||
|
return CythonProtocolHandler
|
||||||
|
|
||||||
|
|
||||||
|
if HAVE_CYTHON:
|
||||||
|
from cassandra.obj_parser import ListParser, LazyParser
|
||||||
|
ProtocolHandler = cython_protocol_handler(ListParser())
|
||||||
|
LazyProtocolHandler = cython_protocol_handler(LazyParser())
|
||||||
|
else:
|
||||||
|
# Use Python-based ProtocolHandler
|
||||||
|
LazyProtocolHandler = None
|
||||||
|
|
||||||
|
|
||||||
|
if HAVE_CYTHON and HAVE_NUMPY:
|
||||||
|
from cassandra.numpy_parser import NumpyParser
|
||||||
|
NumpyProtocolHandler = cython_protocol_handler(NumpyParser())
|
||||||
|
else:
|
||||||
|
NumpyProtocolHandler = None
|
||||||
|
|
||||||
|
|
||||||
def read_byte(f):
|
def read_byte(f):
|
||||||
return int8_unpack(f.read(1))
|
return int8_unpack(f.read(1))
|
||||||
|
|
||||||
|
|||||||
38
cassandra/row_parser.pyx
Normal file
38
cassandra/row_parser.pyx
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
# Copyright 2013-2015 DataStax, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from cassandra.parsing cimport ParseDesc, ColumnParser
|
||||||
|
from cassandra.deserializers import make_deserializers
|
||||||
|
|
||||||
|
include "ioutils.pyx"
|
||||||
|
|
||||||
|
def make_recv_results_rows(ColumnParser colparser):
|
||||||
|
def recv_results_rows(cls, f, int protocol_version, user_type_map):
|
||||||
|
"""
|
||||||
|
Parse protocol data given as a BytesIO f into a set of columns (e.g. list of tuples)
|
||||||
|
This is used as the recv_results_rows method of (Fast)ResultMessage
|
||||||
|
"""
|
||||||
|
paging_state, column_metadata = cls.recv_results_metadata(f, user_type_map)
|
||||||
|
|
||||||
|
colnames = [c[2] for c in column_metadata]
|
||||||
|
coltypes = [c[3] for c in column_metadata]
|
||||||
|
|
||||||
|
desc = ParseDesc(colnames, coltypes, make_deserializers(coltypes),
|
||||||
|
protocol_version)
|
||||||
|
reader = BytesIOReader(f.read())
|
||||||
|
parsed_rows = colparser.parse_rows(reader, desc)
|
||||||
|
|
||||||
|
return (paging_state, (colnames, parsed_rows))
|
||||||
|
|
||||||
|
return recv_results_rows
|
||||||
41
cassandra/tuple.pxd
Normal file
41
cassandra/tuple.pxd
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# Copyright 2013-2015 DataStax, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from cpython.tuple cimport (
|
||||||
|
PyTuple_New,
|
||||||
|
# Return value: New reference.
|
||||||
|
# Return a new tuple object of size len, or NULL on failure.
|
||||||
|
PyTuple_SET_ITEM,
|
||||||
|
# Like PyTuple_SetItem(), but does no error checking, and should
|
||||||
|
# only be used to fill in brand new tuples. Note: This function
|
||||||
|
# ``steals'' a reference to o.
|
||||||
|
)
|
||||||
|
|
||||||
|
from cpython.ref cimport (
|
||||||
|
Py_INCREF
|
||||||
|
# void Py_INCREF(object o)
|
||||||
|
# Increment the reference count for object o. The object must not
|
||||||
|
# be NULL; if you aren't sure that it isn't NULL, use
|
||||||
|
# Py_XINCREF().
|
||||||
|
)
|
||||||
|
|
||||||
|
cdef inline tuple tuple_new(Py_ssize_t n):
|
||||||
|
"""Allocate a new tuple object"""
|
||||||
|
return PyTuple_New(n)
|
||||||
|
|
||||||
|
cdef inline void tuple_set(tuple tup, Py_ssize_t idx, object item):
|
||||||
|
"""Insert new object into tuple. No item must have been set yet."""
|
||||||
|
# PyTuple_SET_ITEM steals a reference, so we need to INCREF
|
||||||
|
Py_INCREF(item)
|
||||||
|
PyTuple_SET_ITEM(tup, idx, item)
|
||||||
42
cassandra/type_codes.pxd
Normal file
42
cassandra/type_codes.pxd
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# Copyright 2013-2015 DataStax, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
cdef enum:
|
||||||
|
CUSTOM_TYPE
|
||||||
|
AsciiType
|
||||||
|
LongType
|
||||||
|
BytesType
|
||||||
|
BooleanType
|
||||||
|
CounterColumnType
|
||||||
|
DecimalType
|
||||||
|
DoubleType
|
||||||
|
FloatType
|
||||||
|
Int32Type
|
||||||
|
UTF8Type
|
||||||
|
DateType
|
||||||
|
UUIDType
|
||||||
|
VarcharType
|
||||||
|
IntegerType
|
||||||
|
TimeUUIDType
|
||||||
|
InetAddressType
|
||||||
|
SimpleDateType
|
||||||
|
TimeType
|
||||||
|
ShortType
|
||||||
|
ByteType
|
||||||
|
ListType
|
||||||
|
MapType
|
||||||
|
SetType
|
||||||
|
UserType
|
||||||
|
TupleType
|
||||||
|
|
||||||
62
cassandra/type_codes.py
Normal file
62
cassandra/type_codes.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""
|
||||||
|
Module with constants for Cassandra type codes.
|
||||||
|
|
||||||
|
These constants are useful for
|
||||||
|
|
||||||
|
a) mapping messages to cqltypes (cassandra/cqltypes.py)
|
||||||
|
b) optimized dispatching for (de)serialization (cassandra/encoding.py)
|
||||||
|
|
||||||
|
Type codes are repeated here from the Cassandra binary protocol specification:
|
||||||
|
|
||||||
|
0x0000 Custom: the value is a [string], see above.
|
||||||
|
0x0001 Ascii
|
||||||
|
0x0002 Bigint
|
||||||
|
0x0003 Blob
|
||||||
|
0x0004 Boolean
|
||||||
|
0x0005 Counter
|
||||||
|
0x0006 Decimal
|
||||||
|
0x0007 Double
|
||||||
|
0x0008 Float
|
||||||
|
0x0009 Int
|
||||||
|
0x000A Text
|
||||||
|
0x000B Timestamp
|
||||||
|
0x000C Uuid
|
||||||
|
0x000D Varchar
|
||||||
|
0x000E Varint
|
||||||
|
0x000F Timeuuid
|
||||||
|
0x0010 Inet
|
||||||
|
0x0020 List: the value is an [option], representing the type
|
||||||
|
of the elements of the list.
|
||||||
|
0x0021 Map: the value is two [option], representing the types of the
|
||||||
|
keys and values of the map
|
||||||
|
0x0022 Set: the value is an [option], representing the type
|
||||||
|
of the elements of the set
|
||||||
|
"""
|
||||||
|
|
||||||
|
CUSTOM_TYPE = 0x0000
|
||||||
|
AsciiType = 0x0001
|
||||||
|
LongType = 0x0002
|
||||||
|
BytesType = 0x0003
|
||||||
|
BooleanType = 0x0004
|
||||||
|
CounterColumnType = 0x0005
|
||||||
|
DecimalType = 0x0006
|
||||||
|
DoubleType = 0x0007
|
||||||
|
FloatType = 0x0008
|
||||||
|
Int32Type = 0x0009
|
||||||
|
UTF8Type = 0x000A
|
||||||
|
DateType = 0x000B
|
||||||
|
UUIDType = 0x000C
|
||||||
|
VarcharType = 0x000D
|
||||||
|
IntegerType = 0x000E
|
||||||
|
TimeUUIDType = 0x000F
|
||||||
|
InetAddressType = 0x0010
|
||||||
|
SimpleDateType = 0x0011
|
||||||
|
TimeType = 0x0012
|
||||||
|
ShortType = 0x0013
|
||||||
|
ByteType = 0x0014
|
||||||
|
ListType = 0x0020
|
||||||
|
MapType = 0x0021
|
||||||
|
SetType = 0x0022
|
||||||
|
UserType = 0x0030
|
||||||
|
TupleType = 0x0031
|
||||||
|
|
||||||
@@ -1,12 +1,29 @@
|
|||||||
|
# Copyright 2013-2015 DataStax, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from __future__ import with_statement
|
from __future__ import with_statement
|
||||||
import calendar
|
import calendar
|
||||||
import datetime
|
import datetime
|
||||||
import random
|
import random
|
||||||
import six
|
import six
|
||||||
import uuid
|
import uuid
|
||||||
|
import sys
|
||||||
|
|
||||||
DATETIME_EPOC = datetime.datetime(1970, 1, 1)
|
DATETIME_EPOC = datetime.datetime(1970, 1, 1)
|
||||||
|
|
||||||
|
assert sys.byteorder in ('little', 'big')
|
||||||
|
is_little_endian = sys.byteorder == 'little'
|
||||||
|
|
||||||
def datetime_from_timestamp(timestamp):
|
def datetime_from_timestamp(timestamp):
|
||||||
"""
|
"""
|
||||||
@@ -490,8 +507,7 @@ except ImportError:
|
|||||||
|
|
||||||
def __init__(self, iterable=()):
|
def __init__(self, iterable=()):
|
||||||
self._items = []
|
self._items = []
|
||||||
for i in iterable:
|
self.update(iterable)
|
||||||
self.add(i)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._items)
|
return len(self._items)
|
||||||
@@ -564,6 +580,10 @@ except ImportError:
|
|||||||
else:
|
else:
|
||||||
self._items.append(item)
|
self._items.append(item)
|
||||||
|
|
||||||
|
def update(self, iterable):
|
||||||
|
for i in iterable:
|
||||||
|
self.add(i)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
del self._items[:]
|
del self._items[:]
|
||||||
|
|
||||||
|
|||||||
@@ -24,3 +24,17 @@ See :meth:`.Session.execute`, ::meth:`.Session.execute_async`, :attr:`.ResponseF
|
|||||||
.. automethod:: encode_message
|
.. automethod:: encode_message
|
||||||
|
|
||||||
.. automethod:: decode_message
|
.. automethod:: decode_message
|
||||||
|
|
||||||
|
Faster Deserialization
|
||||||
|
----------------------
|
||||||
|
When python-driver is compiled with Cython, it uses a Cython-based deserialization path
|
||||||
|
to deserialize messages. There are two additional ProtocolHandler classes that can be
|
||||||
|
used to deserialize response messages: the first is ``LazyProtocolHandler`` and the
|
||||||
|
second is ``NumpyProtocolHandler``.They can be used as follows:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
from cassandra.protocol import NumpyProtocolHandler, LazyProtocolHandler
|
||||||
|
s.client_protocol_handler = LazyProtocolHandler # for a result iterator
|
||||||
|
s.client_protocol_handler = NumpyProtocolHandler # for a dict of NumPy arrays as result
|
||||||
|
|
||||||
|
|||||||
11
setup.py
11
setup.py
@@ -37,7 +37,6 @@ from distutils.errors import (CCompilerError, DistutilsPlatformError,
|
|||||||
DistutilsExecError)
|
DistutilsExecError)
|
||||||
from distutils.cmd import Command
|
from distutils.cmd import Command
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import subprocess
|
import subprocess
|
||||||
has_subprocess = True
|
has_subprocess = True
|
||||||
@@ -71,6 +70,7 @@ if __name__ == '__main__' and sys.argv[1] == "install":
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
PROFILING = False
|
||||||
|
|
||||||
class DocCommand(Command):
|
class DocCommand(Command):
|
||||||
|
|
||||||
@@ -263,11 +263,16 @@ if "--no-libev" not in sys.argv and not is_windows:
|
|||||||
if "--no-cython" not in sys.argv:
|
if "--no-cython" not in sys.argv:
|
||||||
try:
|
try:
|
||||||
from Cython.Build import cythonize
|
from Cython.Build import cythonize
|
||||||
cython_candidates = ['cluster', 'concurrent', 'connection', 'cqltypes', 'marshal', 'metadata', 'pool', 'protocol', 'query', 'util']
|
cython_candidates = ['cluster', 'concurrent', 'connection', 'cqltypes', 'metadata',
|
||||||
|
'pool', 'protocol', 'query', 'util']
|
||||||
compile_args = [] if is_windows else ['-Wno-unused-function']
|
compile_args = [] if is_windows else ['-Wno-unused-function']
|
||||||
extensions.extend(cythonize(
|
extensions.extend(cythonize(
|
||||||
[Extension('cassandra.%s' % m, ['cassandra/%s.py' % m], extra_compile_args=compile_args) for m in cython_candidates],
|
[Extension('cassandra.%s' % m, ['cassandra/%s.py' % m],
|
||||||
|
extra_compile_args=compile_args)
|
||||||
|
for m in cython_candidates],
|
||||||
exclude_failures=True))
|
exclude_failures=True))
|
||||||
|
extensions.extend(cythonize("cassandra/*.pyx"))
|
||||||
|
extensions.extend(cythonize("tests/unit/cython/*.pyx"))
|
||||||
except ImportError:
|
except ImportError:
|
||||||
sys.stderr.write("Cython is not installed. Not compiling core driver files as extensions (optional).")
|
sys.stderr.write("Cython is not installed. Not compiling core driver files as extensions (optional).")
|
||||||
|
|
||||||
|
|||||||
@@ -160,7 +160,7 @@ def remove_cluster():
|
|||||||
CCM_CLUSTER.remove()
|
CCM_CLUSTER.remove()
|
||||||
CCM_CLUSTER = None
|
CCM_CLUSTER = None
|
||||||
return
|
return
|
||||||
except WindowsError:
|
except OSError:
|
||||||
ex_type, ex, tb = sys.exc_info()
|
ex_type, ex, tb = sys.exc_info()
|
||||||
log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb)))
|
log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb)))
|
||||||
del tb
|
del tb
|
||||||
|
|||||||
@@ -386,7 +386,8 @@ class TestMapColumn(BaseCassEngTestCase):
|
|||||||
k2 = uuid4()
|
k2 = uuid4()
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
then = now + timedelta(days=1)
|
then = now + timedelta(days=1)
|
||||||
m1 = TestMapModel.create(int_map={1: k1, 2: k2}, text_map={'now': now, 'then': then})
|
m1 = TestMapModel.create(int_map={1: k1, 2: k2},
|
||||||
|
text_map={'now': now, 'then': then})
|
||||||
m2 = TestMapModel.get(partition=m1.partition)
|
m2 = TestMapModel.get(partition=m1.partition)
|
||||||
|
|
||||||
self.assertTrue(isinstance(m2.int_map, dict))
|
self.assertTrue(isinstance(m2.int_map, dict))
|
||||||
|
|||||||
@@ -629,7 +629,7 @@ class TestMinMaxTimeUUIDFunctions(BaseCassEngTestCase):
|
|||||||
# test kwarg filtering
|
# test kwarg filtering
|
||||||
q = TimeUUIDQueryModel.filter(partition=pk, time__lte=functions.MaxTimeUUID(midpoint))
|
q = TimeUUIDQueryModel.filter(partition=pk, time__lte=functions.MaxTimeUUID(midpoint))
|
||||||
q = [d for d in q]
|
q = [d for d in q]
|
||||||
assert len(q) == 2
|
self.assertEqual(len(q), 2, msg="Got: %s" % q)
|
||||||
datas = [d.data for d in q]
|
datas = [d.data for d in q]
|
||||||
assert '1' in datas
|
assert '1' in datas
|
||||||
assert '2' in datas
|
assert '2' in datas
|
||||||
|
|||||||
@@ -52,17 +52,17 @@ class QueryUpdateTests(BaseCassEngTestCase):
|
|||||||
|
|
||||||
# sanity check
|
# sanity check
|
||||||
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
||||||
assert row.cluster == i
|
self.assertEqual(row.cluster, i)
|
||||||
assert row.count == i
|
self.assertEqual(row.count, i)
|
||||||
assert row.text == str(i)
|
self.assertEqual(row.text, str(i))
|
||||||
|
|
||||||
# perform update
|
# perform update
|
||||||
TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6)
|
TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6)
|
||||||
|
|
||||||
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
||||||
assert row.cluster == i
|
self.assertEqual(row.cluster, i)
|
||||||
assert row.count == (6 if i == 3 else i)
|
self.assertEqual(row.count, 6 if i == 3 else i)
|
||||||
assert row.text == str(i)
|
self.assertEqual(row.text, str(i))
|
||||||
|
|
||||||
def test_update_values_validation(self):
|
def test_update_values_validation(self):
|
||||||
""" tests calling udpate on models with values passed in """
|
""" tests calling udpate on models with values passed in """
|
||||||
@@ -72,9 +72,9 @@ class QueryUpdateTests(BaseCassEngTestCase):
|
|||||||
|
|
||||||
# sanity check
|
# sanity check
|
||||||
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
||||||
assert row.cluster == i
|
self.assertEqual(row.cluster, i)
|
||||||
assert row.count == i
|
self.assertEqual(row.count, i)
|
||||||
assert row.text == str(i)
|
self.assertEqual(row.text, str(i))
|
||||||
|
|
||||||
# perform update
|
# perform update
|
||||||
with self.assertRaises(ValidationError):
|
with self.assertRaises(ValidationError):
|
||||||
@@ -98,17 +98,17 @@ class QueryUpdateTests(BaseCassEngTestCase):
|
|||||||
|
|
||||||
# sanity check
|
# sanity check
|
||||||
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
||||||
assert row.cluster == i
|
self.assertEqual(row.cluster, i)
|
||||||
assert row.count == i
|
self.assertEqual(row.count, i)
|
||||||
assert row.text == str(i)
|
self.assertEqual(row.text, str(i))
|
||||||
|
|
||||||
# perform update
|
# perform update
|
||||||
TestQueryUpdateModel.objects(partition=partition, cluster=3).update(text=None)
|
TestQueryUpdateModel.objects(partition=partition, cluster=3).update(text=None)
|
||||||
|
|
||||||
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
||||||
assert row.cluster == i
|
self.assertEqual(row.cluster, i)
|
||||||
assert row.count == i
|
self.assertEqual(row.count, i)
|
||||||
assert row.text == (None if i == 3 else str(i))
|
self.assertEqual(row.text, None if i == 3 else str(i))
|
||||||
|
|
||||||
def test_mixed_value_and_null_update(self):
|
def test_mixed_value_and_null_update(self):
|
||||||
""" tests that updating a columns value, and removing another works properly """
|
""" tests that updating a columns value, and removing another works properly """
|
||||||
@@ -118,17 +118,17 @@ class QueryUpdateTests(BaseCassEngTestCase):
|
|||||||
|
|
||||||
# sanity check
|
# sanity check
|
||||||
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
||||||
assert row.cluster == i
|
self.assertEqual(row.cluster, i)
|
||||||
assert row.count == i
|
self.assertEqual(row.count, i)
|
||||||
assert row.text == str(i)
|
self.assertEqual(row.text, str(i))
|
||||||
|
|
||||||
# perform update
|
# perform update
|
||||||
TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6, text=None)
|
TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6, text=None)
|
||||||
|
|
||||||
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
||||||
assert row.cluster == i
|
self.assertEqual(row.cluster, i)
|
||||||
assert row.count == (6 if i == 3 else i)
|
self.assertEqual(row.count, 6 if i == 3 else i)
|
||||||
assert row.text == (None if i == 3 else str(i))
|
self.assertEqual(row.text, None if i == 3 else str(i))
|
||||||
|
|
||||||
def test_counter_updates(self):
|
def test_counter_updates(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -85,11 +85,11 @@ class SchemaTests(unittest.TestCase):
|
|||||||
|
|
||||||
session = self.session
|
session = self.session
|
||||||
|
|
||||||
for i in xrange(30):
|
for i in range(30):
|
||||||
execute_until_pass(session, "CREATE KEYSPACE test_{0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}".format(i))
|
execute_until_pass(session, "CREATE KEYSPACE test_{0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}".format(i))
|
||||||
execute_until_pass(session, "CREATE TABLE test_{0}.cf (key int PRIMARY KEY, value int)".format(i))
|
execute_until_pass(session, "CREATE TABLE test_{0}.cf (key int PRIMARY KEY, value int)".format(i))
|
||||||
|
|
||||||
for j in xrange(100):
|
for j in range(100):
|
||||||
execute_until_pass(session, "INSERT INTO test_{0}.cf (key, value) VALUES ({1}, {1})".format(i, j))
|
execute_until_pass(session, "INSERT INTO test_{0}.cf (key, value) VALUES ({1}, {1})".format(i, j))
|
||||||
|
|
||||||
execute_until_pass(session, "DROP KEYSPACE test_{0}".format(i))
|
execute_until_pass(session, "DROP KEYSPACE test_{0}".format(i))
|
||||||
@@ -102,7 +102,7 @@ class SchemaTests(unittest.TestCase):
|
|||||||
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
session = cluster.connect()
|
session = cluster.connect()
|
||||||
|
|
||||||
for i in xrange(30):
|
for i in range(30):
|
||||||
try:
|
try:
|
||||||
execute_until_pass(session, "CREATE KEYSPACE test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}")
|
execute_until_pass(session, "CREATE KEYSPACE test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}")
|
||||||
except AlreadyExists:
|
except AlreadyExists:
|
||||||
@@ -111,7 +111,7 @@ class SchemaTests(unittest.TestCase):
|
|||||||
|
|
||||||
execute_until_pass(session, "CREATE TABLE test.cf (key int PRIMARY KEY, value int)")
|
execute_until_pass(session, "CREATE TABLE test.cf (key int PRIMARY KEY, value int)")
|
||||||
|
|
||||||
for j in xrange(100):
|
for j in range(100):
|
||||||
execute_until_pass(session, "INSERT INTO test.cf (key, value) VALUES ({0}, {0})".format(j))
|
execute_until_pass(session, "INSERT INTO test.cf (key, value) VALUES ({0}, {0})".format(j))
|
||||||
|
|
||||||
execute_until_pass(session, "DROP KEYSPACE test")
|
execute_until_pass(session, "DROP KEYSPACE test")
|
||||||
|
|||||||
@@ -11,6 +11,12 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
try:
|
||||||
|
import unittest2 as unittest
|
||||||
|
except ImportError:
|
||||||
|
import unittest # noqa
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from ccmlib import common
|
from ccmlib import common
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ from cassandra.query import tuple_factory, SimpleStatement
|
|||||||
|
|
||||||
from tests.integration import use_singledc, PROTOCOL_VERSION
|
from tests.integration import use_singledc, PROTOCOL_VERSION
|
||||||
|
|
||||||
|
from six import next
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import unittest2 as unittest
|
import unittest2 as unittest
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ from cassandra.protocol import ProtocolHandler, ResultMessage, UUIDType, read_in
|
|||||||
from cassandra.query import tuple_factory
|
from cassandra.query import tuple_factory
|
||||||
from cassandra.cluster import Cluster
|
from cassandra.cluster import Cluster
|
||||||
from tests.integration import use_singledc, PROTOCOL_VERSION, execute_until_pass
|
from tests.integration import use_singledc, PROTOCOL_VERSION, execute_until_pass
|
||||||
from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, get_sample
|
from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES
|
||||||
|
from tests.integration.standard.utils import create_table_with_all_types, get_all_primitive_params
|
||||||
from six import binary_type
|
from six import binary_type
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
@@ -62,7 +63,7 @@ class CustomProtocolHandlerTest(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Ensure that we get normal uuid back first
|
# Ensure that we get normal uuid back first
|
||||||
session = Cluster().connect()
|
session = Cluster(protocol_version=PROTOCOL_VERSION).connect(keyspace="custserdes")
|
||||||
session.row_factory = tuple_factory
|
session.row_factory = tuple_factory
|
||||||
result_set = session.execute("SELECT schema_version FROM system.local")
|
result_set = session.execute("SELECT schema_version FROM system.local")
|
||||||
result = result_set.pop()
|
result = result_set.pop()
|
||||||
@@ -102,15 +103,16 @@ class CustomProtocolHandlerTest(unittest.TestCase):
|
|||||||
@test_category data_types:serialization
|
@test_category data_types:serialization
|
||||||
"""
|
"""
|
||||||
# Connect using a custom protocol handler that tracks the various types the result message is used with.
|
# Connect using a custom protocol handler that tracks the various types the result message is used with.
|
||||||
session = Cluster().connect(keyspace="custserdes")
|
session = Cluster(protocol_version=PROTOCOL_VERSION).connect(keyspace="custserdes")
|
||||||
session.client_protocol_handler = CustomProtocolHandlerResultMessageTracked
|
session.client_protocol_handler = CustomProtocolHandlerResultMessageTracked
|
||||||
session.row_factory = tuple_factory
|
session.row_factory = tuple_factory
|
||||||
|
|
||||||
columns_string = create_table_with_all_types("test_table", session)
|
colnames = create_table_with_all_types("alltypes", session, 1)
|
||||||
|
columns_string = ", ".join(colnames)
|
||||||
|
|
||||||
# verify data
|
# verify data
|
||||||
params = get_all_primitive_params()
|
params = get_all_primitive_params(0)
|
||||||
results = session.execute("SELECT {0} FROM alltypes WHERE zz=0".format(columns_string))[0]
|
results = session.execute("SELECT {0} FROM alltypes WHERE primkey=0".format(columns_string))[0]
|
||||||
for expected, actual in zip(params, results):
|
for expected, actual in zip(params, results):
|
||||||
self.assertEqual(actual, expected)
|
self.assertEqual(actual, expected)
|
||||||
# Ensure we have covered the various primitive types
|
# Ensure we have covered the various primitive types
|
||||||
@@ -118,43 +120,6 @@ class CustomProtocolHandlerTest(unittest.TestCase):
|
|||||||
session.shutdown()
|
session.shutdown()
|
||||||
|
|
||||||
|
|
||||||
def create_table_with_all_types(table_name, session):
|
|
||||||
"""
|
|
||||||
Method that given a table_name and session construct a table that contains all possible primitive types
|
|
||||||
:param table_name: Name of table to create
|
|
||||||
:param session: session to use for table creation
|
|
||||||
:return: a string containing and columns. This can be used to query the table.
|
|
||||||
"""
|
|
||||||
# create table
|
|
||||||
alpha_type_list = ["zz int PRIMARY KEY"]
|
|
||||||
col_names = ["zz"]
|
|
||||||
start_index = ord('a')
|
|
||||||
for i, datatype in enumerate(PRIMITIVE_DATATYPES):
|
|
||||||
alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype))
|
|
||||||
col_names.append(chr(start_index + i))
|
|
||||||
|
|
||||||
session.execute("CREATE TABLE alltypes ({0})".format(', '.join(alpha_type_list)), timeout=120)
|
|
||||||
|
|
||||||
# create the input
|
|
||||||
params = get_all_primitive_params()
|
|
||||||
|
|
||||||
# insert into table as a simple statement
|
|
||||||
columns_string = ', '.join(col_names)
|
|
||||||
placeholders = ', '.join(["%s"] * len(col_names))
|
|
||||||
session.execute("INSERT INTO alltypes ({0}) VALUES ({1})".format(columns_string, placeholders), params, timeout=120)
|
|
||||||
return columns_string
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_primitive_params():
|
|
||||||
"""
|
|
||||||
Simple utility method used to give back a list of all possible primitive data sample types.
|
|
||||||
"""
|
|
||||||
params = [0]
|
|
||||||
for datatype in PRIMITIVE_DATATYPES:
|
|
||||||
params.append((get_sample(datatype)))
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
class CustomResultMessageRaw(ResultMessage):
|
class CustomResultMessageRaw(ResultMessage):
|
||||||
"""
|
"""
|
||||||
This is a custom Result Message that is used to return raw results, rather then
|
This is a custom Result Message that is used to return raw results, rather then
|
||||||
|
|||||||
129
tests/integration/standard/test_cython_protocol_handlers.py
Normal file
129
tests/integration/standard/test_cython_protocol_handlers.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
"""Test the various Cython-based message deserializers"""
|
||||||
|
|
||||||
|
# Based on test_custom_protocol_handler.py
|
||||||
|
|
||||||
|
try:
|
||||||
|
import unittest2 as unittest
|
||||||
|
except ImportError:
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from cassandra.query import tuple_factory
|
||||||
|
from cassandra.cluster import Cluster
|
||||||
|
from cassandra.protocol import ProtocolHandler, LazyProtocolHandler, NumpyProtocolHandler
|
||||||
|
|
||||||
|
from tests.integration import use_singledc, PROTOCOL_VERSION
|
||||||
|
from tests.integration.datatype_utils import update_datatypes
|
||||||
|
from tests.integration.standard.utils import (
|
||||||
|
create_table_with_all_types, get_all_primitive_params, get_primitive_datatypes)
|
||||||
|
|
||||||
|
from tests.unit.cython.utils import cythontest, numpytest
|
||||||
|
|
||||||
|
def setup_module():
|
||||||
|
use_singledc()
|
||||||
|
update_datatypes()
|
||||||
|
|
||||||
|
|
||||||
|
class CythonProtocolHandlerTest(unittest.TestCase):
|
||||||
|
|
||||||
|
N_ITEMS = 10
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
|
cls.session = cls.cluster.connect()
|
||||||
|
cls.session.execute("CREATE KEYSPACE testspace WITH replication = "
|
||||||
|
"{ 'class' : 'SimpleStrategy', 'replication_factor': '1'}")
|
||||||
|
cls.session.set_keyspace("testspace")
|
||||||
|
cls.colnames = create_table_with_all_types("test_table", cls.session, cls.N_ITEMS)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
cls.session.execute("DROP KEYSPACE testspace")
|
||||||
|
cls.cluster.shutdown()
|
||||||
|
|
||||||
|
@cythontest
|
||||||
|
def test_cython_parser(self):
|
||||||
|
"""
|
||||||
|
Test Cython-based parser that returns a list of tuples
|
||||||
|
"""
|
||||||
|
verify_iterator_data(self.assertEqual, get_data(ProtocolHandler))
|
||||||
|
|
||||||
|
@cythontest
|
||||||
|
def test_cython_lazy_parser(self):
|
||||||
|
"""
|
||||||
|
Test Cython-based parser that returns an iterator of tuples
|
||||||
|
"""
|
||||||
|
verify_iterator_data(self.assertEqual, get_data(LazyProtocolHandler))
|
||||||
|
|
||||||
|
@numpytest
|
||||||
|
def test_numpy_parser(self):
|
||||||
|
"""
|
||||||
|
Test Numpy-based parser that returns a NumPy array
|
||||||
|
"""
|
||||||
|
# arrays = { 'a': arr1, 'b': arr2, ... }
|
||||||
|
arrays = get_data(NumpyProtocolHandler)
|
||||||
|
|
||||||
|
colnames = self.colnames
|
||||||
|
datatypes = get_primitive_datatypes()
|
||||||
|
for colname, datatype in zip(colnames, datatypes):
|
||||||
|
arr = arrays[colname]
|
||||||
|
self.match_dtype(datatype, arr.dtype)
|
||||||
|
|
||||||
|
verify_iterator_data(self.assertEqual, arrays_to_list_of_tuples(arrays, colnames))
|
||||||
|
|
||||||
|
def match_dtype(self, datatype, dtype):
|
||||||
|
"""Match a string cqltype (e.g. 'int' or 'blob') with a numpy dtype"""
|
||||||
|
if datatype == 'smallint':
|
||||||
|
self.match_dtype_props(dtype, 'i', 2)
|
||||||
|
elif datatype == 'int':
|
||||||
|
self.match_dtype_props(dtype, 'i', 4)
|
||||||
|
elif datatype in ('bigint', 'counter'):
|
||||||
|
self.match_dtype_props(dtype, 'i', 8)
|
||||||
|
elif datatype == 'float':
|
||||||
|
self.match_dtype_props(dtype, 'f', 4)
|
||||||
|
elif datatype == 'double':
|
||||||
|
self.match_dtype_props(dtype, 'f', 8)
|
||||||
|
else:
|
||||||
|
self.assertEqual(dtype.kind, 'O', msg=(dtype, datatype))
|
||||||
|
|
||||||
|
def match_dtype_props(self, dtype, kind, size, signed=None):
|
||||||
|
self.assertEqual(dtype.kind, kind, msg=dtype)
|
||||||
|
self.assertEqual(dtype.itemsize, size, msg=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def arrays_to_list_of_tuples(arrays, colnames):
|
||||||
|
"""Convert a dict of arrays (as given by the numpy protocol handler) to a list of tuples"""
|
||||||
|
first_array = arrays[colnames[0]]
|
||||||
|
return [tuple(arrays[colname][i] for colname in colnames)
|
||||||
|
for i in range(len(first_array))]
|
||||||
|
|
||||||
|
|
||||||
|
def get_data(protocol_handler):
|
||||||
|
"""
|
||||||
|
Get some data from the test table.
|
||||||
|
|
||||||
|
:param key: if None, get all results (100.000 results), otherwise get only one result
|
||||||
|
"""
|
||||||
|
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
|
||||||
|
session = cluster.connect(keyspace="testspace")
|
||||||
|
|
||||||
|
# use our custom protocol handler
|
||||||
|
session.client_protocol_handler = protocol_handler
|
||||||
|
session.row_factory = tuple_factory
|
||||||
|
|
||||||
|
results = session.execute("SELECT * FROM test_table")
|
||||||
|
session.shutdown()
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def verify_iterator_data(assertEqual, results):
|
||||||
|
"""
|
||||||
|
Check the result of get_data() when this is a list or
|
||||||
|
iterator of tuples
|
||||||
|
"""
|
||||||
|
for result in results:
|
||||||
|
params = get_all_primitive_params(result[0])
|
||||||
|
assertEqual(len(params), len(result),
|
||||||
|
msg="Not the right number of columns?")
|
||||||
|
for expected, actual in zip(params, result):
|
||||||
|
assertEqual(actual, expected)
|
||||||
@@ -301,8 +301,8 @@ class BatchStatementTests(unittest.TestCase):
|
|||||||
keys.add(result.k)
|
keys.add(result.k)
|
||||||
values.add(result.v)
|
values.add(result.v)
|
||||||
|
|
||||||
self.assertEqual(set(range(10)), keys)
|
self.assertEqual(set(range(10)), keys, msg=results)
|
||||||
self.assertEqual(set(range(10)), values)
|
self.assertEqual(set(range(10)), values, msg=results)
|
||||||
|
|
||||||
def test_string_statements(self):
|
def test_string_statements(self):
|
||||||
batch = BatchStatement(BatchType.LOGGED)
|
batch = BatchStatement(BatchType.LOGGED)
|
||||||
@@ -370,6 +370,11 @@ class BatchStatementTests(unittest.TestCase):
|
|||||||
self.session.execute(batch)
|
self.session.execute(batch)
|
||||||
self.confirm_results()
|
self.confirm_results()
|
||||||
|
|
||||||
|
def test_no_parameters_many_times(self):
|
||||||
|
for i in range(1000):
|
||||||
|
self.test_no_parameters()
|
||||||
|
self.session.execute("TRUNCATE test3rf.test")
|
||||||
|
|
||||||
|
|
||||||
class SerialConsistencyTests(unittest.TestCase):
|
class SerialConsistencyTests(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|||||||
53
tests/integration/standard/utils.py
Normal file
53
tests/integration/standard/utils.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
"""
|
||||||
|
Helper module to populate a dummy Cassandra tables with data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from tests.integration.datatype_utils import PRIMITIVE_DATATYPES, get_sample
|
||||||
|
|
||||||
|
def create_table_with_all_types(table_name, session, N):
|
||||||
|
"""
|
||||||
|
Method that given a table_name and session construct a table that contains
|
||||||
|
all possible primitive types.
|
||||||
|
|
||||||
|
:param table_name: Name of table to create
|
||||||
|
:param session: session to use for table creation
|
||||||
|
:param N: the number of items to insert into the table
|
||||||
|
|
||||||
|
:return: a list of column names
|
||||||
|
"""
|
||||||
|
# create table
|
||||||
|
alpha_type_list = ["primkey int PRIMARY KEY"]
|
||||||
|
col_names = ["primkey"]
|
||||||
|
start_index = ord('a')
|
||||||
|
for i, datatype in enumerate(PRIMITIVE_DATATYPES):
|
||||||
|
alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype))
|
||||||
|
col_names.append(chr(start_index + i))
|
||||||
|
|
||||||
|
session.execute("CREATE TABLE {0} ({1})".format(
|
||||||
|
table_name, ', '.join(alpha_type_list)), timeout=120)
|
||||||
|
|
||||||
|
# create the input
|
||||||
|
|
||||||
|
for key in range(N):
|
||||||
|
params = get_all_primitive_params(key)
|
||||||
|
|
||||||
|
# insert into table as a simple statement
|
||||||
|
columns_string = ', '.join(col_names)
|
||||||
|
placeholders = ', '.join(["%s"] * len(col_names))
|
||||||
|
session.execute("INSERT INTO {0} ({1}) VALUES ({2})".format(
|
||||||
|
table_name, columns_string, placeholders), params, timeout=120)
|
||||||
|
return col_names
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_primitive_params(key):
|
||||||
|
"""
|
||||||
|
Simple utility method used to give back a list of all possible primitive data sample types.
|
||||||
|
"""
|
||||||
|
params = [key]
|
||||||
|
for datatype in PRIMITIVE_DATATYPES:
|
||||||
|
params.append(get_sample(datatype))
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def get_primitive_datatypes():
|
||||||
|
return ['int'] + list(PRIMITIVE_DATATYPES)
|
||||||
@@ -75,7 +75,7 @@ class StressInsertsTests(unittest.TestCase):
|
|||||||
break
|
break
|
||||||
for conn in pool.get_connections():
|
for conn in pool.get_connections():
|
||||||
if conn.in_flight > 1:
|
if conn.in_flight > 1:
|
||||||
print self.session.get_pool_state()
|
print(self.session.get_pool_state())
|
||||||
leaking_connections = True
|
leaking_connections = True
|
||||||
break
|
break
|
||||||
i = i + 1
|
i = i + 1
|
||||||
|
|||||||
14
tests/unit/cython/__init__.py
Normal file
14
tests/unit/cython/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
# Copyright 2013-2015 DataStax, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
44
tests/unit/cython/bytesio_testhelper.pyx
Normal file
44
tests/unit/cython/bytesio_testhelper.pyx
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
# Copyright 2013-2015 DataStax, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from cassandra.bytesio cimport BytesIOReader
|
||||||
|
|
||||||
|
def test_read1(assert_equal, assert_raises):
|
||||||
|
cdef BytesIOReader reader = BytesIOReader(b'abcdef')
|
||||||
|
assert_equal(reader.read(2)[:2], b'ab')
|
||||||
|
assert_equal(reader.read(2)[:2], b'cd')
|
||||||
|
assert_equal(reader.read(0)[:0], b'')
|
||||||
|
assert_equal(reader.read(2)[:2], b'ef')
|
||||||
|
|
||||||
|
def test_read2(assert_equal, assert_raises):
|
||||||
|
cdef BytesIOReader reader = BytesIOReader(b'abcdef')
|
||||||
|
reader.read(5)
|
||||||
|
reader.read(1)
|
||||||
|
|
||||||
|
def test_read3(assert_equal, assert_raises):
|
||||||
|
cdef BytesIOReader reader = BytesIOReader(b'abcdef')
|
||||||
|
reader.read(6)
|
||||||
|
|
||||||
|
def test_read_eof(assert_equal, assert_raises):
|
||||||
|
cdef BytesIOReader reader = BytesIOReader(b'abcdef')
|
||||||
|
reader.read(5)
|
||||||
|
# cannot convert reader.read to an object, do it manually
|
||||||
|
# assert_raises(EOFError, reader.read, 2)
|
||||||
|
try:
|
||||||
|
reader.read(2)
|
||||||
|
except EOFError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise Exception("Expected an EOFError")
|
||||||
|
reader.read(1) # see that we can still read this
|
||||||
35
tests/unit/cython/test_bytesio.py
Normal file
35
tests/unit/cython/test_bytesio.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# Copyright 2013-2015 DataStax, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from tests.unit.cython.utils import cyimport, cythontest
|
||||||
|
bytesio_testhelper = cyimport('tests.unit.cython.bytesio_testhelper')
|
||||||
|
|
||||||
|
try:
|
||||||
|
import unittest2 as unittest
|
||||||
|
except ImportError:
|
||||||
|
import unittest # noqa
|
||||||
|
|
||||||
|
|
||||||
|
class BytesIOTest(unittest.TestCase):
|
||||||
|
"""Test Cython BytesIO proxy"""
|
||||||
|
|
||||||
|
@cythontest
|
||||||
|
def test_reading(self):
|
||||||
|
bytesio_testhelper.test_read1(self.assertEqual, self.assertRaises)
|
||||||
|
bytesio_testhelper.test_read2(self.assertEqual, self.assertRaises)
|
||||||
|
bytesio_testhelper.test_read3(self.assertEqual, self.assertRaises)
|
||||||
|
|
||||||
|
@cythontest
|
||||||
|
def test_reading_error(self):
|
||||||
|
bytesio_testhelper.test_read_eof(self.assertEqual, self.assertRaises)
|
||||||
38
tests/unit/cython/utils.py
Normal file
38
tests/unit/cython/utils.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
# Copyright 2013-2015 DataStax, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY
|
||||||
|
|
||||||
|
try:
|
||||||
|
import unittest2 as unittest
|
||||||
|
except ImportError:
|
||||||
|
import unittest # noqa
|
||||||
|
|
||||||
|
def cyimport(import_path):
|
||||||
|
"""
|
||||||
|
Import a Cython module if available, otherwise return None
|
||||||
|
(and skip any relevant tests).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return __import__(import_path, fromlist=[True])
|
||||||
|
except ImportError:
|
||||||
|
if HAVE_CYTHON:
|
||||||
|
raise
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# @cythontest
|
||||||
|
# def test_something(self): ...
|
||||||
|
cythontest = unittest.skipUnless(HAVE_CYTHON, 'Cython is not available')
|
||||||
|
numpytest = unittest.skipUnless(HAVE_CYTHON and HAVE_NUMPY, 'NumPy is not available')
|
||||||
Reference in New Issue
Block a user