diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index e333c117..76045247 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -48,7 +48,7 @@ from cassandra.marshal import (int8_pack, int8_unpack, int16_pack, int16_unpack, uint16_pack, uint16_unpack, uint32_pack, uint32_unpack, int32_pack, int32_unpack, int64_pack, int64_unpack, float_pack, float_unpack, double_pack, double_unpack, - varint_pack, varint_unpack) + varint_pack, varint_unpack, vints_pack, vints_unpack) from cassandra import util apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.' @@ -660,6 +660,23 @@ class TimeType(_CassandraType): return int64_pack(nano) +class DurationType(_CassandraType): + typename = 'duration' + + @staticmethod + def deserialize(byts, protocol_version): + months, days, nanoseconds = vints_unpack(byts) + return util.Duration(months, days, nanoseconds) + + @staticmethod + def serialize(duration, protocol_version): + try: + m, d, n = duration.months, duration.days, duration.nanoseconds + except AttributeError: + raise TypeError('DurationType arguments must be a Duration.') + return vints_pack([m, d, n]) + + class UTF8Type(_CassandraType): typename = 'text' empty_binary_ok = True diff --git a/cassandra/marshal.py b/cassandra/marshal.py index 5a523d63..9647e22b 100644 --- a/cassandra/marshal.py +++ b/cassandra/marshal.py @@ -45,6 +45,10 @@ v3_header_unpack = v3_header_struct.unpack if six.PY3: + def byte2int(b): + return b + + def varint_unpack(term): val = int(''.join("%02x" % i for i in term), 16) if (term[0] & 128) != 0: @@ -52,6 +56,10 @@ if six.PY3: val -= 1 << (len_term * 8) return val else: + def byte2int(b): + return ord(b) + + def varint_unpack(term): # noqa val = int(term.encode('hex'), 16) if (ord(term[0]) & 128) != 0: @@ -84,3 +92,61 @@ def varint_pack(big): revbytes.append(0) revbytes.reverse() return six.binary_type(revbytes) + + +def encode_zig_zag(n): + return (n << 1) ^ (n >> 63) + + +def decode_zig_zag(n): + return (n >> 1) ^ -(n & 1) + + +def vints_unpack(term): # noqa + values = [] + n = 0 + while n < len(term): + first_byte = byte2int(term[n]) + + if (first_byte & 128) == 0: + val = first_byte + else: + num_extra_bytes = 8 - (~first_byte & 0xff).bit_length() + val = first_byte & (0xff >> num_extra_bytes) + end = n + num_extra_bytes + while n < end: + n += 1 + val <<= 8 + val |= byte2int(term[n]) & 0xff + + n += 1 + values.append(decode_zig_zag(val)) + + return tuple(values) + + +def vints_pack(values): + revbytes = bytearray() + values = [int(v) for v in values[::-1]] + for v in values: + v = encode_zig_zag(v) + if v < 128: + revbytes.append(v) + else: + num_extra_bytes = 0 + num_bits = v.bit_length() + # We need to reserve (num_extra_bytes+1) bits in the first byte + # ie. with 1 extra byte, the first byte needs to be something like '10XXXXXX' + while num_bits > (8-(num_extra_bytes+1)): + num_extra_bytes += 1 + num_bits -= 8 + revbytes.append(v & 0xff) + v >>= 8 + + # We can now store the last bits in the first byte + n = 8 - num_extra_bytes + v |= (0xff >> n << n) + revbytes.append(abs(v)) + + revbytes.reverse() + return six.binary_type(revbytes) \ No newline at end of file diff --git a/cassandra/protocol.py b/cassandra/protocol.py index 83328364..313c1aa9 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -38,7 +38,7 @@ from cassandra.cqltypes import (AsciiType, BytesType, BooleanType, LongType, MapType, SetType, TimeUUIDType, UTF8Type, VarcharType, UUIDType, UserType, TupleType, lookup_casstype, SimpleDateType, - TimeType, ByteType, ShortType) + TimeType, ByteType, ShortType, DurationType) from cassandra.policies import WriteType from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY from cassandra import util diff --git a/cassandra/type_codes.py b/cassandra/type_codes.py index daf882e4..eab9a334 100644 --- a/cassandra/type_codes.py +++ b/cassandra/type_codes.py @@ -25,6 +25,11 @@ Type codes are repeated here from the Cassandra binary protocol specification: 0x000E Varint 0x000F Timeuuid 0x0010 Inet + 0x0011 SimpleDateType + 0x0012 TimeType + 0x0013 ShortType + 0x0014 ByteType + 0x0015 DurationType 0x0020 List: the value is an [option], representing the type of the elements of the list. 0x0021 Map: the value is two [option], representing the types of the @@ -54,6 +59,7 @@ SimpleDateType = 0x0011 TimeType = 0x0012 ShortType = 0x0013 ByteType = 0x0014 +DurationType = 0x0015 ListType = 0x0020 MapType = 0x0021 SetType = 0x0022 diff --git a/cassandra/util.py b/cassandra/util.py index 7f17e85d..924b5c79 100644 --- a/cassandra/util.py +++ b/cassandra/util.py @@ -1193,3 +1193,33 @@ def _sanitize_identifiers(field_names): names_out[index] = "%s_" % (names_out[index],) observed_names.add(names_out[index]) return names_out + + +class Duration(object): + """ + Cassandra Duration Type + """ + + months = 0 + days = 0 + nanoseconds = 0 + + def __init__(self, months=0, days=0, nanoseconds=0): + self.months = months + self.days = days + self.nanoseconds = nanoseconds + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.months == other.months and self.days == other.days and self.nanoseconds == other.nanoseconds + + def __repr__(self): + return "Duration({0}, {1}, {2})".format(self.months, self.days, self.nanoseconds) + + def __str__(self): + has_negative_values = self.months < 0 or self.days < 0 or self.nanoseconds < 0 + return '%s%dmo%dd%dns' % ( + '-' if has_negative_values else '', + abs(self.months), + abs(self.days), + abs(self.nanoseconds) + )