diff --git a/cqlengine/columns.py b/cqlengine/columns.py index f63afff1..5cbe6a09 100644 --- a/cqlengine/columns.py +++ b/cqlengine/columns.py @@ -1,6 +1,7 @@ #column field types from copy import copy from datetime import datetime +from datetime import date import re from uuid import uuid1, uuid4 from cql.query import cql_quote @@ -216,6 +217,31 @@ class DateTime(Column): epoch = datetime(1970, 1, 1) return long((value - epoch).total_seconds() * 1000) + +class Date(Column): + db_type = 'timestamp' + + def __init__(self, **kwargs): + super(Date, self).__init__(**kwargs) + + def to_python(self, value): + if isinstance(value, datetime): + return value.date() + elif isinstance(value, date): + return value + + return date.fromtimestamp(value) + + def to_database(self, value): + value = super(Date, self).to_database(value) + if isinstance(value, datetime): + value = value.date() + if not isinstance(value, date): + raise ValidationError("'{}' is not a date object".format(repr(value))) + + return long((value - date(1970, 1, 1)).total_seconds() * 1000) + + class UUID(Column): """ Type 1 or 4 UUID @@ -235,14 +261,14 @@ class UUID(Column): if not self.re_uuid.match(val): raise ValidationError("{} is not a valid uuid".format(value)) return _UUID(val) - + class TimeUUID(UUID): """ UUID containing timestamp """ - + db_type = 'timeuuid' - + def __init__(self, **kwargs): kwargs.setdefault('default', lambda: uuid1()) super(TimeUUID, self).__init__(**kwargs) diff --git a/cqlengine/tests/columns/test_validation.py b/cqlengine/tests/columns/test_validation.py index 4921b672..44903a4c 100644 --- a/cqlengine/tests/columns/test_validation.py +++ b/cqlengine/tests/columns/test_validation.py @@ -1,5 +1,6 @@ #tests the behavior of the column classes from datetime import datetime +from datetime import date from decimal import Decimal as D from cqlengine import ValidationError @@ -11,6 +12,7 @@ from cqlengine.columns import Ascii from cqlengine.columns import Text from cqlengine.columns import Integer from cqlengine.columns import DateTime +from cqlengine.columns import Date from cqlengine.columns import UUID from cqlengine.columns import Boolean from cqlengine.columns import Float @@ -40,6 +42,37 @@ class TestDatetime(BaseCassEngTestCase): dt2 = self.DatetimeTest.objects(test_id=0).first() assert dt2.created_at.timetuple()[:6] == now.timetuple()[:6] + +class TestDate(BaseCassEngTestCase): + class DateTest(Model): + test_id = Integer(primary_key=True) + created_at = Date() + + @classmethod + def setUpClass(cls): + super(TestDate, cls).setUpClass() + create_table(cls.DateTest) + + @classmethod + def tearDownClass(cls): + super(TestDate, cls).tearDownClass() + delete_table(cls.DateTest) + + def test_date_io(self): + today = date.today() + self.DateTest.objects.create(test_id=0, created_at=today) + dt2 = self.DateTest.objects(test_id=0).first() + assert dt2.created_at.isoformat() == today.isoformat() + + def test_date_io_using_datetime(self): + now = datetime.utcnow() + self.DateTest.objects.create(test_id=0, created_at=now) + dt2 = self.DateTest.objects(test_id=0).first() + assert not isinstance(dt2.created_at, datetime) + assert isinstance(dt2.created_at, date) + assert dt2.created_at.isoformat() == now.date().isoformat() + + class TestDecimal(BaseCassEngTestCase): class DecimalTest(Model): test_id = Integer(primary_key=True) @@ -63,26 +96,26 @@ class TestDecimal(BaseCassEngTestCase): dt = self.DecimalTest.objects.create(test_id=0, dec_val=5) dt2 = self.DecimalTest.objects(test_id=0).first() assert dt2.dec_val == D('5') - + class TestTimeUUID(BaseCassEngTestCase): class TimeUUIDTest(Model): test_id = Integer(primary_key=True) timeuuid = TimeUUID() - + @classmethod def setUpClass(cls): super(TestTimeUUID, cls).setUpClass() create_table(cls.TimeUUIDTest) - + @classmethod def tearDownClass(cls): super(TestTimeUUID, cls).tearDownClass() delete_table(cls.TimeUUIDTest) - + def test_timeuuid_io(self): t0 = self.TimeUUIDTest.create(test_id=0) t1 = self.TimeUUIDTest.get(test_id=0) - + assert t1.timeuuid.time == t1.timeuuid.time class TestInteger(BaseCassEngTestCase):