From bd940eff49e1877ffbb4c9487a67d451db2c8349 Mon Sep 17 00:00:00 2001 From: Blake Eggleston Date: Sun, 25 Nov 2012 18:28:35 -0800 Subject: [PATCH] adding datetime and decimal columns and tests --- cqlengine/columns.py | 17 +++++--- cqlengine/models.py | 4 +- cqlengine/query.py | 7 +-- cqlengine/tests/columns/test_validation.py | 51 ++++++++++++++++++++++ 4 files changed, 70 insertions(+), 9 deletions(-) diff --git a/cqlengine/columns.py b/cqlengine/columns.py index df7a3a34..617b0d4b 100644 --- a/cqlengine/columns.py +++ b/cqlengine/columns.py @@ -1,4 +1,5 @@ #column field types +from datetime import datetime import re from uuid import uuid1, uuid4 @@ -165,7 +166,17 @@ class DateTime(Column): db_type = 'timestamp' def __init__(self, **kwargs): super(DateTime, self).__init__(**kwargs) - raise NotImplementedError + + def to_python(self, value): + if isinstance(value, datetime): + return value + return datetime.fromtimestamp(value) + + def to_database(self, value): + value = super(DateTime, self).to_database(value) + if not isinstance(value, datetime): + raise ValidationError("'{}' is not a datetime object".format(value)) + return value.strftime('%Y-%m-%d %H:%M:%S') class UUID(Column): """ @@ -216,10 +227,6 @@ class Float(Column): class Decimal(Column): db_type = 'decimal' - #TODO: decimal field - def __init__(self, **kwargs): - super(DateTime, self).__init__(**kwargs) - raise NotImplementedError class Counter(Column): #TODO: counter field diff --git a/cqlengine/models.py b/cqlengine/models.py index d977742f..c629d204 100644 --- a/cqlengine/models.py +++ b/cqlengine/models.py @@ -21,7 +21,9 @@ class BaseModel(object): def __init__(self, **values): self._values = {} for name, column in self._columns.items(): - value_mngr = column.value_manager(self, column, values.get(name, None)) + value = values.get(name, None) + if value is not None: value = column.to_python(value) + value_mngr = column.value_manager(self, column, value) self._values[name] = value_mngr @classmethod diff --git a/cqlengine/query.py b/cqlengine/query.py index 921e8d19..bb7a5e51 100644 --- a/cqlengine/query.py +++ b/cqlengine/query.py @@ -117,9 +117,8 @@ class LessThanOrEqualOperator(QueryOperator): cql_symbol = '<=' class QuerySet(object): - #TODO: delete empty columns on save - #TODO: support specifying columns to exclude or select only #TODO: cache results in this instance, but don't copy them on deepcopy + #TODO: support multiple iterators def __init__(self, model): super(QuerySet, self).__init__() @@ -224,7 +223,9 @@ class QuerySet(object): with connection_manager() as con: self._cursor = con.execute(self._select_query(), self._where_values()) self._rowcount = self._cursor.rowcount - return self + return self + else: + raise QueryException("QuerySet only supports a single iterator at a time, though this will be fixed shortly") def __getitem__(self, s): diff --git a/cqlengine/tests/columns/test_validation.py b/cqlengine/tests/columns/test_validation.py index 4043dfe8..3f70ab12 100644 --- a/cqlengine/tests/columns/test_validation.py +++ b/cqlengine/tests/columns/test_validation.py @@ -1,4 +1,6 @@ #tests the behavior of the column classes +from datetime import datetime +from decimal import Decimal as D from cqlengine.tests.base import BaseCassEngTestCase @@ -13,4 +15,53 @@ from cqlengine.columns import Boolean from cqlengine.columns import Float from cqlengine.columns import Decimal +from cqlengine.management import create_column_family, delete_column_family +from cqlengine.models import Model + +class TestDatetime(BaseCassEngTestCase): + class DatetimeTest(Model): + test_id = Integer(primary_key=True) + created_at = DateTime() + + @classmethod + def setUpClass(cls): + super(TestDatetime, cls).setUpClass() + create_column_family(cls.DatetimeTest) + + @classmethod + def tearDownClass(cls): + super(TestDatetime, cls).tearDownClass() + delete_column_family(cls.DatetimeTest) + + def test_datetime_io(self): + now = datetime.now() + dt = self.DatetimeTest.objects.create(test_id=0, created_at=now) + dt2 = self.DatetimeTest.objects(test_id=0).first() + assert dt2.created_at.timetuple()[:6] == now.timetuple()[:6] + +class TestDecimal(BaseCassEngTestCase): + class DecimalTest(Model): + test_id = Integer(primary_key=True) + dec_val = Decimal() + + @classmethod + def setUpClass(cls): + super(TestDecimal, cls).setUpClass() + create_column_family(cls.DecimalTest) + + @classmethod + def tearDownClass(cls): + super(TestDecimal, cls).tearDownClass() + delete_column_family(cls.DecimalTest) + + def test_datetime_io(self): + dt = self.DecimalTest.objects.create(test_id=0, dec_val=D('0.00')) + dt2 = self.DecimalTest.objects(test_id=0).first() + assert dt2.dec_val == dt.dec_val + + dt = self.DecimalTest.objects.create(test_id=0, dec_val=5) + dt2 = self.DecimalTest.objects(test_id=0).first() + assert dt2.dec_val == D('5') + +