Merge pull request #689 from datastax/python-655

PYTHON-655 Duration Type
This commit is contained in:
Alan Boudreault
2017-01-17 16:28:20 -05:00
committed by GitHub
5 changed files with 121 additions and 2 deletions

View File

@@ -48,7 +48,7 @@ from cassandra.marshal import (int8_pack, int8_unpack, int16_pack, int16_unpack,
uint16_pack, uint16_unpack, uint32_pack, uint32_unpack,
int32_pack, int32_unpack, int64_pack, int64_unpack,
float_pack, float_unpack, double_pack, double_unpack,
varint_pack, varint_unpack)
varint_pack, varint_unpack, vints_pack, vints_unpack)
from cassandra import util
apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.'
@@ -660,6 +660,23 @@ class TimeType(_CassandraType):
return int64_pack(nano)
class DurationType(_CassandraType):
typename = 'duration'
@staticmethod
def deserialize(byts, protocol_version):
months, days, nanoseconds = vints_unpack(byts)
return util.Duration(months, days, nanoseconds)
@staticmethod
def serialize(duration, protocol_version):
try:
m, d, n = duration.months, duration.days, duration.nanoseconds
except AttributeError:
raise TypeError('DurationType arguments must be a Duration.')
return vints_pack([m, d, n])
class UTF8Type(_CassandraType):
typename = 'text'
empty_binary_ok = True

View File

@@ -45,6 +45,10 @@ v3_header_unpack = v3_header_struct.unpack
if six.PY3:
def byte2int(b):
return b
def varint_unpack(term):
val = int(''.join("%02x" % i for i in term), 16)
if (term[0] & 128) != 0:
@@ -52,6 +56,10 @@ if six.PY3:
val -= 1 << (len_term * 8)
return val
else:
def byte2int(b):
return ord(b)
def varint_unpack(term): # noqa
val = int(term.encode('hex'), 16)
if (ord(term[0]) & 128) != 0:
@@ -84,3 +92,61 @@ def varint_pack(big):
revbytes.append(0)
revbytes.reverse()
return six.binary_type(revbytes)
def encode_zig_zag(n):
return (n << 1) ^ (n >> 63)
def decode_zig_zag(n):
return (n >> 1) ^ -(n & 1)
def vints_unpack(term): # noqa
values = []
n = 0
while n < len(term):
first_byte = byte2int(term[n])
if (first_byte & 128) == 0:
val = first_byte
else:
num_extra_bytes = 8 - (~first_byte & 0xff).bit_length()
val = first_byte & (0xff >> num_extra_bytes)
end = n + num_extra_bytes
while n < end:
n += 1
val <<= 8
val |= byte2int(term[n]) & 0xff
n += 1
values.append(decode_zig_zag(val))
return tuple(values)
def vints_pack(values):
revbytes = bytearray()
values = [int(v) for v in values[::-1]]
for v in values:
v = encode_zig_zag(v)
if v < 128:
revbytes.append(v)
else:
num_extra_bytes = 0
num_bits = v.bit_length()
# We need to reserve (num_extra_bytes+1) bits in the first byte
# ie. with 1 extra byte, the first byte needs to be something like '10XXXXXX'
while num_bits > (8-(num_extra_bytes+1)):
num_extra_bytes += 1
num_bits -= 8
revbytes.append(v & 0xff)
v >>= 8
# We can now store the last bits in the first byte
n = 8 - num_extra_bytes
v |= (0xff >> n << n)
revbytes.append(abs(v))
revbytes.reverse()
return six.binary_type(revbytes)

View File

@@ -38,7 +38,7 @@ from cassandra.cqltypes import (AsciiType, BytesType, BooleanType,
LongType, MapType, SetType, TimeUUIDType,
UTF8Type, VarcharType, UUIDType, UserType,
TupleType, lookup_casstype, SimpleDateType,
TimeType, ByteType, ShortType)
TimeType, ByteType, ShortType, DurationType)
from cassandra.policies import WriteType
from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY
from cassandra import util

View File

@@ -25,6 +25,11 @@ Type codes are repeated here from the Cassandra binary protocol specification:
0x000E Varint
0x000F Timeuuid
0x0010 Inet
0x0011 SimpleDateType
0x0012 TimeType
0x0013 ShortType
0x0014 ByteType
0x0015 DurationType
0x0020 List: the value is an [option], representing the type
of the elements of the list.
0x0021 Map: the value is two [option], representing the types of the
@@ -54,6 +59,7 @@ SimpleDateType = 0x0011
TimeType = 0x0012
ShortType = 0x0013
ByteType = 0x0014
DurationType = 0x0015
ListType = 0x0020
MapType = 0x0021
SetType = 0x0022

View File

@@ -1193,3 +1193,33 @@ def _sanitize_identifiers(field_names):
names_out[index] = "%s_" % (names_out[index],)
observed_names.add(names_out[index])
return names_out
class Duration(object):
"""
Cassandra Duration Type
"""
months = 0
days = 0
nanoseconds = 0
def __init__(self, months=0, days=0, nanoseconds=0):
self.months = months
self.days = days
self.nanoseconds = nanoseconds
def __eq__(self, other):
return isinstance(other, self.__class__) and self.months == other.months and self.days == other.days and self.nanoseconds == other.nanoseconds
def __repr__(self):
return "Duration({0}, {1}, {2})".format(self.months, self.days, self.nanoseconds)
def __str__(self):
has_negative_values = self.months < 0 or self.days < 0 or self.nanoseconds < 0
return '%s%dmo%dd%dns' % (
'-' if has_negative_values else '',
abs(self.months),
abs(self.days),
abs(self.nanoseconds)
)