Add util.Duration and its string encoding

This commit is contained in:
Alan Boudreault
2017-01-16 14:53:28 -05:00
parent f2039b751a
commit 18aa88112e
4 changed files with 65 additions and 9 deletions

View File

@@ -48,7 +48,7 @@ from cassandra.marshal import (int8_pack, int8_unpack, int16_pack, int16_unpack,
uint16_pack, uint16_unpack, uint32_pack, uint32_unpack,
int32_pack, int32_unpack, int64_pack, int64_unpack,
float_pack, float_unpack, double_pack, double_unpack,
varint_pack, varint_unpack, vints_unpack)
varint_pack, varint_unpack, vints_pack, vints_unpack)
from cassandra import util
apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.'
@@ -665,13 +665,17 @@ class DurationType(_CassandraType):
@staticmethod
def deserialize(byts, protocol_version):
print vints_unpack(byts)
varint_unpack(byts)
return varint_unpack(byts)
months, days, nanoseconds = vints_unpack(byts)
return util.Duration(months, days, nanoseconds)
@staticmethod
def serialize(byts, protocol_version):
return # ...
def serialize(duration, protocol_version):
try:
duration.validate()
m, d, n = duration.months, duration.days, duration.nanoseconds
except AttributeError:
raise TypeError('DurationType arguments must be a Duration.')
return vints_pack([m, d, n])
class UTF8Type(_CassandraType):

View File

@@ -30,7 +30,7 @@ from uuid import UUID
import six
from cassandra.util import (OrderedDict, OrderedMap, OrderedMapSerializedKey,
sortedset, Time, Date)
sortedset, Time, Date, Duration)
if six.PY3:
long = int
@@ -88,7 +88,8 @@ class Encoder(object):
sortedset: self.cql_encode_set_collection,
frozenset: self.cql_encode_set_collection,
types.GeneratorType: self.cql_encode_list_collection,
ValueSequence: self.cql_encode_sequence
ValueSequence: self.cql_encode_sequence,
Duration: self.cql_encode_duration
}
if six.PY2:
@@ -225,3 +226,9 @@ class Encoder(object):
if :attr:`~Encoder.mapping` does not contain an entry for the type.
"""
return self.mapping.get(type(val), self.cql_encode_object)(val)
def cql_encode_duration(self, val):
"""
Encodes a :class:`cassandra.util.Duration` object as a string.
"""
return str(val)

View File

@@ -127,7 +127,7 @@ def vints_unpack(term): # noqa
def vints_pack(values):
revbytes = bytearray()
values.reverse()
values = [int(v) for v in values[::-1]]
for v in values:
v = encode_zig_zag(v)
if v < 128:

View File

@@ -1193,3 +1193,48 @@ def _sanitize_identifiers(field_names):
names_out[index] = "%s_" % (names_out[index],)
observed_names.add(names_out[index])
return names_out
class Duration(object):
"""
Cassandra Duration Type
"""
months = 0
days = 0
nanoseconds = 0
def __init__(self, months=0, days=0, nanoseconds=0):
self.months = months
self.days = days
self.nanoseconds = nanoseconds
self.validate()
def validate(self):
"""
A Duration is valid if its values are all positive or all negative. It cannot has mixed signs.
"""
if self._has_negative_values and self._has_positive_values:
raise ValueError('Duration values cannot have mixed signs.')
@property
def _has_negative_values(self):
return self.months < 0 or self.days < 0 or self.nanoseconds < 0
@property
def _has_positive_values(self):
return self.months > 0 or self.days > 0 or self.nanoseconds > 0
def __eq__(self, other):
return isinstance(other, self.__class__) and self.months == other.months and self.days == other.days and self.nanoseconds == other.nanoseconds
def __repr__(self):
return "Duration({0}, {1}, {2})".format(self.months, self.days, self.nanoseconds)
def __str__(self):
return '{0}{1}mo{2}d{3}ns'.format(
'-' if self._has_negative_values else '',
abs(self.months),
abs(self.days),
abs(self.nanoseconds)
)