Merge pull request #689 from datastax/python-655
PYTHON-655 Duration Type
This commit is contained in:
		@@ -48,7 +48,7 @@ from cassandra.marshal import (int8_pack, int8_unpack, int16_pack, int16_unpack,
 | 
			
		||||
                               uint16_pack, uint16_unpack, uint32_pack, uint32_unpack,
 | 
			
		||||
                               int32_pack, int32_unpack, int64_pack, int64_unpack,
 | 
			
		||||
                               float_pack, float_unpack, double_pack, double_unpack,
 | 
			
		||||
                               varint_pack, varint_unpack)
 | 
			
		||||
                               varint_pack, varint_unpack, vints_pack, vints_unpack)
 | 
			
		||||
from cassandra import util
 | 
			
		||||
 | 
			
		||||
apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.'
 | 
			
		||||
@@ -660,6 +660,23 @@ class TimeType(_CassandraType):
 | 
			
		||||
        return int64_pack(nano)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DurationType(_CassandraType):
 | 
			
		||||
    typename = 'duration'
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def deserialize(byts, protocol_version):
 | 
			
		||||
        months, days, nanoseconds = vints_unpack(byts)
 | 
			
		||||
        return util.Duration(months, days, nanoseconds)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def serialize(duration, protocol_version):
 | 
			
		||||
        try:
 | 
			
		||||
            m, d, n = duration.months, duration.days, duration.nanoseconds
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            raise TypeError('DurationType arguments must be a Duration.')
 | 
			
		||||
        return vints_pack([m, d, n])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UTF8Type(_CassandraType):
 | 
			
		||||
    typename = 'text'
 | 
			
		||||
    empty_binary_ok = True
 | 
			
		||||
 
 | 
			
		||||
@@ -45,6 +45,10 @@ v3_header_unpack = v3_header_struct.unpack
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if six.PY3:
 | 
			
		||||
    def byte2int(b):
 | 
			
		||||
        return b
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def varint_unpack(term):
 | 
			
		||||
        val = int(''.join("%02x" % i for i in term), 16)
 | 
			
		||||
        if (term[0] & 128) != 0:
 | 
			
		||||
@@ -52,6 +56,10 @@ if six.PY3:
 | 
			
		||||
            val -= 1 << (len_term * 8)
 | 
			
		||||
        return val
 | 
			
		||||
else:
 | 
			
		||||
    def byte2int(b):
 | 
			
		||||
        return ord(b)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def varint_unpack(term):  # noqa
 | 
			
		||||
        val = int(term.encode('hex'), 16)
 | 
			
		||||
        if (ord(term[0]) & 128) != 0:
 | 
			
		||||
@@ -84,3 +92,61 @@ def varint_pack(big):
 | 
			
		||||
        revbytes.append(0)
 | 
			
		||||
    revbytes.reverse()
 | 
			
		||||
    return six.binary_type(revbytes)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def encode_zig_zag(n):
 | 
			
		||||
    return (n << 1) ^ (n >> 63)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def decode_zig_zag(n):
 | 
			
		||||
    return (n >> 1) ^ -(n & 1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def vints_unpack(term):  # noqa
 | 
			
		||||
    values = []
 | 
			
		||||
    n = 0
 | 
			
		||||
    while n < len(term):
 | 
			
		||||
        first_byte = byte2int(term[n])
 | 
			
		||||
 | 
			
		||||
        if (first_byte & 128) == 0:
 | 
			
		||||
            val = first_byte
 | 
			
		||||
        else:
 | 
			
		||||
            num_extra_bytes = 8 - (~first_byte & 0xff).bit_length()
 | 
			
		||||
            val = first_byte & (0xff >> num_extra_bytes)
 | 
			
		||||
            end = n + num_extra_bytes
 | 
			
		||||
            while n < end:
 | 
			
		||||
                n += 1
 | 
			
		||||
                val <<= 8
 | 
			
		||||
                val |= byte2int(term[n]) & 0xff
 | 
			
		||||
 | 
			
		||||
        n += 1
 | 
			
		||||
        values.append(decode_zig_zag(val))
 | 
			
		||||
 | 
			
		||||
    return tuple(values)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def vints_pack(values):
 | 
			
		||||
    revbytes = bytearray()
 | 
			
		||||
    values = [int(v) for v in values[::-1]]
 | 
			
		||||
    for v in values:
 | 
			
		||||
        v = encode_zig_zag(v)
 | 
			
		||||
        if v < 128:
 | 
			
		||||
            revbytes.append(v)
 | 
			
		||||
        else:
 | 
			
		||||
            num_extra_bytes = 0
 | 
			
		||||
            num_bits = v.bit_length()
 | 
			
		||||
            # We need to reserve (num_extra_bytes+1) bits in the first byte
 | 
			
		||||
            # ie. with 1 extra byte, the first byte needs to be something like '10XXXXXX'
 | 
			
		||||
            while num_bits > (8-(num_extra_bytes+1)):
 | 
			
		||||
                num_extra_bytes += 1
 | 
			
		||||
                num_bits -= 8
 | 
			
		||||
                revbytes.append(v & 0xff)
 | 
			
		||||
                v >>= 8
 | 
			
		||||
 | 
			
		||||
            # We can now store the last bits in the first byte
 | 
			
		||||
            n = 8 - num_extra_bytes
 | 
			
		||||
            v |= (0xff >> n << n)
 | 
			
		||||
            revbytes.append(abs(v))
 | 
			
		||||
 | 
			
		||||
    revbytes.reverse()
 | 
			
		||||
    return six.binary_type(revbytes)
 | 
			
		||||
@@ -38,7 +38,7 @@ from cassandra.cqltypes import (AsciiType, BytesType, BooleanType,
 | 
			
		||||
                                LongType, MapType, SetType, TimeUUIDType,
 | 
			
		||||
                                UTF8Type, VarcharType, UUIDType, UserType,
 | 
			
		||||
                                TupleType, lookup_casstype, SimpleDateType,
 | 
			
		||||
                                TimeType, ByteType, ShortType)
 | 
			
		||||
                                TimeType, ByteType, ShortType, DurationType)
 | 
			
		||||
from cassandra.policies import WriteType
 | 
			
		||||
from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY
 | 
			
		||||
from cassandra import util
 | 
			
		||||
 
 | 
			
		||||
@@ -25,6 +25,11 @@ Type codes are repeated here from the Cassandra binary protocol specification:
 | 
			
		||||
            0x000E    Varint
 | 
			
		||||
            0x000F    Timeuuid
 | 
			
		||||
            0x0010    Inet
 | 
			
		||||
            0x0011    SimpleDateType
 | 
			
		||||
            0x0012    TimeType
 | 
			
		||||
            0x0013    ShortType
 | 
			
		||||
            0x0014    ByteType
 | 
			
		||||
            0x0015    DurationType
 | 
			
		||||
            0x0020    List: the value is an [option], representing the type
 | 
			
		||||
                            of the elements of the list.
 | 
			
		||||
            0x0021    Map: the value is two [option], representing the types of the
 | 
			
		||||
@@ -54,6 +59,7 @@ SimpleDateType = 0x0011
 | 
			
		||||
TimeType = 0x0012
 | 
			
		||||
ShortType = 0x0013
 | 
			
		||||
ByteType = 0x0014
 | 
			
		||||
DurationType = 0x0015
 | 
			
		||||
ListType = 0x0020
 | 
			
		||||
MapType = 0x0021
 | 
			
		||||
SetType = 0x0022
 | 
			
		||||
 
 | 
			
		||||
@@ -1193,3 +1193,33 @@ def _sanitize_identifiers(field_names):
 | 
			
		||||
                names_out[index] = "%s_" % (names_out[index],)
 | 
			
		||||
            observed_names.add(names_out[index])
 | 
			
		||||
    return names_out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Duration(object):
 | 
			
		||||
    """
 | 
			
		||||
    Cassandra Duration Type
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    months = 0
 | 
			
		||||
    days = 0
 | 
			
		||||
    nanoseconds = 0
 | 
			
		||||
 | 
			
		||||
    def __init__(self, months=0, days=0, nanoseconds=0):
 | 
			
		||||
        self.months = months
 | 
			
		||||
        self.days = days
 | 
			
		||||
        self.nanoseconds = nanoseconds
 | 
			
		||||
 | 
			
		||||
    def __eq__(self, other):
 | 
			
		||||
        return isinstance(other, self.__class__) and self.months == other.months and self.days == other.days and self.nanoseconds == other.nanoseconds
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return "Duration({0}, {1}, {2})".format(self.months, self.days, self.nanoseconds)
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        has_negative_values = self.months < 0 or self.days < 0 or self.nanoseconds < 0
 | 
			
		||||
        return '%s%dmo%dd%dns' % (
 | 
			
		||||
            '-' if has_negative_values else '',
 | 
			
		||||
            abs(self.months),
 | 
			
		||||
            abs(self.days),
 | 
			
		||||
            abs(self.nanoseconds)
 | 
			
		||||
        )
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user