Merge pull request #689 from datastax/python-655
PYTHON-655 Duration Type
This commit is contained in:
@@ -48,7 +48,7 @@ from cassandra.marshal import (int8_pack, int8_unpack, int16_pack, int16_unpack,
|
||||
uint16_pack, uint16_unpack, uint32_pack, uint32_unpack,
|
||||
int32_pack, int32_unpack, int64_pack, int64_unpack,
|
||||
float_pack, float_unpack, double_pack, double_unpack,
|
||||
varint_pack, varint_unpack)
|
||||
varint_pack, varint_unpack, vints_pack, vints_unpack)
|
||||
from cassandra import util
|
||||
|
||||
apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.'
|
||||
@@ -660,6 +660,23 @@ class TimeType(_CassandraType):
|
||||
return int64_pack(nano)
|
||||
|
||||
|
||||
class DurationType(_CassandraType):
|
||||
typename = 'duration'
|
||||
|
||||
@staticmethod
|
||||
def deserialize(byts, protocol_version):
|
||||
months, days, nanoseconds = vints_unpack(byts)
|
||||
return util.Duration(months, days, nanoseconds)
|
||||
|
||||
@staticmethod
|
||||
def serialize(duration, protocol_version):
|
||||
try:
|
||||
m, d, n = duration.months, duration.days, duration.nanoseconds
|
||||
except AttributeError:
|
||||
raise TypeError('DurationType arguments must be a Duration.')
|
||||
return vints_pack([m, d, n])
|
||||
|
||||
|
||||
class UTF8Type(_CassandraType):
|
||||
typename = 'text'
|
||||
empty_binary_ok = True
|
||||
|
||||
@@ -45,6 +45,10 @@ v3_header_unpack = v3_header_struct.unpack
|
||||
|
||||
|
||||
if six.PY3:
|
||||
def byte2int(b):
|
||||
return b
|
||||
|
||||
|
||||
def varint_unpack(term):
|
||||
val = int(''.join("%02x" % i for i in term), 16)
|
||||
if (term[0] & 128) != 0:
|
||||
@@ -52,6 +56,10 @@ if six.PY3:
|
||||
val -= 1 << (len_term * 8)
|
||||
return val
|
||||
else:
|
||||
def byte2int(b):
|
||||
return ord(b)
|
||||
|
||||
|
||||
def varint_unpack(term): # noqa
|
||||
val = int(term.encode('hex'), 16)
|
||||
if (ord(term[0]) & 128) != 0:
|
||||
@@ -84,3 +92,61 @@ def varint_pack(big):
|
||||
revbytes.append(0)
|
||||
revbytes.reverse()
|
||||
return six.binary_type(revbytes)
|
||||
|
||||
|
||||
def encode_zig_zag(n):
|
||||
return (n << 1) ^ (n >> 63)
|
||||
|
||||
|
||||
def decode_zig_zag(n):
|
||||
return (n >> 1) ^ -(n & 1)
|
||||
|
||||
|
||||
def vints_unpack(term): # noqa
|
||||
values = []
|
||||
n = 0
|
||||
while n < len(term):
|
||||
first_byte = byte2int(term[n])
|
||||
|
||||
if (first_byte & 128) == 0:
|
||||
val = first_byte
|
||||
else:
|
||||
num_extra_bytes = 8 - (~first_byte & 0xff).bit_length()
|
||||
val = first_byte & (0xff >> num_extra_bytes)
|
||||
end = n + num_extra_bytes
|
||||
while n < end:
|
||||
n += 1
|
||||
val <<= 8
|
||||
val |= byte2int(term[n]) & 0xff
|
||||
|
||||
n += 1
|
||||
values.append(decode_zig_zag(val))
|
||||
|
||||
return tuple(values)
|
||||
|
||||
|
||||
def vints_pack(values):
|
||||
revbytes = bytearray()
|
||||
values = [int(v) for v in values[::-1]]
|
||||
for v in values:
|
||||
v = encode_zig_zag(v)
|
||||
if v < 128:
|
||||
revbytes.append(v)
|
||||
else:
|
||||
num_extra_bytes = 0
|
||||
num_bits = v.bit_length()
|
||||
# We need to reserve (num_extra_bytes+1) bits in the first byte
|
||||
# ie. with 1 extra byte, the first byte needs to be something like '10XXXXXX'
|
||||
while num_bits > (8-(num_extra_bytes+1)):
|
||||
num_extra_bytes += 1
|
||||
num_bits -= 8
|
||||
revbytes.append(v & 0xff)
|
||||
v >>= 8
|
||||
|
||||
# We can now store the last bits in the first byte
|
||||
n = 8 - num_extra_bytes
|
||||
v |= (0xff >> n << n)
|
||||
revbytes.append(abs(v))
|
||||
|
||||
revbytes.reverse()
|
||||
return six.binary_type(revbytes)
|
||||
@@ -38,7 +38,7 @@ from cassandra.cqltypes import (AsciiType, BytesType, BooleanType,
|
||||
LongType, MapType, SetType, TimeUUIDType,
|
||||
UTF8Type, VarcharType, UUIDType, UserType,
|
||||
TupleType, lookup_casstype, SimpleDateType,
|
||||
TimeType, ByteType, ShortType)
|
||||
TimeType, ByteType, ShortType, DurationType)
|
||||
from cassandra.policies import WriteType
|
||||
from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY
|
||||
from cassandra import util
|
||||
|
||||
@@ -25,6 +25,11 @@ Type codes are repeated here from the Cassandra binary protocol specification:
|
||||
0x000E Varint
|
||||
0x000F Timeuuid
|
||||
0x0010 Inet
|
||||
0x0011 SimpleDateType
|
||||
0x0012 TimeType
|
||||
0x0013 ShortType
|
||||
0x0014 ByteType
|
||||
0x0015 DurationType
|
||||
0x0020 List: the value is an [option], representing the type
|
||||
of the elements of the list.
|
||||
0x0021 Map: the value is two [option], representing the types of the
|
||||
@@ -54,6 +59,7 @@ SimpleDateType = 0x0011
|
||||
TimeType = 0x0012
|
||||
ShortType = 0x0013
|
||||
ByteType = 0x0014
|
||||
DurationType = 0x0015
|
||||
ListType = 0x0020
|
||||
MapType = 0x0021
|
||||
SetType = 0x0022
|
||||
|
||||
@@ -1193,3 +1193,33 @@ def _sanitize_identifiers(field_names):
|
||||
names_out[index] = "%s_" % (names_out[index],)
|
||||
observed_names.add(names_out[index])
|
||||
return names_out
|
||||
|
||||
|
||||
class Duration(object):
|
||||
"""
|
||||
Cassandra Duration Type
|
||||
"""
|
||||
|
||||
months = 0
|
||||
days = 0
|
||||
nanoseconds = 0
|
||||
|
||||
def __init__(self, months=0, days=0, nanoseconds=0):
|
||||
self.months = months
|
||||
self.days = days
|
||||
self.nanoseconds = nanoseconds
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, self.__class__) and self.months == other.months and self.days == other.days and self.nanoseconds == other.nanoseconds
|
||||
|
||||
def __repr__(self):
|
||||
return "Duration({0}, {1}, {2})".format(self.months, self.days, self.nanoseconds)
|
||||
|
||||
def __str__(self):
|
||||
has_negative_values = self.months < 0 or self.days < 0 or self.nanoseconds < 0
|
||||
return '%s%dmo%dd%dns' % (
|
||||
'-' if has_negative_values else '',
|
||||
abs(self.months),
|
||||
abs(self.days),
|
||||
abs(self.nanoseconds)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user