diff --git a/cqlengine/columns.py b/cqlengine/columns.py index 79b90e25..05ace2bb 100644 --- a/cqlengine/columns.py +++ b/cqlengine/columns.py @@ -865,8 +865,15 @@ class _PartitionKeysToken(Column): self.partition_columns = model._partition_keys.values() super(_PartitionKeysToken, self).__init__(partition_key=True) + @property + def db_field_name(self): + return 'token({})'.format(', '.join(['"{}"'.format(c.db_field_name) for c in self.partition_columns])) + def to_database(self, value): - raise NotImplementedError + from cqlengine.functions import Token + assert isinstance(value, Token) + value.set_columns(self.partition_columns) + return value def get_cql(self): return "token({})".format(", ".join(c.cql for c in self.partition_columns)) diff --git a/cqlengine/functions.py b/cqlengine/functions.py index 13661845..ceb6a78b 100644 --- a/cqlengine/functions.py +++ b/cqlengine/functions.py @@ -9,24 +9,24 @@ class QueryValue(object): be passed into .filter() keyword args """ - _cql_string = ':{}' + format_string = ':{}' - def __init__(self, value, identifier=None): + def __init__(self, value): self.value = value - self.identifier = uuid1().hex if identifier is None else identifier + self.context_id = None - def get_cql(self): - return self._cql_string.format(self.identifier) + def __unicode__(self): + return self.format_string.format(self.context_id) - def get_value(self): - return self.value + def set_context_id(self, ctx_id): + self.context_id = ctx_id - def get_dict(self, column): - return {self.identifier: column.to_database(self.get_value())} + def get_context_size(self): + return 1 + + def update_context(self, ctx): + ctx[str(self.context_id)] = self.value - @property - def cql(self): - return self.get_cql() class BaseQueryFunction(QueryValue): """ @@ -42,7 +42,7 @@ class MinTimeUUID(BaseQueryFunction): http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun """ - _cql_string = 'MinTimeUUID(:{})' + format_string = 'MinTimeUUID(:{})' def __init__(self, value): """ @@ -53,14 +53,11 @@ class MinTimeUUID(BaseQueryFunction): raise ValidationError('datetime instance is required') super(MinTimeUUID, self).__init__(value) - def get_value(self): + def update_context(self, ctx): epoch = datetime(1970, 1, 1, tzinfo=self.value.tzinfo) offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0 + ctx[str(self.context_id)] = long(((self.value - epoch).total_seconds() - offset) * 1000) - return long(((self.value - epoch).total_seconds() - offset) * 1000) - - def get_dict(self, column): - return {self.identifier: self.get_value()} class MaxTimeUUID(BaseQueryFunction): """ @@ -69,7 +66,7 @@ class MaxTimeUUID(BaseQueryFunction): http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun """ - _cql_string = 'MaxTimeUUID(:{})' + format_string = 'MaxTimeUUID(:{})' def __init__(self, value): """ @@ -80,14 +77,11 @@ class MaxTimeUUID(BaseQueryFunction): raise ValidationError('datetime instance is required') super(MaxTimeUUID, self).__init__(value) - def get_value(self): + def update_context(self, ctx): epoch = datetime(1970, 1, 1, tzinfo=self.value.tzinfo) offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0 + ctx[str(self.context_id)] = long(((self.value - epoch).total_seconds() - offset) * 1000) - return long(((self.value - epoch).total_seconds() - offset) * 1000) - - def get_dict(self, column): - return {self.identifier: self.get_value()} class Token(BaseQueryFunction): """ @@ -99,15 +93,20 @@ class Token(BaseQueryFunction): def __init__(self, *values): if len(values) == 1 and isinstance(values[0], (list, tuple)): values = values[0] - super(Token, self).__init__(values, [uuid1().hex for i in values]) + super(Token, self).__init__(values) + self._columns = None - def get_dict(self, column): - items = zip(self.identifier, self.value, column.partition_columns) - return dict( - (id, col.to_database(val)) for id, val, col in items - ) + def set_columns(self, columns): + self._columns = columns - def get_cql(self): - token_args = ', '.join(':{}'.format(id) for id in self.identifier) + def get_context_size(self): + return len(self.value) + + def __unicode__(self): + token_args = ', '.join(':{}'.format(self.context_id + i) for i in range(self.get_context_size())) return "token({})".format(token_args) + def update_context(self, ctx): + for i, (col, val) in enumerate(zip(self._columns, self.value)): + ctx[str(self.context_id + i)] = col.to_database(val) + diff --git a/cqlengine/named.py b/cqlengine/named.py index 2b754431..c6ba3ac9 100644 --- a/cqlengine/named.py +++ b/cqlengine/named.py @@ -40,6 +40,10 @@ class NamedColumn(AbstractQueryableColumn): """ :rtype: NamedColumn """ return self + @property + def db_field_name(self): + return self.name + @property def cql(self): return self.get_cql() diff --git a/cqlengine/query.py b/cqlengine/query.py index 545d5810..0418883f 100644 --- a/cqlengine/query.py +++ b/cqlengine/query.py @@ -10,7 +10,7 @@ from cqlengine.columns import Counter from cqlengine.connection import connection_manager, execute, RowResult from cqlengine.exceptions import CQLEngineException, ValidationError -from cqlengine.functions import QueryValue, Token +from cqlengine.functions import QueryValue, Token, BaseQueryFunction #CQL 3 reference: #http://www.datastax.com/docs/1.1/references/cql/index @@ -480,12 +480,14 @@ class AbstractQuerySet(object): for arg, val in kwargs.items(): col_name, col_op = self._parse_filter_arg(arg) + quote_field = True #resolve column and operator try: column = self.model._get_column(col_name) except KeyError: if col_name == 'pk__token': column = columns._PartitionKeysToken(self.model) + quote_field = False else: raise QueryException("Can't resolve column name: '{}'".format(col_name)) @@ -497,10 +499,12 @@ class AbstractQuerySet(object): if not isinstance(val, (list, tuple)): raise QueryException('IN queries must use a list/tuple value') query_val = [column.to_database(v) for v in val] + elif isinstance(val, BaseQueryFunction): + query_val = val else: query_val = column.to_database(val) - clone._where.append(WhereClause(col_name, operator, query_val)) + clone._where.append(WhereClause(column.db_field_name, operator, query_val, quote_field=quote_field)) return clone diff --git a/cqlengine/statements.py b/cqlengine/statements.py index f974a189..834e6479 100644 --- a/cqlengine/statements.py +++ b/cqlengine/statements.py @@ -1,3 +1,4 @@ +from cqlengine.functions import QueryValue from cqlengine.operators import BaseWhereOperator, InOperator @@ -78,16 +79,27 @@ class BaseClause(object): class WhereClause(BaseClause): """ a single where statement used in queries """ - def __init__(self, field, operator, value): + def __init__(self, field, operator, value, quote_field=True): + """ + + :param field: + :param operator: + :param value: + :param quote_field: hack to get the token function rendering properly + :return: + """ if not isinstance(operator, BaseWhereOperator): raise StatementException( "operator must be of type {}, got {}".format(BaseWhereOperator, type(operator)) ) super(WhereClause, self).__init__(field, value) self.operator = operator + self.query_value = self.value if isinstance(self.value, QueryValue) else QueryValue(self.value) + self.quote_field = quote_field def __unicode__(self): - return u'"{}" {} :{}'.format(self.field, self.operator, self.context_id) + field = ('"{}"' if self.quote_field else '{}').format(self.field) + return u'{} {} {}'.format(field, self.operator, unicode(self.query_value)) def __hash__(self): return super(WhereClause, self).__hash__() ^ hash(self.operator) @@ -97,11 +109,18 @@ class WhereClause(BaseClause): return self.operator.__class__ == other.operator.__class__ return False + def get_context_size(self): + return self.query_value.get_context_size() + + def set_context_id(self, i): + super(WhereClause, self).set_context_id(i) + self.query_value.set_context_id(i) + def update_context(self, ctx): if isinstance(self.operator, InOperator): ctx[str(self.context_id)] = InQuoter(self.value) else: - super(WhereClause, self).update_context(ctx) + self.query_value.update_context(ctx) class AssignmentClause(BaseClause): diff --git a/cqlengine/tests/query/test_queryoperators.py b/cqlengine/tests/query/test_queryoperators.py index 8f5243fb..11dec4dd 100644 --- a/cqlengine/tests/query/test_queryoperators.py +++ b/cqlengine/tests/query/test_queryoperators.py @@ -1,10 +1,12 @@ from datetime import datetime -import time +from cqlengine.columns import DateTime from cqlengine.tests.base import BaseCassEngTestCase from cqlengine import columns, Model from cqlengine import functions from cqlengine import query +from cqlengine.statements import WhereClause +from cqlengine.operators import EqualsOperator class TestQuerySetOperation(BaseCassEngTestCase): @@ -13,22 +15,26 @@ class TestQuerySetOperation(BaseCassEngTestCase): Tests that queries with helper functions are generated properly """ now = datetime.now() - col = columns.DateTime() - col.set_column_name('time') - qry = query.EqualsOperator(col, functions.MaxTimeUUID(now)) + where = WhereClause('time', EqualsOperator(), functions.MaxTimeUUID(now)) + where.set_context_id(5) - assert qry.cql == '"time" = MaxTimeUUID(:{})'.format(qry.value.identifier) + self.assertEqual(str(where), '"time" = MaxTimeUUID(:5)') + ctx = {} + where.update_context(ctx) + self.assertEqual(ctx, {'5': DateTime().to_database(now)}) def test_mintimeuuid_function(self): """ Tests that queries with helper functions are generated properly """ now = datetime.now() - col = columns.DateTime() - col.set_column_name('time') - qry = query.EqualsOperator(col, functions.MinTimeUUID(now)) + where = WhereClause('time', EqualsOperator(), functions.MinTimeUUID(now)) + where.set_context_id(5) - assert qry.cql == '"time" = MinTimeUUID(:{})'.format(qry.value.identifier) + self.assertEqual(str(where), '"time" = MinTimeUUID(:5)') + ctx = {} + where.update_context(ctx) + self.assertEqual(ctx, {'5': DateTime().to_database(now)}) def test_token_function(self): @@ -39,12 +45,16 @@ class TestQuerySetOperation(BaseCassEngTestCase): func = functions.Token('a', 'b') q = TestModel.objects.filter(pk__token__gt=func) - self.assertEquals(q._where[0].cql, 'token("p1", "p2") > token(:{}, :{})'.format(*func.identifier)) + where = q._where[0] + where.set_context_id(1) + self.assertEquals(str(where), 'token("p1", "p2") > token(:{}, :{})'.format(1, 2)) - # Token(tuple()) is also possible for convinience + # Token(tuple()) is also possible for convenience # it (allows for Token(obj.pk) syntax) func = functions.Token(('a', 'b')) q = TestModel.objects.filter(pk__token__gt=func) - self.assertEquals(q._where[0].cql, 'token("p1", "p2") > token(:{}, :{})'.format(*func.identifier)) + where = q._where[0] + where.set_context_id(1) + self.assertEquals(str(where), 'token("p1", "p2") > token(:{}, :{})'.format(1, 2))