diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index 3cb88172..e00196d8 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -553,10 +553,13 @@ class DateType(_CassandraType): timestamp_seconds = calendar.timegm(v.utctimetuple()) timestamp = timestamp_seconds * 1e3 + getattr(v, 'microsecond', 0) / 1e3 except AttributeError: - # Ints and floats are valid timestamps too - if type(v) not in _number_types: - raise TypeError('DateType arguments must be a datetime or timestamp') - timestamp = v + try: + timestamp = calendar.timegm(v.timetuple()) * 1e3 + except AttributeError: + # Ints and floats are valid timestamps too + if type(v) not in _number_types: + raise TypeError('DateType arguments must be a datetime, date, or timestamp') + timestamp = v return int64_pack(long(timestamp)) diff --git a/tests/unit/test_marshalling.py b/tests/unit/test_marshalling.py index 60ae33fd..3b98280c 100644 --- a/tests/unit/test_marshalling.py +++ b/tests/unit/test_marshalling.py @@ -23,7 +23,7 @@ from datetime import datetime, date from decimal import Decimal from uuid import UUID -from cassandra.cqltypes import lookup_casstype, DecimalType, UTF8Type +from cassandra.cqltypes import lookup_casstype, DecimalType, UTF8Type, DateType from cassandra.util import OrderedMap, OrderedMapSerializedKey, sortedset, Time, Date marshalled_value_pairs = ( @@ -39,6 +39,7 @@ marshalled_value_pairs = ( (b'\x80\x00\x00\x00\x00\x00\x00\x00', 'CounterColumnType', -9223372036854775808), (b'', 'CounterColumnType', None), (b'\x00\x00\x013\x7fb\xeey', 'DateType', datetime(2011, 11, 7, 18, 55, 49, 881000)), + (b'\x00\x00\x01P\xc5~L\x00', 'DateType', datetime(2015, 11, 2)), (b'', 'DateType', None), (b'\x00\x00\x00\r\nJ\x04"^\x91\x04\x8a\xb1\x18\xfe', 'DecimalType', Decimal('1243878957943.1234124191998')), (b'\x00\x00\x00\x06\xe5\xde]\x98Y', 'DecimalType', Decimal('-112233.441191')), @@ -134,3 +135,7 @@ class UnmarshalTest(unittest.TestCase): self.assertEqual(bitlength(9), 4) self.assertEqual(bitlength(-10), 0) self.assertEqual(bitlength(0), 0) + + def test_date(self): + # separate test because it will deserialize as datetime + self.assertEqual(DateType.from_binary(DateType.to_binary(date(2015, 11, 2), 1), 1), datetime(2015, 11, 2))