diff --git a/cqlengine/functions.py b/cqlengine/functions.py index f2f61661..15e93d46 100644 --- a/cqlengine/functions.py +++ b/cqlengine/functions.py @@ -20,6 +20,9 @@ class BaseQueryFunction(object): """ return self._cql_string.format(value_id) + def get_value(self): + raise NotImplementedError + class MinTimeUUID(BaseQueryFunction): _cql_string = 'MinTimeUUID(:{})' @@ -33,6 +36,10 @@ class MinTimeUUID(BaseQueryFunction): raise ValidationError('datetime instance is required') super(MinTimeUUID, self).__init__(value) + def get_value(self): + epoch = datetime(1970, 1, 1) + return long((self.value - epoch).total_seconds() * 1000) + class MaxTimeUUID(BaseQueryFunction): _cql_string = 'MaxTimeUUID(:{})' @@ -46,3 +53,7 @@ class MaxTimeUUID(BaseQueryFunction): raise ValidationError('datetime instance is required') super(MaxTimeUUID, self).__init__(value) + def get_value(self): + epoch = datetime(1970, 1, 1) + return long((self.value - epoch).total_seconds() * 1000) + diff --git a/cqlengine/query.py b/cqlengine/query.py index 94c4544d..7a9a74e7 100644 --- a/cqlengine/query.py +++ b/cqlengine/query.py @@ -78,7 +78,10 @@ class QueryOperator(object): this should return the dict: {'colval':} SELECT * FROM column_family WHERE colname=:colval """ - return {self.identifier: self.column.to_database(self.value)} + if isinstance(self.value, BaseQueryFunction): + return {self.identifier: self.column.to_database(self.value.get_value())} + else: + return {self.identifier: self.column.to_database(self.value)} @classmethod def get_operator(cls, symbol): diff --git a/cqlengine/tests/query/test_queryoperators.py b/cqlengine/tests/query/test_queryoperators.py index 62ffea9c..5db4a299 100644 --- a/cqlengine/tests/query/test_queryoperators.py +++ b/cqlengine/tests/query/test_queryoperators.py @@ -1,7 +1,8 @@ from datetime import datetime +import time from cqlengine.tests.base import BaseCassEngTestCase -from cqlengine import columns +from cqlengine import columns, Model from cqlengine import functions from cqlengine import query @@ -30,3 +31,4 @@ class TestQuerySetOperation(BaseCassEngTestCase): assert qry.cql == '"time" = MinTimeUUID(:{})'.format(qry.identifier) + diff --git a/cqlengine/tests/query/test_queryset.py b/cqlengine/tests/query/test_queryset.py index cf4eb9d1..99f45afa 100644 --- a/cqlengine/tests/query/test_queryset.py +++ b/cqlengine/tests/query/test_queryset.py @@ -1,6 +1,11 @@ +from datetime import datetime +import time +from uuid import uuid1, uuid4 + from cqlengine.tests.base import BaseCassEngTestCase from cqlengine.exceptions import ModelException +from cqlengine import functions from cqlengine.management import create_table from cqlengine.management import delete_table from cqlengine.models import Model @@ -386,6 +391,51 @@ class TestQuerySetConnectionHandling(BaseQuerySetUsage): del q assert ConnectionPool._queue.qsize() == 1 +class TimeUUIDQueryModel(Model): + partition = columns.UUID(primary_key=True) + time = columns.TimeUUID(primary_key=True) + data = columns.Text(required=False) + +class TestMinMaxTimeUUIDFunctions(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestMinMaxTimeUUIDFunctions, cls).setUpClass() + create_table(TimeUUIDQueryModel) + + @classmethod + def tearDownClass(cls): + super(TestMinMaxTimeUUIDFunctions, cls).tearDownClass() + delete_table(TimeUUIDQueryModel) + + def test_success_case(self): + """ Test that the min and max time uuid functions work as expected """ + pk = uuid4() + TimeUUIDQueryModel.create(partition=pk, time=uuid1(), data='1') + time.sleep(0.2) + TimeUUIDQueryModel.create(partition=pk, time=uuid1(), data='2') + time.sleep(0.2) + midpoint = datetime.utcnow() + time.sleep(0.2) + TimeUUIDQueryModel.create(partition=pk, time=uuid1(), data='3') + time.sleep(0.2) + TimeUUIDQueryModel.create(partition=pk, time=uuid1(), data='4') + time.sleep(0.2) + + q = TimeUUIDQueryModel.filter(partition=pk, time__lte=functions.MaxTimeUUID(midpoint)) + q = [d for d in q] + assert len(q) == 2 + datas = [d.data for d in q] + assert '1' in datas + assert '2' in datas + + q = TimeUUIDQueryModel.filter(partition=pk, time__gte=functions.MinTimeUUID(midpoint)) + assert len(q) == 2 + datas = [d.data for d in q] + assert '3' in datas + assert '4' in datas + +