hacking in the query functions, also getting all tests to pass
This commit is contained in:
@@ -865,8 +865,15 @@ class _PartitionKeysToken(Column):
|
||||
self.partition_columns = model._partition_keys.values()
|
||||
super(_PartitionKeysToken, self).__init__(partition_key=True)
|
||||
|
||||
@property
|
||||
def db_field_name(self):
|
||||
return 'token({})'.format(', '.join(['"{}"'.format(c.db_field_name) for c in self.partition_columns]))
|
||||
|
||||
def to_database(self, value):
|
||||
raise NotImplementedError
|
||||
from cqlengine.functions import Token
|
||||
assert isinstance(value, Token)
|
||||
value.set_columns(self.partition_columns)
|
||||
return value
|
||||
|
||||
def get_cql(self):
|
||||
return "token({})".format(", ".join(c.cql for c in self.partition_columns))
|
||||
|
||||
@@ -9,24 +9,24 @@ class QueryValue(object):
|
||||
be passed into .filter() keyword args
|
||||
"""
|
||||
|
||||
_cql_string = ':{}'
|
||||
format_string = ':{}'
|
||||
|
||||
def __init__(self, value, identifier=None):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
self.identifier = uuid1().hex if identifier is None else identifier
|
||||
self.context_id = None
|
||||
|
||||
def get_cql(self):
|
||||
return self._cql_string.format(self.identifier)
|
||||
def __unicode__(self):
|
||||
return self.format_string.format(self.context_id)
|
||||
|
||||
def get_value(self):
|
||||
return self.value
|
||||
def set_context_id(self, ctx_id):
|
||||
self.context_id = ctx_id
|
||||
|
||||
def get_dict(self, column):
|
||||
return {self.identifier: column.to_database(self.get_value())}
|
||||
def get_context_size(self):
|
||||
return 1
|
||||
|
||||
def update_context(self, ctx):
|
||||
ctx[str(self.context_id)] = self.value
|
||||
|
||||
@property
|
||||
def cql(self):
|
||||
return self.get_cql()
|
||||
|
||||
class BaseQueryFunction(QueryValue):
|
||||
"""
|
||||
@@ -42,7 +42,7 @@ class MinTimeUUID(BaseQueryFunction):
|
||||
http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun
|
||||
"""
|
||||
|
||||
_cql_string = 'MinTimeUUID(:{})'
|
||||
format_string = 'MinTimeUUID(:{})'
|
||||
|
||||
def __init__(self, value):
|
||||
"""
|
||||
@@ -53,14 +53,11 @@ class MinTimeUUID(BaseQueryFunction):
|
||||
raise ValidationError('datetime instance is required')
|
||||
super(MinTimeUUID, self).__init__(value)
|
||||
|
||||
def get_value(self):
|
||||
def update_context(self, ctx):
|
||||
epoch = datetime(1970, 1, 1, tzinfo=self.value.tzinfo)
|
||||
offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0
|
||||
ctx[str(self.context_id)] = long(((self.value - epoch).total_seconds() - offset) * 1000)
|
||||
|
||||
return long(((self.value - epoch).total_seconds() - offset) * 1000)
|
||||
|
||||
def get_dict(self, column):
|
||||
return {self.identifier: self.get_value()}
|
||||
|
||||
class MaxTimeUUID(BaseQueryFunction):
|
||||
"""
|
||||
@@ -69,7 +66,7 @@ class MaxTimeUUID(BaseQueryFunction):
|
||||
http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun
|
||||
"""
|
||||
|
||||
_cql_string = 'MaxTimeUUID(:{})'
|
||||
format_string = 'MaxTimeUUID(:{})'
|
||||
|
||||
def __init__(self, value):
|
||||
"""
|
||||
@@ -80,14 +77,11 @@ class MaxTimeUUID(BaseQueryFunction):
|
||||
raise ValidationError('datetime instance is required')
|
||||
super(MaxTimeUUID, self).__init__(value)
|
||||
|
||||
def get_value(self):
|
||||
def update_context(self, ctx):
|
||||
epoch = datetime(1970, 1, 1, tzinfo=self.value.tzinfo)
|
||||
offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0
|
||||
ctx[str(self.context_id)] = long(((self.value - epoch).total_seconds() - offset) * 1000)
|
||||
|
||||
return long(((self.value - epoch).total_seconds() - offset) * 1000)
|
||||
|
||||
def get_dict(self, column):
|
||||
return {self.identifier: self.get_value()}
|
||||
|
||||
class Token(BaseQueryFunction):
|
||||
"""
|
||||
@@ -99,15 +93,20 @@ class Token(BaseQueryFunction):
|
||||
def __init__(self, *values):
|
||||
if len(values) == 1 and isinstance(values[0], (list, tuple)):
|
||||
values = values[0]
|
||||
super(Token, self).__init__(values, [uuid1().hex for i in values])
|
||||
super(Token, self).__init__(values)
|
||||
self._columns = None
|
||||
|
||||
def get_dict(self, column):
|
||||
items = zip(self.identifier, self.value, column.partition_columns)
|
||||
return dict(
|
||||
(id, col.to_database(val)) for id, val, col in items
|
||||
)
|
||||
def set_columns(self, columns):
|
||||
self._columns = columns
|
||||
|
||||
def get_cql(self):
|
||||
token_args = ', '.join(':{}'.format(id) for id in self.identifier)
|
||||
def get_context_size(self):
|
||||
return len(self.value)
|
||||
|
||||
def __unicode__(self):
|
||||
token_args = ', '.join(':{}'.format(self.context_id + i) for i in range(self.get_context_size()))
|
||||
return "token({})".format(token_args)
|
||||
|
||||
def update_context(self, ctx):
|
||||
for i, (col, val) in enumerate(zip(self._columns, self.value)):
|
||||
ctx[str(self.context_id + i)] = col.to_database(val)
|
||||
|
||||
|
||||
@@ -40,6 +40,10 @@ class NamedColumn(AbstractQueryableColumn):
|
||||
""" :rtype: NamedColumn """
|
||||
return self
|
||||
|
||||
@property
|
||||
def db_field_name(self):
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def cql(self):
|
||||
return self.get_cql()
|
||||
|
||||
@@ -10,7 +10,7 @@ from cqlengine.columns import Counter
|
||||
from cqlengine.connection import connection_manager, execute, RowResult
|
||||
|
||||
from cqlengine.exceptions import CQLEngineException, ValidationError
|
||||
from cqlengine.functions import QueryValue, Token
|
||||
from cqlengine.functions import QueryValue, Token, BaseQueryFunction
|
||||
|
||||
#CQL 3 reference:
|
||||
#http://www.datastax.com/docs/1.1/references/cql/index
|
||||
@@ -480,12 +480,14 @@ class AbstractQuerySet(object):
|
||||
|
||||
for arg, val in kwargs.items():
|
||||
col_name, col_op = self._parse_filter_arg(arg)
|
||||
quote_field = True
|
||||
#resolve column and operator
|
||||
try:
|
||||
column = self.model._get_column(col_name)
|
||||
except KeyError:
|
||||
if col_name == 'pk__token':
|
||||
column = columns._PartitionKeysToken(self.model)
|
||||
quote_field = False
|
||||
else:
|
||||
raise QueryException("Can't resolve column name: '{}'".format(col_name))
|
||||
|
||||
@@ -497,10 +499,12 @@ class AbstractQuerySet(object):
|
||||
if not isinstance(val, (list, tuple)):
|
||||
raise QueryException('IN queries must use a list/tuple value')
|
||||
query_val = [column.to_database(v) for v in val]
|
||||
elif isinstance(val, BaseQueryFunction):
|
||||
query_val = val
|
||||
else:
|
||||
query_val = column.to_database(val)
|
||||
|
||||
clone._where.append(WhereClause(col_name, operator, query_val))
|
||||
clone._where.append(WhereClause(column.db_field_name, operator, query_val, quote_field=quote_field))
|
||||
|
||||
return clone
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from cqlengine.functions import QueryValue
|
||||
from cqlengine.operators import BaseWhereOperator, InOperator
|
||||
|
||||
|
||||
@@ -78,16 +79,27 @@ class BaseClause(object):
|
||||
class WhereClause(BaseClause):
|
||||
""" a single where statement used in queries """
|
||||
|
||||
def __init__(self, field, operator, value):
|
||||
def __init__(self, field, operator, value, quote_field=True):
|
||||
"""
|
||||
|
||||
:param field:
|
||||
:param operator:
|
||||
:param value:
|
||||
:param quote_field: hack to get the token function rendering properly
|
||||
:return:
|
||||
"""
|
||||
if not isinstance(operator, BaseWhereOperator):
|
||||
raise StatementException(
|
||||
"operator must be of type {}, got {}".format(BaseWhereOperator, type(operator))
|
||||
)
|
||||
super(WhereClause, self).__init__(field, value)
|
||||
self.operator = operator
|
||||
self.query_value = self.value if isinstance(self.value, QueryValue) else QueryValue(self.value)
|
||||
self.quote_field = quote_field
|
||||
|
||||
def __unicode__(self):
|
||||
return u'"{}" {} :{}'.format(self.field, self.operator, self.context_id)
|
||||
field = ('"{}"' if self.quote_field else '{}').format(self.field)
|
||||
return u'{} {} {}'.format(field, self.operator, unicode(self.query_value))
|
||||
|
||||
def __hash__(self):
|
||||
return super(WhereClause, self).__hash__() ^ hash(self.operator)
|
||||
@@ -97,11 +109,18 @@ class WhereClause(BaseClause):
|
||||
return self.operator.__class__ == other.operator.__class__
|
||||
return False
|
||||
|
||||
def get_context_size(self):
|
||||
return self.query_value.get_context_size()
|
||||
|
||||
def set_context_id(self, i):
|
||||
super(WhereClause, self).set_context_id(i)
|
||||
self.query_value.set_context_id(i)
|
||||
|
||||
def update_context(self, ctx):
|
||||
if isinstance(self.operator, InOperator):
|
||||
ctx[str(self.context_id)] = InQuoter(self.value)
|
||||
else:
|
||||
super(WhereClause, self).update_context(ctx)
|
||||
self.query_value.update_context(ctx)
|
||||
|
||||
|
||||
class AssignmentClause(BaseClause):
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from datetime import datetime
|
||||
import time
|
||||
from cqlengine.columns import DateTime
|
||||
|
||||
from cqlengine.tests.base import BaseCassEngTestCase
|
||||
from cqlengine import columns, Model
|
||||
from cqlengine import functions
|
||||
from cqlengine import query
|
||||
from cqlengine.statements import WhereClause
|
||||
from cqlengine.operators import EqualsOperator
|
||||
|
||||
class TestQuerySetOperation(BaseCassEngTestCase):
|
||||
|
||||
@@ -13,22 +15,26 @@ class TestQuerySetOperation(BaseCassEngTestCase):
|
||||
Tests that queries with helper functions are generated properly
|
||||
"""
|
||||
now = datetime.now()
|
||||
col = columns.DateTime()
|
||||
col.set_column_name('time')
|
||||
qry = query.EqualsOperator(col, functions.MaxTimeUUID(now))
|
||||
where = WhereClause('time', EqualsOperator(), functions.MaxTimeUUID(now))
|
||||
where.set_context_id(5)
|
||||
|
||||
assert qry.cql == '"time" = MaxTimeUUID(:{})'.format(qry.value.identifier)
|
||||
self.assertEqual(str(where), '"time" = MaxTimeUUID(:5)')
|
||||
ctx = {}
|
||||
where.update_context(ctx)
|
||||
self.assertEqual(ctx, {'5': DateTime().to_database(now)})
|
||||
|
||||
def test_mintimeuuid_function(self):
|
||||
"""
|
||||
Tests that queries with helper functions are generated properly
|
||||
"""
|
||||
now = datetime.now()
|
||||
col = columns.DateTime()
|
||||
col.set_column_name('time')
|
||||
qry = query.EqualsOperator(col, functions.MinTimeUUID(now))
|
||||
where = WhereClause('time', EqualsOperator(), functions.MinTimeUUID(now))
|
||||
where.set_context_id(5)
|
||||
|
||||
assert qry.cql == '"time" = MinTimeUUID(:{})'.format(qry.value.identifier)
|
||||
self.assertEqual(str(where), '"time" = MinTimeUUID(:5)')
|
||||
ctx = {}
|
||||
where.update_context(ctx)
|
||||
self.assertEqual(ctx, {'5': DateTime().to_database(now)})
|
||||
|
||||
def test_token_function(self):
|
||||
|
||||
@@ -39,12 +45,16 @@ class TestQuerySetOperation(BaseCassEngTestCase):
|
||||
func = functions.Token('a', 'b')
|
||||
|
||||
q = TestModel.objects.filter(pk__token__gt=func)
|
||||
self.assertEquals(q._where[0].cql, 'token("p1", "p2") > token(:{}, :{})'.format(*func.identifier))
|
||||
where = q._where[0]
|
||||
where.set_context_id(1)
|
||||
self.assertEquals(str(where), 'token("p1", "p2") > token(:{}, :{})'.format(1, 2))
|
||||
|
||||
# Token(tuple()) is also possible for convinience
|
||||
# Token(tuple()) is also possible for convenience
|
||||
# it (allows for Token(obj.pk) syntax)
|
||||
func = functions.Token(('a', 'b'))
|
||||
|
||||
q = TestModel.objects.filter(pk__token__gt=func)
|
||||
self.assertEquals(q._where[0].cql, 'token("p1", "p2") > token(:{}, :{})'.format(*func.identifier))
|
||||
where = q._where[0]
|
||||
where.set_context_id(1)
|
||||
self.assertEquals(str(where), 'token("p1", "p2") > token(:{}, :{})'.format(1, 2))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user