From de8ae577a22643d429a2a129601ec323ca434cb3 Mon Sep 17 00:00:00 2001 From: Adam Holmberg Date: Thu, 15 Jan 2015 20:19:14 -0600 Subject: [PATCH] Updates to 'date' and 'time' type handling. datetime.date --> 'date' datetime.time --> 'time' datetime.time seemed to be the best native type for time, but it doesn't have nanosecond resolution. Clients must use int types for that. --- cassandra/cqltypes.py | 140 ++++++++++++----------- cassandra/encoder.py | 12 +- tests/integration/standard/test_types.py | 16 ++- tests/unit/test_marshalling.py | 7 +- tests/unit/test_types.py | 93 +++++++++------ 5 files changed, 151 insertions(+), 117 deletions(-) diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index 9d375d4a..abcb4a10 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -36,7 +36,7 @@ import io import re import socket import time -from datetime import datetime, timedelta +import datetime from uuid import UUID import warnings @@ -55,8 +55,8 @@ apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.' if six.PY3: _number_types = frozenset((int, float)) - _time_types = frozenset((int)) - _date_types = frozenset((int)) + _time_types = frozenset((int,)) + _date_types = frozenset((int,)) long = int else: _number_types = frozenset((int, long, float)) @@ -74,6 +74,15 @@ def unix_time_from_uuid1(u): return (u.time - 0x01B21DD213814000) / 10000000.0 +def datetime_from_timestamp(timestamp): + if timestamp >= 0: + dt = datetime.datetime.utcfromtimestamp(timestamp) + else: + # PYTHON-119: workaround for Windows + dt = datetime.datetime(1970, 1, 1) + datetime.timedelta(seconds=timestamp) + return dt + + _casstypes = {} @@ -543,26 +552,26 @@ class DateType(_CassandraType): typename = 'timestamp' @classmethod - def validate(cls, date): - if isinstance(date, six.string_types): - date = cls.interpret_datestring(date) - return date + def validate(cls, val): + if isinstance(val, six.string_types): + val = cls.interpret_datestring(val) + return val @staticmethod - def interpret_datestring(date): - if date[-5] in ('+', '-'): - offset = (int(date[-4:-2]) * 3600 + int(date[-2:]) * 60) * int(date[-5] + '1') - date = date[:-5] + def interpret_datestring(val): + if val[-5] in ('+', '-'): + offset = (int(val[-4:-2]) * 3600 + int(val[-2:]) * 60) * int(val[-5] + '1') + val = val[:-5] else: offset = -time.timezone for tformat in cql_timestamp_formats: try: - tval = time.strptime(date, tformat) + tval = time.strptime(val, tformat) except ValueError: continue return calendar.timegm(tval) + offset else: - raise ValueError("can't interpret %r as a date" % (date,)) + raise ValueError("can't interpret %r as a date" % (val,)) def my_timestamp(self): return self.val @@ -570,12 +579,7 @@ class DateType(_CassandraType): @staticmethod def deserialize(byts, protocol_version): timestamp = int64_unpack(byts) / 1000.0 - if timestamp >= 0: - dt = datetime.utcfromtimestamp(timestamp) - else: - # PYTHON-119: workaround for Windows - dt = datetime(1970, 1, 1) + timedelta(seconds=timestamp) - return dt + return datetime_from_timestamp(timestamp) @staticmethod def serialize(v, protocol_version): @@ -635,85 +639,87 @@ class SimpleDateType(_CassandraType): date_format = "%Y-%m-%d" @classmethod - def validate(cls, date): - if isinstance(date, basestring): - date = cls.interpret_simpledate_string(date) - return date + def validate(cls, val): + if isinstance(val, six.string_types): + val = cls.interpret_simpledate_string(val) + elif (not isinstance(val, datetime.date)) and (type(val) not in _date_types): + raise TypeError('SimpleDateType arg must be a datetime.date, unsigned integer, or string in the format YYYY-MM-DD') + return val @staticmethod def interpret_simpledate_string(v): - try: - tval = time.strptime(v, SimpleDateType.date_format) - # shift upward w/epoch at 2**31 - return (calendar.timegm(tval) / SimpleDateType.seconds_per_day) + 2**31 - except TypeError: - # Ints are valid dates too - if type(v) not in _date_types: - raise TypeError('Date arguments must be an unsigned integer or string in the format YYYY-MM-DD') - return v + date_time = datetime.datetime.strptime(v, SimpleDateType.date_format) + return datetime.date(date_time.year, date_time.month, date_time.day) @staticmethod def serialize(val, protocol_version): - date_val = SimpleDateType.interpret_simpledate_string(val) - return uint32_pack(date_val) + # Values of the 'date'` type are encoded as 32-bit unsigned integers + # representing a number of days with "the epoch" at the center of the + # range (2^31). Epoch is January 1st, 1970 + try: + shifted = (calendar.timegm(val.timetuple()) // SimpleDateType.seconds_per_day) + 2 ** 31 + except AttributeError: + shifted = val + return uint32_pack(shifted) @staticmethod def deserialize(byts, protocol_version): - Result = namedtuple('SimpleDate', 'value') - return Result(value=uint32_unpack(byts)) + timestamp = SimpleDateType.seconds_per_day * (uint32_unpack(byts) - 2 ** 31) + dt = datetime.datetime.utcfromtimestamp(timestamp) + return datetime.date(dt.year, dt.month, dt.day) class TimeType(_CassandraType): typename = 'time' - ONE_MICRO=1000 - ONE_MILLI=1000*ONE_MICRO - ONE_SECOND=1000*ONE_MILLI - ONE_MINUTE=60*ONE_SECOND - ONE_HOUR=60*ONE_MINUTE + ONE_MICRO = 1000 + ONE_MILLI = 1000 * ONE_MICRO + ONE_SECOND = 1000 * ONE_MILLI + ONE_MINUTE = 60 * ONE_SECOND + ONE_HOUR = 60 * ONE_MINUTE @classmethod def validate(cls, val): - if isinstance(val, basestring): - time = cls.interpret_timestring(val) - return time + if isinstance(val, six.string_types): + val = cls.interpret_timestring(val) + elif (not isinstance(val, datetime.time)) and (type(val) not in _time_types): + raise TypeError('TimeType arguments must be a string or whole number') + return val @staticmethod def interpret_timestring(val): try: nano = 0 - try: - base_time_str = val - if '.' in base_time_str: - base_time_str = val[0:val.find('.')] - base_time = time.strptime(base_time_str, "%H:%M:%S") - nano = base_time.tm_hour * TimeType.ONE_HOUR - nano += base_time.tm_min * TimeType.ONE_MINUTE - nano += base_time.tm_sec * TimeType.ONE_SECOND + parts = val.split('.') + base_time = time.strptime(parts[0], "%H:%M:%S") + nano = (base_time.tm_hour * TimeType.ONE_HOUR + + base_time.tm_min * TimeType.ONE_MINUTE + + base_time.tm_sec * TimeType.ONE_SECOND) - if '.' in val: - nano_time_str = val[val.find('.')+1:] - # right pad to 9 digits - while len(nano_time_str) < 9: - nano_time_str += "0" - nano += int(nano_time_str) + if len(parts) > 1: + # right pad to 9 digits + nano_time_str = parts[1] + "0" * (9 - len(parts[1])) + nano += int(nano_time_str) - except AttributeError as e: - if type(val) not in _time_types: - raise TypeError('TimeType arguments must be a string or whole number') - # long / int values passed in are acceptable too - nano = val return nano - except ValueError as e: + except ValueError: raise ValueError("can't interpret %r as a time" % (val,)) @staticmethod def serialize(val, protocol_version): - return int64_pack(TimeType.interpret_timestring(val)) + # Values of the @time@ type are encoded as 64-bit signed integers + # representing the number of nanoseconds since midnight. + try: + nano = (val.hour * TimeType.ONE_HOUR + + val.minute * TimeType.ONE_MINUTE + + val.second * TimeType.ONE_SECOND + + val.microsecond * TimeType.ONE_MICRO) + except AttributeError: + nano = val + return int64_pack(nano) @staticmethod def deserialize(byts, protocol_version): - Result = namedtuple('Time', 'value') - return Result(value=int64_unpack(byts)) + return int64_unpack(byts) class UTF8Type(_CassandraType): diff --git a/cassandra/encoder.py b/cassandra/encoder.py index 8b1bb2fe..2f3cfa92 100644 --- a/cassandra/encoder.py +++ b/cassandra/encoder.py @@ -74,6 +74,7 @@ class Encoder(object): UUID: self.cql_encode_object, datetime.datetime: self.cql_encode_datetime, datetime.date: self.cql_encode_date, + datetime.time: self.cql_encode_time, dict: self.cql_encode_map_collection, OrderedDict: self.cql_encode_map_collection, list: self.cql_encode_list_collection, @@ -146,9 +147,16 @@ class Encoder(object): def cql_encode_date(self, val): """ Converts a :class:`datetime.date` object to a string with format - ``YYYY-MM-DD-0000``. + ``YYYY-MM-DD``. """ - return "'%s'" % val.strftime('%Y-%m-%d-0000') + return "'%s'" % val.strftime('%Y-%m-%d') + + def cql_encode_time(self, val): + """ + Converts a :class:`datetime.date` object to a string with format + ``HH:MM:SS.mmmuuunnn``. + """ + return "'%s'" % val def cql_encode_sequence(self, val): """ diff --git a/tests/integration/standard/test_types.py b/tests/integration/standard/test_types.py index 44978e67..4944725e 100644 --- a/tests/integration/standard/test_types.py +++ b/tests/integration/standard/test_types.py @@ -22,7 +22,7 @@ import logging log = logging.getLogger(__name__) from decimal import Decimal -from datetime import datetime +from datetime import datetime, date, time import six from uuid import uuid1, uuid4 @@ -31,7 +31,6 @@ from cassandra.cluster import Cluster from cassandra.cqltypes import Int32Type, EMPTY from cassandra.query import dict_factory from cassandra.util import OrderedDict, sortedset -from collections import namedtuple from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION @@ -171,6 +170,8 @@ class TypeTests(unittest.TestCase): v1_uuid = uuid1() v4_uuid = uuid4() mydatetime = datetime(2013, 12, 31, 23, 59, 59, 999000) + mydate = date(2015, 1, 15) + mytime = time(16, 47, 25, 7) params = [ "sometext", @@ -192,13 +193,10 @@ class TypeTests(unittest.TestCase): v1_uuid, # timeuuid u"sometext\u1234", # varchar 123456789123456789123456789, # varint - '2014-01-01', # date - '01:02:03.456789012' # time + mydate, # date + mytime ] - SimpleDate = namedtuple('SimpleDate', 'value') - Time = namedtuple('Time', 'value') - expected_vals = ( "sometext", "sometext", @@ -219,8 +217,8 @@ class TypeTests(unittest.TestCase): v1_uuid, # timeuuid u"sometext\u1234", # varchar 123456789123456789123456789, # varint - SimpleDate(2147499719), # date - Time(3723456789012) # time + mydate, # date + 60445000007000 # time ) s.execute(""" diff --git a/tests/unit/test_marshalling.py b/tests/unit/test_marshalling.py index 9eb4de69..a329b4f5 100644 --- a/tests/unit/test_marshalling.py +++ b/tests/unit/test_marshalling.py @@ -19,7 +19,7 @@ except ImportError: import unittest # noqa import platform -from datetime import datetime +from datetime import datetime, date from decimal import Decimal from uuid import UUID @@ -79,8 +79,9 @@ marshalled_value_pairs = ( (b'\x00\x00', 'ListType(FloatType)', []), (b'\x00\x00', 'SetType(IntegerType)', sortedset()), (b'\x00\x01\x00\x10\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0', 'ListType(TimeUUIDType)', [UUID(bytes=b'\xafYC\xa3\xea<\x11\xe1\xabc\xc4,\x03"y\xf0')]), - (b'\x00\x00>\xc7', 'SimpleDateType', '2014-01-01'), - (b'\x00\x00\x00\x00\x00\x00\x00\x01', 'TimeType', '00:00:00.000000001') + (b'\x80\x00\x00\x01', 'SimpleDateType', date(1970,1,2)), + (b'\x7f\xff\xff\xff', 'SimpleDateType', date(1969,12,31)), + (b'\x00\x00\x00\x00\x00\x00\x00\x01', 'TimeType', 1) ) ordered_dict_value = OrderedDict() diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index 28f7b63a..42e2c992 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -128,56 +128,77 @@ class TypeTests(unittest.TestCase): """ Test cassandra.cqltypes.SimpleDateType() construction """ + # from string + expected_date = datetime.date(1492, 10, 12) + sd = SimpleDateType('1492-10-12') + self.assertEqual(sd.val, expected_date) - nd = SimpleDateType.interpret_simpledate_string('2014-01-01') - tval = time.strptime('2014-01-01', SimpleDateType.date_format) - manual = calendar.timegm(tval) / SimpleDateType.seconds_per_day - self.assertEqual(nd, manual) + # date + sd = SimpleDateType(expected_date) + self.assertEqual(sd.val, expected_date) - nd = SimpleDateType.interpret_simpledate_string('1970-01-01') - self.assertEqual(nd, 0) + # int + expected_timestamp = calendar.timegm(expected_date.timetuple()) + sd = SimpleDateType(expected_timestamp) + self.assertEqual(sd.val, expected_timestamp) + + # no contruct + self.assertRaises(ValueError, SimpleDateType, '1999-10-10-bad-time') + self.assertRaises(TypeError, SimpleDateType, 1.234) def test_time(self): """ Test cassandra.cqltypes.TimeType() construction """ - one_micro = 1000 - one_milli = 1000L*one_micro - one_second = 1000L*one_milli - one_minute = 60L*one_second - one_hour = 60L*one_minute + one_milli = 1000 * one_micro + one_second = 1000 * one_milli + one_minute = 60 * one_second + one_hour = 60 * one_minute - nd = TimeType.interpret_timestring('00:00:00.000000001') - self.assertEqual(nd, 1) - nd = TimeType.interpret_timestring('00:00:00.000001') - self.assertEqual(nd, one_micro) - nd = TimeType.interpret_timestring('00:00:00.001') - self.assertEqual(nd, one_milli) - nd = TimeType.interpret_timestring('00:00:01') - self.assertEqual(nd, one_second) - nd = TimeType.interpret_timestring('00:01:00') - self.assertEqual(nd, one_minute) - nd = TimeType.interpret_timestring('01:00:00') - self.assertEqual(nd, one_hour) + # from strings + tt = TimeType('00:00:00.000000001') + self.assertEqual(tt.val, 1) + tt = TimeType('00:00:00.000001') + self.assertEqual(tt.val, one_micro) + tt = TimeType('00:00:00.001') + self.assertEqual(tt.val, one_milli) + tt = TimeType('00:00:01') + self.assertEqual(tt.val, one_second) + tt = TimeType('00:01:00') + self.assertEqual(tt.val, one_minute) + tt = TimeType('01:00:00') + self.assertEqual(tt.val, one_hour) + tt = TimeType('01:00:00.') + self.assertEqual(tt.val, one_hour) - nd = TimeType('23:59:59.1') - nd = TimeType('23:59:59.12') - nd = TimeType('23:59:59.123') - nd = TimeType('23:59:59.1234') - nd = TimeType('23:59:59.12345') + tt = TimeType('23:59:59.1') + tt = TimeType('23:59:59.12') + tt = TimeType('23:59:59.123') + tt = TimeType('23:59:59.1234') + tt = TimeType('23:59:59.12345') - nd = TimeType.interpret_timestring('23:59:59.123456') - self.assertEquals(nd, 23*one_hour + 59*one_minute + 59*one_second + 123*one_milli + 456*one_micro) + tt = TimeType('23:59:59.123456') + self.assertEqual(tt.val, 23*one_hour + 59*one_minute + 59*one_second + 123*one_milli + 456*one_micro) - nd = TimeType.interpret_timestring('23:59:59.1234567') - self.assertEquals(nd, 23*one_hour + 59*one_minute + 59*one_second + 123*one_milli + 456*one_micro + 700) + tt = TimeType('23:59:59.1234567') + self.assertEqual(tt.val, 23*one_hour + 59*one_minute + 59*one_second + 123*one_milli + 456*one_micro + 700) - nd = TimeType.interpret_timestring('23:59:59.12345678') - self.assertEquals(nd, 23*one_hour + 59*one_minute + 59*one_second + 123*one_milli + 456*one_micro + 780) + tt = TimeType('23:59:59.12345678') + self.assertEqual(tt.val, 23*one_hour + 59*one_minute + 59*one_second + 123*one_milli + 456*one_micro + 780) + + tt = TimeType('23:59:59.123456789') + self.assertEqual(tt.val, 23*one_hour + 59*one_minute + 59*one_second + 123*one_milli + 456*one_micro + 789) + + # from int + tt = TimeType(12345678) + self.assertEqual(tt.val, 12345678) + + # no construct + self.assertRaises(ValueError, TimeType, '1999-10-10 11:11:11.1234') + self.assertRaises(TypeError, TimeType, 1.234) + self.assertRaises(TypeError, TimeType, datetime.datetime(2004, 12, 23, 11, 11, 1)) - nd = TimeType.interpret_timestring('23:59:59.123456789') - self.assertEquals(nd, 23*one_hour + 59*one_minute + 59*one_second + 123*one_milli + 456*one_micro + 789) def test_cql_typename(self): """