hacking in the query functions, also getting all tests to pass
This commit is contained in:
		| @@ -865,8 +865,15 @@ class _PartitionKeysToken(Column): | |||||||
|         self.partition_columns = model._partition_keys.values() |         self.partition_columns = model._partition_keys.values() | ||||||
|         super(_PartitionKeysToken, self).__init__(partition_key=True) |         super(_PartitionKeysToken, self).__init__(partition_key=True) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def db_field_name(self): | ||||||
|  |         return 'token({})'.format(', '.join(['"{}"'.format(c.db_field_name) for c in self.partition_columns])) | ||||||
|  |  | ||||||
|     def to_database(self, value): |     def to_database(self, value): | ||||||
|         raise NotImplementedError |         from cqlengine.functions import Token | ||||||
|  |         assert isinstance(value, Token) | ||||||
|  |         value.set_columns(self.partition_columns) | ||||||
|  |         return value | ||||||
|  |  | ||||||
|     def get_cql(self): |     def get_cql(self): | ||||||
|         return "token({})".format(", ".join(c.cql for c in self.partition_columns)) |         return "token({})".format(", ".join(c.cql for c in self.partition_columns)) | ||||||
|   | |||||||
| @@ -9,24 +9,24 @@ class QueryValue(object): | |||||||
|     be passed into .filter() keyword args |     be passed into .filter() keyword args | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     _cql_string = ':{}' |     format_string = ':{}' | ||||||
|  |  | ||||||
|     def __init__(self, value, identifier=None): |     def __init__(self, value): | ||||||
|         self.value = value |         self.value = value | ||||||
|         self.identifier = uuid1().hex if identifier is None else identifier |         self.context_id = None | ||||||
|  |  | ||||||
|     def get_cql(self): |     def __unicode__(self): | ||||||
|         return self._cql_string.format(self.identifier) |         return self.format_string.format(self.context_id) | ||||||
|  |  | ||||||
|     def get_value(self): |     def set_context_id(self, ctx_id): | ||||||
|         return self.value |         self.context_id = ctx_id | ||||||
|  |  | ||||||
|     def get_dict(self, column): |     def get_context_size(self): | ||||||
|         return {self.identifier: column.to_database(self.get_value())} |         return 1 | ||||||
|  |  | ||||||
|  |     def update_context(self, ctx): | ||||||
|  |         ctx[str(self.context_id)] = self.value | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def cql(self): |  | ||||||
|         return self.get_cql() |  | ||||||
|  |  | ||||||
| class BaseQueryFunction(QueryValue): | class BaseQueryFunction(QueryValue): | ||||||
|     """ |     """ | ||||||
| @@ -42,7 +42,7 @@ class MinTimeUUID(BaseQueryFunction): | |||||||
|     http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun |     http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     _cql_string = 'MinTimeUUID(:{})' |     format_string = 'MinTimeUUID(:{})' | ||||||
|  |  | ||||||
|     def __init__(self, value): |     def __init__(self, value): | ||||||
|         """ |         """ | ||||||
| @@ -53,14 +53,11 @@ class MinTimeUUID(BaseQueryFunction): | |||||||
|             raise ValidationError('datetime instance is required') |             raise ValidationError('datetime instance is required') | ||||||
|         super(MinTimeUUID, self).__init__(value) |         super(MinTimeUUID, self).__init__(value) | ||||||
|  |  | ||||||
|     def get_value(self): |     def update_context(self, ctx): | ||||||
|         epoch = datetime(1970, 1, 1, tzinfo=self.value.tzinfo) |         epoch = datetime(1970, 1, 1, tzinfo=self.value.tzinfo) | ||||||
|         offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0 |         offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0 | ||||||
|  |         ctx[str(self.context_id)] = long(((self.value - epoch).total_seconds() - offset) * 1000) | ||||||
|  |  | ||||||
|         return long(((self.value - epoch).total_seconds() - offset) * 1000) |  | ||||||
|  |  | ||||||
|     def get_dict(self, column): |  | ||||||
|         return {self.identifier: self.get_value()} |  | ||||||
|  |  | ||||||
| class MaxTimeUUID(BaseQueryFunction): | class MaxTimeUUID(BaseQueryFunction): | ||||||
|     """ |     """ | ||||||
| @@ -69,7 +66,7 @@ class MaxTimeUUID(BaseQueryFunction): | |||||||
|     http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun |     http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     _cql_string = 'MaxTimeUUID(:{})' |     format_string = 'MaxTimeUUID(:{})' | ||||||
|  |  | ||||||
|     def __init__(self, value): |     def __init__(self, value): | ||||||
|         """ |         """ | ||||||
| @@ -80,14 +77,11 @@ class MaxTimeUUID(BaseQueryFunction): | |||||||
|             raise ValidationError('datetime instance is required') |             raise ValidationError('datetime instance is required') | ||||||
|         super(MaxTimeUUID, self).__init__(value) |         super(MaxTimeUUID, self).__init__(value) | ||||||
|  |  | ||||||
|     def get_value(self): |     def update_context(self, ctx): | ||||||
|         epoch = datetime(1970, 1, 1, tzinfo=self.value.tzinfo) |         epoch = datetime(1970, 1, 1, tzinfo=self.value.tzinfo) | ||||||
|         offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0 |         offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0 | ||||||
|  |         ctx[str(self.context_id)] = long(((self.value - epoch).total_seconds() - offset) * 1000) | ||||||
|  |  | ||||||
|         return long(((self.value - epoch).total_seconds() - offset) * 1000) |  | ||||||
|  |  | ||||||
|     def get_dict(self, column): |  | ||||||
|         return {self.identifier: self.get_value()} |  | ||||||
|  |  | ||||||
| class Token(BaseQueryFunction): | class Token(BaseQueryFunction): | ||||||
|     """ |     """ | ||||||
| @@ -99,15 +93,20 @@ class Token(BaseQueryFunction): | |||||||
|     def __init__(self, *values): |     def __init__(self, *values): | ||||||
|         if len(values) == 1 and isinstance(values[0], (list, tuple)): |         if len(values) == 1 and isinstance(values[0], (list, tuple)): | ||||||
|             values = values[0] |             values = values[0] | ||||||
|         super(Token, self).__init__(values, [uuid1().hex for i in values]) |         super(Token, self).__init__(values) | ||||||
|  |         self._columns = None | ||||||
|  |  | ||||||
|     def get_dict(self, column): |     def set_columns(self, columns): | ||||||
|         items = zip(self.identifier, self.value, column.partition_columns) |         self._columns = columns | ||||||
|         return dict( |  | ||||||
|             (id, col.to_database(val)) for id, val, col in items |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def get_cql(self): |     def get_context_size(self): | ||||||
|         token_args = ', '.join(':{}'.format(id) for id in self.identifier) |         return len(self.value) | ||||||
|  |  | ||||||
|  |     def __unicode__(self): | ||||||
|  |         token_args = ', '.join(':{}'.format(self.context_id + i) for i in range(self.get_context_size())) | ||||||
|         return "token({})".format(token_args) |         return "token({})".format(token_args) | ||||||
|  |  | ||||||
|  |     def update_context(self, ctx): | ||||||
|  |         for i, (col, val) in enumerate(zip(self._columns, self.value)): | ||||||
|  |             ctx[str(self.context_id + i)] = col.to_database(val) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -40,6 +40,10 @@ class NamedColumn(AbstractQueryableColumn): | |||||||
|         """ :rtype: NamedColumn """ |         """ :rtype: NamedColumn """ | ||||||
|         return self |         return self | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def db_field_name(self): | ||||||
|  |         return self.name | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def cql(self): |     def cql(self): | ||||||
|         return self.get_cql() |         return self.get_cql() | ||||||
|   | |||||||
| @@ -10,7 +10,7 @@ from cqlengine.columns import Counter | |||||||
| from cqlengine.connection import connection_manager, execute, RowResult | from cqlengine.connection import connection_manager, execute, RowResult | ||||||
|  |  | ||||||
| from cqlengine.exceptions import CQLEngineException, ValidationError | from cqlengine.exceptions import CQLEngineException, ValidationError | ||||||
| from cqlengine.functions import QueryValue, Token | from cqlengine.functions import QueryValue, Token, BaseQueryFunction | ||||||
|  |  | ||||||
| #CQL 3 reference: | #CQL 3 reference: | ||||||
| #http://www.datastax.com/docs/1.1/references/cql/index | #http://www.datastax.com/docs/1.1/references/cql/index | ||||||
| @@ -480,12 +480,14 @@ class AbstractQuerySet(object): | |||||||
|  |  | ||||||
|         for arg, val in kwargs.items(): |         for arg, val in kwargs.items(): | ||||||
|             col_name, col_op = self._parse_filter_arg(arg) |             col_name, col_op = self._parse_filter_arg(arg) | ||||||
|  |             quote_field = True | ||||||
|             #resolve column and operator |             #resolve column and operator | ||||||
|             try: |             try: | ||||||
|                 column = self.model._get_column(col_name) |                 column = self.model._get_column(col_name) | ||||||
|             except KeyError: |             except KeyError: | ||||||
|                 if col_name == 'pk__token': |                 if col_name == 'pk__token': | ||||||
|                     column = columns._PartitionKeysToken(self.model) |                     column = columns._PartitionKeysToken(self.model) | ||||||
|  |                     quote_field = False | ||||||
|                 else: |                 else: | ||||||
|                     raise QueryException("Can't resolve column name: '{}'".format(col_name)) |                     raise QueryException("Can't resolve column name: '{}'".format(col_name)) | ||||||
|  |  | ||||||
| @@ -497,10 +499,12 @@ class AbstractQuerySet(object): | |||||||
|                 if not isinstance(val, (list, tuple)): |                 if not isinstance(val, (list, tuple)): | ||||||
|                     raise QueryException('IN queries must use a list/tuple value') |                     raise QueryException('IN queries must use a list/tuple value') | ||||||
|                 query_val = [column.to_database(v) for v in val] |                 query_val = [column.to_database(v) for v in val] | ||||||
|  |             elif isinstance(val, BaseQueryFunction): | ||||||
|  |                 query_val = val | ||||||
|             else: |             else: | ||||||
|                 query_val = column.to_database(val) |                 query_val = column.to_database(val) | ||||||
|  |  | ||||||
|             clone._where.append(WhereClause(col_name, operator, query_val)) |             clone._where.append(WhereClause(column.db_field_name, operator, query_val, quote_field=quote_field)) | ||||||
|  |  | ||||||
|         return clone |         return clone | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,3 +1,4 @@ | |||||||
|  | from cqlengine.functions import QueryValue | ||||||
| from cqlengine.operators import BaseWhereOperator, InOperator | from cqlengine.operators import BaseWhereOperator, InOperator | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -78,16 +79,27 @@ class BaseClause(object): | |||||||
| class WhereClause(BaseClause): | class WhereClause(BaseClause): | ||||||
|     """ a single where statement used in queries """ |     """ a single where statement used in queries """ | ||||||
|  |  | ||||||
|     def __init__(self, field, operator, value): |     def __init__(self, field, operator, value, quote_field=True): | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         :param field: | ||||||
|  |         :param operator: | ||||||
|  |         :param value: | ||||||
|  |         :param quote_field: hack to get the token function rendering properly | ||||||
|  |         :return: | ||||||
|  |         """ | ||||||
|         if not isinstance(operator, BaseWhereOperator): |         if not isinstance(operator, BaseWhereOperator): | ||||||
|             raise StatementException( |             raise StatementException( | ||||||
|                 "operator must be of type {}, got {}".format(BaseWhereOperator, type(operator)) |                 "operator must be of type {}, got {}".format(BaseWhereOperator, type(operator)) | ||||||
|             ) |             ) | ||||||
|         super(WhereClause, self).__init__(field, value) |         super(WhereClause, self).__init__(field, value) | ||||||
|         self.operator = operator |         self.operator = operator | ||||||
|  |         self.query_value = self.value if isinstance(self.value, QueryValue) else QueryValue(self.value) | ||||||
|  |         self.quote_field = quote_field | ||||||
|  |  | ||||||
|     def __unicode__(self): |     def __unicode__(self): | ||||||
|         return u'"{}" {} :{}'.format(self.field, self.operator, self.context_id) |         field = ('"{}"' if self.quote_field else '{}').format(self.field) | ||||||
|  |         return u'{} {} {}'.format(field, self.operator, unicode(self.query_value)) | ||||||
|  |  | ||||||
|     def __hash__(self): |     def __hash__(self): | ||||||
|         return super(WhereClause, self).__hash__() ^ hash(self.operator) |         return super(WhereClause, self).__hash__() ^ hash(self.operator) | ||||||
| @@ -97,11 +109,18 @@ class WhereClause(BaseClause): | |||||||
|             return self.operator.__class__ == other.operator.__class__ |             return self.operator.__class__ == other.operator.__class__ | ||||||
|         return False |         return False | ||||||
|  |  | ||||||
|  |     def get_context_size(self): | ||||||
|  |         return self.query_value.get_context_size() | ||||||
|  |  | ||||||
|  |     def set_context_id(self, i): | ||||||
|  |         super(WhereClause, self).set_context_id(i) | ||||||
|  |         self.query_value.set_context_id(i) | ||||||
|  |  | ||||||
|     def update_context(self, ctx): |     def update_context(self, ctx): | ||||||
|         if isinstance(self.operator, InOperator): |         if isinstance(self.operator, InOperator): | ||||||
|             ctx[str(self.context_id)] = InQuoter(self.value) |             ctx[str(self.context_id)] = InQuoter(self.value) | ||||||
|         else: |         else: | ||||||
|             super(WhereClause, self).update_context(ctx) |             self.query_value.update_context(ctx) | ||||||
|  |  | ||||||
|  |  | ||||||
| class AssignmentClause(BaseClause): | class AssignmentClause(BaseClause): | ||||||
|   | |||||||
| @@ -1,10 +1,12 @@ | |||||||
| from datetime import datetime | from datetime import datetime | ||||||
| import time | from cqlengine.columns import DateTime | ||||||
|  |  | ||||||
| from cqlengine.tests.base import BaseCassEngTestCase | from cqlengine.tests.base import BaseCassEngTestCase | ||||||
| from cqlengine import columns, Model | from cqlengine import columns, Model | ||||||
| from cqlengine import functions | from cqlengine import functions | ||||||
| from cqlengine import query | from cqlengine import query | ||||||
|  | from cqlengine.statements import WhereClause | ||||||
|  | from cqlengine.operators import EqualsOperator | ||||||
|  |  | ||||||
| class TestQuerySetOperation(BaseCassEngTestCase): | class TestQuerySetOperation(BaseCassEngTestCase): | ||||||
|  |  | ||||||
| @@ -13,22 +15,26 @@ class TestQuerySetOperation(BaseCassEngTestCase): | |||||||
|         Tests that queries with helper functions are generated properly |         Tests that queries with helper functions are generated properly | ||||||
|         """ |         """ | ||||||
|         now = datetime.now() |         now = datetime.now() | ||||||
|         col = columns.DateTime() |         where = WhereClause('time', EqualsOperator(), functions.MaxTimeUUID(now)) | ||||||
|         col.set_column_name('time') |         where.set_context_id(5) | ||||||
|         qry = query.EqualsOperator(col, functions.MaxTimeUUID(now)) |  | ||||||
|  |  | ||||||
|         assert qry.cql == '"time" = MaxTimeUUID(:{})'.format(qry.value.identifier) |         self.assertEqual(str(where), '"time" = MaxTimeUUID(:5)') | ||||||
|  |         ctx = {} | ||||||
|  |         where.update_context(ctx) | ||||||
|  |         self.assertEqual(ctx, {'5': DateTime().to_database(now)}) | ||||||
|  |  | ||||||
|     def test_mintimeuuid_function(self): |     def test_mintimeuuid_function(self): | ||||||
|         """ |         """ | ||||||
|         Tests that queries with helper functions are generated properly |         Tests that queries with helper functions are generated properly | ||||||
|         """ |         """ | ||||||
|         now = datetime.now() |         now = datetime.now() | ||||||
|         col = columns.DateTime() |         where = WhereClause('time', EqualsOperator(), functions.MinTimeUUID(now)) | ||||||
|         col.set_column_name('time') |         where.set_context_id(5) | ||||||
|         qry = query.EqualsOperator(col, functions.MinTimeUUID(now)) |  | ||||||
|  |  | ||||||
|         assert qry.cql == '"time" = MinTimeUUID(:{})'.format(qry.value.identifier) |         self.assertEqual(str(where), '"time" = MinTimeUUID(:5)') | ||||||
|  |         ctx = {} | ||||||
|  |         where.update_context(ctx) | ||||||
|  |         self.assertEqual(ctx, {'5': DateTime().to_database(now)}) | ||||||
|  |  | ||||||
|     def test_token_function(self): |     def test_token_function(self): | ||||||
|  |  | ||||||
| @@ -39,12 +45,16 @@ class TestQuerySetOperation(BaseCassEngTestCase): | |||||||
|         func = functions.Token('a', 'b') |         func = functions.Token('a', 'b') | ||||||
|  |  | ||||||
|         q = TestModel.objects.filter(pk__token__gt=func) |         q = TestModel.objects.filter(pk__token__gt=func) | ||||||
|         self.assertEquals(q._where[0].cql, 'token("p1", "p2") > token(:{}, :{})'.format(*func.identifier)) |         where = q._where[0] | ||||||
|  |         where.set_context_id(1) | ||||||
|  |         self.assertEquals(str(where), 'token("p1", "p2") > token(:{}, :{})'.format(1, 2)) | ||||||
|  |  | ||||||
|         # Token(tuple()) is also possible for convinience |         # Token(tuple()) is also possible for convenience | ||||||
|         # it (allows for Token(obj.pk) syntax) |         # it (allows for Token(obj.pk) syntax) | ||||||
|         func = functions.Token(('a', 'b')) |         func = functions.Token(('a', 'b')) | ||||||
|  |  | ||||||
|         q = TestModel.objects.filter(pk__token__gt=func) |         q = TestModel.objects.filter(pk__token__gt=func) | ||||||
|         self.assertEquals(q._where[0].cql, 'token("p1", "p2") > token(:{}, :{})'.format(*func.identifier)) |         where = q._where[0] | ||||||
|  |         where.set_context_id(1) | ||||||
|  |         self.assertEquals(str(where), 'token("p1", "p2") > token(:{}, :{})'.format(1, 2)) | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Blake Eggleston
					Blake Eggleston