More Cython-based object deserializers

This commit is contained in:
Mark Florisson
2015-08-06 13:45:23 +01:00
parent ddeb753662
commit 2d8ad6ad3a
9 changed files with 275 additions and 74 deletions

View File

@@ -1,8 +1,32 @@
"""
Simple buffer data structure that provides a view on existing memory
(e.g. from a bytes object). This memory must stay alive while the
buffer is in use.
"""
from cpython.bytes cimport PyBytes_AS_STRING
# char* PyBytes_AS_STRING(object string)
# Macro form of PyBytes_AsString() but without error
# checking. Only string objects are supported; no Unicode objects
# should be passed.
from cassandra.buffer cimport Buffer
cdef struct Buffer:
char *ptr
Py_ssize_t size
cdef inline Buffer from_bytes(bytes byts)
cdef inline bytes to_bytes(Buffer *buf)
cdef inline char *buf_ptr(Buffer *buf)
cdef inline Buffer from_ptr_and_size(char *ptr, Py_ssize_t size)
cdef inline Buffer from_bytes(bytes byts):
return from_ptr_and_size(PyBytes_AS_STRING(byts), len(byts))
cdef inline bytes to_bytes(Buffer *buf):
return buf.ptr[:buf.size]
cdef inline char *buf_ptr(Buffer *buf):
return buf.ptr
cdef inline Buffer from_ptr_and_size(char *ptr, Py_ssize_t size):
cdef Buffer res
res.ptr = ptr
res.size = size
return res

View File

@@ -1,38 +0,0 @@
"""
Simple buffer data structure. This buffer can be included:
include "buffer.pyx"
or imported:
from cassanda cimport buffer
but this prevents inlining of the functions below.
"""
from cpython.bytes cimport PyBytes_AS_STRING
# char* PyBytes_AS_STRING(object string)
# Macro form of PyBytes_AsString() but without error
# checking. Only string objects are supported; no Unicode objects
# should be passed.
from cassandra.buffer cimport Buffer
cdef struct Buffer:
char *ptr
Py_ssize_t size
cdef inline Buffer from_bytes(bytes byts):
return from_ptr_and_size(PyBytes_AS_STRING(byts), len(byts))
cdef inline bytes to_bytes(Buffer *buf):
return buf.ptr[:buf.size]
cdef inline char *buf_ptr(Buffer *buf):
return buf.ptr
cdef inline Buffer from_ptr_and_size(char *ptr, Py_ssize_t size):
cdef Buffer res
res.ptr = ptr
res.size = size
return res

View File

@@ -11,7 +11,7 @@ include "ioutils.pyx"
def make_recv_results_rows(ColumnParser colparser):
def recv_results_rows(cls, f, protocol_version, user_type_map):
def recv_results_rows(cls, f, int protocol_version, user_type_map):
"""
Parse protocol data given as a BytesIO f into a set of columns (e.g. list of tuples)
This is used as the recv_results_rows method of (Fast)ResultMessage

View File

@@ -0,0 +1,27 @@
"""
Duplicate module of util.py, with some accelerated functions
used for deserialization.
"""
# from __future__ import with_statement
from cpython.datetime cimport timedelta_new
# cdef inline object timedelta_new(int days, int seconds, int useconds)
# Create timedelta object using DateTime CAPI factory function.
# Note, there are no range checks for any of the arguments.
import calendar
import datetime
import random
import six
import uuid
import sys
DATETIME_EPOC = datetime.datetime(1970, 1, 1)
assert sys.byteorder in ('little', 'big')
is_little_endian = sys.byteorder == 'little'
cdef datetime_from_timestamp(timestamp):
return DATETIME_EPOC + timedelta_new(0, timestamp, 0)

View File

@@ -3,5 +3,5 @@
from cassandra.buffer cimport Buffer
cdef class Deserializer:
cdef deserialize(self, Buffer *buf, protocol_version)
cdef deserialize(self, Buffer *buf, int protocol_version)
# cdef deserialize(self, CString byts, protocol_version)

View File

@@ -1,69 +1,246 @@
# -- cython: profile=True
from libc.stdint cimport int32_t, uint16_t
include 'marshal.pyx'
include 'buffer.pyx'
include 'cython_utils.pyx'
from cassandra.buffer cimport Buffer, to_bytes
from cython.view cimport array as cython_array
import socket
import inspect
from decimal import Decimal
from uuid import UUID
import inspect
from cassandra import util
cdef class Deserializer:
cdef deserialize(self, Buffer *buf, protocol_version):
cdef deserialize(self, Buffer *buf, int protocol_version):
raise NotImplementedError
cdef class DesLongType(Deserializer):
cdef deserialize(self, Buffer *buf, protocol_version):
cdef deserialize(self, Buffer *buf, int protocol_version):
return int64_unpack(buf.ptr)
# TODO: Use libmpdec: http://www.bytereef.org/mpdecimal/index.html
cdef class DesDecimalType(Deserializer):
cdef deserialize(self, Buffer *buf, protocol_version):
cdef deserialize(self, Buffer *buf, int protocol_version):
scale = int32_unpack(buf.ptr)
unscaled = varint_unpack(buf.ptr + 4)
return Decimal('%de%d' % (unscaled, -scale))
cdef class DesUUIDType(Deserializer):
cdef deserialize(self, Buffer *buf, protocol_version):
cdef deserialize(self, Buffer *buf, int protocol_version):
return UUID(bytes=to_bytes(buf))
cdef class DesBooleanType(Deserializer):
cdef deserialize(self, Buffer *buf, protocol_version):
cdef deserialize(self, Buffer *buf, int protocol_version):
return bool(int8_unpack(buf.ptr))
cdef class DesByteType(Deserializer):
cdef deserialize(self, Buffer *buf, protocol_version):
cdef deserialize(self, Buffer *buf, int protocol_version):
return int8_unpack(buf.ptr)
cdef class DesAsciiType(Deserializer):
cdef deserialize(self, Buffer *buf, protocol_version):
cdef deserialize(self, Buffer *buf, int protocol_version):
if six.PY2:
return to_bytes(buf)
return to_bytes(buf).decode('ascii')
cdef class DesFloatType(Deserializer):
cdef deserialize(self, Buffer *buf, protocol_version):
cdef deserialize(self, Buffer *buf, int protocol_version):
return float_unpack(buf.ptr)
cdef class DesDoubleType(Deserializer):
cdef deserialize(self, Buffer *buf, protocol_version):
cdef deserialize(self, Buffer *buf, int protocol_version):
return double_unpack(buf.ptr)
cdef class DesInt32Type(Deserializer):
cdef deserialize(self, Buffer *buf, protocol_version):
cdef deserialize(self, Buffer *buf, int protocol_version):
return int32_unpack(buf.ptr)
cdef class DesIntegerType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
return varint_unpack(to_bytes(buf))
cdef class DesInetAddressType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
cdef bytes byts = to_bytes(buf)
# TODO: optimize inet_ntop, inet_ntoa
if len(buf.size) == 16:
return util.inet_ntop(socket.AF_INET6, byts)
else:
# util.inet_pton could also handle, but this is faster
# since we've already determined the AF
return socket.inet_ntoa(byts)
cdef class DesCounterColumnType(DesLongType):
pass
cdef class DesDateType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
timestamp = int64_unpack(buf.ptr) / 1000.0
return datetime_from_timestamp(timestamp)
cdef class TimestampType(DesDateType):
pass
cdef class TimeUUIDType(DesDateType):
cdef deserialize(self, Buffer *buf, int protocol_version):
return UUID(bytes=to_bytes(buf))
# Values of the 'date'` type are encoded as 32-bit unsigned integers
# representing a number of days with epoch (January 1st, 1970) at the center of the
# range (2^31).
EPOCH_OFFSET_DAYS = 2 ** 31
cdef class DesSimpleDateType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
days = uint32_unpack(buf.ptr) - EPOCH_OFFSET_DAYS
return util.Date(days)
cdef class DesShortType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
return int16_unpack(buf.ptr)
cdef class DesTimeType(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
return util.Time(int64_unpack(to_bytes(buf)))
cdef class DesUTF8Type(Deserializer):
cdef deserialize(self, Buffer *buf, int protocol_version):
return to_bytes(buf).decode('utf8')
cdef class DesVarcharType(DesUTF8Type):
pass
cdef class _DesParameterizedType(Deserializer):
cdef object cqltype
cdef object adapter
cdef object subtypes
cdef Deserializer[::1] deserializers
def __init__(self, cqltype):
assert cqltype.subtypes and len(cqltype.subtypes) == 1
self.cqltype = cqltype
self.adapter = cqltype.adapter
self.subtypes = cqltype.subtypes
self.deserializers = make_deserializers(cqltype.subtypes)
cdef class _DesSimpleParameterizedType(_DesParameterizedType):
cdef deserialize(self, Buffer *buf, int protocol_version):
cdef uint16_t v2_and_below = 0
cdef int32_t v3_and_above = 0
if protocol_version >= 3:
result = _deserialize_parameterized[int32_t](
v3_and_above, self.deserializers[0], buf, protocol_version)
else:
result = _deserialize_parameterized[uint16_t](
v2_and_below, self.deserializers[0], buf, protocol_version)
return self.adapter(result)
ctypedef fused itemlen_t:
uint16_t # protocol <= v2
int32_t # protocol >= v3
cdef itemlen_t _unpack(itemlen_t dummy, const char *buf):
cdef itemlen_t result
if itemlen_t is uint16_t:
result = uint16_unpack(buf)
else:
result = int32_unpack(buf)
return result
cdef list _deserialize_parameterized(
itemlen_t dummy, Deserializer deserializer,
Buffer *buf, int protocol_version):
cdef itemlen_t itemlen
cdef Buffer sub_buf
cdef itemlen_t numelements = _unpack[itemlen_t](dummy, buf.ptr)
cdef itemlen_t p = sizeof(itemlen_t)
cdef list result = []
for _ in range(numelements):
itemlen = _unpack[itemlen_t](dummy, buf.ptr + p)
p += sizeof(itemlen_t)
sub_buf.ptr = buf.ptr + p
sub_buf.size = itemlen
p += itemlen
result.append(deserializer.deserialize(&sub_buf, protocol_version))
return result
# cdef deserialize_v3_and_above(
# Deserializer deserializer, Buffer *buf, int protocol_version):
# cdef Py_ssize_t itemlen
# cdef Buffer sub_buf
#
# cdef Py_ssize_t numelements = int32_unpack(buf.ptr)
# cdef Py_ssize_t p = 4
# cdef list result = []
#
# for _ in range(numelements):
# itemlen = int32_unpack(buf.ptr + p)
# p += 4
# sub_buf.ptr = buf.ptr + p
# sub_buf.size = itemlen
# p += itemlen
# result.append(deserializer.deserialize(&sub_buf, protocol_version))
#
# return result
#
#
# cdef deserialize_v2_and_below(
# Deserializer deserializer, Buffer *buf, int protocol_version):
# cdef Py_ssize_t itemlen
# cdef Buffer sub_buf
#
# cdef Py_ssize_t numelements = uint16_unpack(buf.ptr)
# cdef Py_ssize_t p = 2
# cdef list result = []
#
# for _ in range(numelements):
# itemlen = uint16_unpack(buf.ptr + p)
# p += 2
# sub_buf.ptr = buf.ptr + p
# sub_buf.size = itemlen
# p += itemlen
# result.append(deserializer.deserialize(&sub_buf, protocol_version))
#
# return result
cdef class GenericDeserializer(Deserializer):
"""
Wrap a generic datatype for deserialization
@@ -74,7 +251,7 @@ cdef class GenericDeserializer(Deserializer):
def __init__(self, cqltype):
self.cqltype = cqltype
cdef deserialize(self, Buffer *buf, protocol_version):
cdef deserialize(self, Buffer *buf, int protocol_version):
return self.cqltype.deserialize(to_bytes(buf), protocol_version)
#--------------------------------------------------------------------------

View File

@@ -1,5 +1,5 @@
include 'marshal.pyx'
include 'buffer.pyx'
from cassandra.buffer cimport Buffer
from libc.stdint cimport int32_t
from cassandra.bytesio cimport BytesIOReader

View File

@@ -25,6 +25,8 @@ from libc.stdint cimport (int8_t, int16_t, int32_t, int64_t,
cdef bint is_little_endian
from cassandra.util import is_little_endian
cdef bint PY3 = six.PY3
# cdef extern from "marshal.h":
# cdef str c_string_to_python(char *p, Py_ssize_t len)
@@ -165,21 +167,30 @@ v3_header_pack = v3_header_struct.pack
v3_header_unpack = v3_header_struct.unpack
if six.PY3:
def varint_unpack(term):
val = int(''.join("%02x" % i for i in term), 16)
if (term[0] & 128) != 0:
# There is a bug in Cython (0.20 - 0.22), where if we do
# '1 << (len(term) * 8)' Cython generates '1' directly into the
# C code, causing integer overflows. Treat it as an object for now
val -= (<object> 1L) << (len(term) * 8)
return val
else:
def varint_unpack(term): # noqa
val = int(term.encode('hex'), 16)
if (ord(term[0]) & 128) != 0:
val = val - (1 << (len(term) * 8))
return val
cpdef varint_unpack(term):
"""Unpack a variable-sized integer"""
if PY3:
return varint_unpack_py3(term)
else:
return varint_unpack_py2(term)
# TODO: Optimize these two functions
def varint_unpack_py3(term):
cdef int64_t one = 1L
val = int(''.join("%02x" % i for i in term), 16)
if (term[0] & 128) != 0:
# There is a bug in Cython (0.20 - 0.22), where if we do
# '1 << (len(term) * 8)' Cython generates '1' directly into the
# C code, causing integer overflows
val -= one << (len(term) * 8)
return val
def varint_unpack_py2(term): # noqa
cdef int64_t one = 1L
val = int(term.encode('hex'), 16)
if (ord(term[0]) & 128) != 0:
val = val - (one << (len(term) * 8))
return val
def bitlength(n):

View File

@@ -5,7 +5,7 @@ cdef class ParseDesc:
cdef public object colnames
cdef public object coltypes
cdef Deserializer[::1] deserializers
cdef public object protocol_version
cdef public int protocol_version
cdef Py_ssize_t rowsize
cdef class ColumnParser: