hacking in the query functions, also getting all tests to pass
This commit is contained in:
@@ -865,8 +865,15 @@ class _PartitionKeysToken(Column):
|
|||||||
self.partition_columns = model._partition_keys.values()
|
self.partition_columns = model._partition_keys.values()
|
||||||
super(_PartitionKeysToken, self).__init__(partition_key=True)
|
super(_PartitionKeysToken, self).__init__(partition_key=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def db_field_name(self):
|
||||||
|
return 'token({})'.format(', '.join(['"{}"'.format(c.db_field_name) for c in self.partition_columns]))
|
||||||
|
|
||||||
def to_database(self, value):
|
def to_database(self, value):
|
||||||
raise NotImplementedError
|
from cqlengine.functions import Token
|
||||||
|
assert isinstance(value, Token)
|
||||||
|
value.set_columns(self.partition_columns)
|
||||||
|
return value
|
||||||
|
|
||||||
def get_cql(self):
|
def get_cql(self):
|
||||||
return "token({})".format(", ".join(c.cql for c in self.partition_columns))
|
return "token({})".format(", ".join(c.cql for c in self.partition_columns))
|
||||||
|
|||||||
@@ -9,24 +9,24 @@ class QueryValue(object):
|
|||||||
be passed into .filter() keyword args
|
be passed into .filter() keyword args
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_cql_string = ':{}'
|
format_string = ':{}'
|
||||||
|
|
||||||
def __init__(self, value, identifier=None):
|
def __init__(self, value):
|
||||||
self.value = value
|
self.value = value
|
||||||
self.identifier = uuid1().hex if identifier is None else identifier
|
self.context_id = None
|
||||||
|
|
||||||
def get_cql(self):
|
def __unicode__(self):
|
||||||
return self._cql_string.format(self.identifier)
|
return self.format_string.format(self.context_id)
|
||||||
|
|
||||||
def get_value(self):
|
def set_context_id(self, ctx_id):
|
||||||
return self.value
|
self.context_id = ctx_id
|
||||||
|
|
||||||
def get_dict(self, column):
|
def get_context_size(self):
|
||||||
return {self.identifier: column.to_database(self.get_value())}
|
return 1
|
||||||
|
|
||||||
|
def update_context(self, ctx):
|
||||||
|
ctx[str(self.context_id)] = self.value
|
||||||
|
|
||||||
@property
|
|
||||||
def cql(self):
|
|
||||||
return self.get_cql()
|
|
||||||
|
|
||||||
class BaseQueryFunction(QueryValue):
|
class BaseQueryFunction(QueryValue):
|
||||||
"""
|
"""
|
||||||
@@ -42,7 +42,7 @@ class MinTimeUUID(BaseQueryFunction):
|
|||||||
http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun
|
http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_cql_string = 'MinTimeUUID(:{})'
|
format_string = 'MinTimeUUID(:{})'
|
||||||
|
|
||||||
def __init__(self, value):
|
def __init__(self, value):
|
||||||
"""
|
"""
|
||||||
@@ -53,14 +53,11 @@ class MinTimeUUID(BaseQueryFunction):
|
|||||||
raise ValidationError('datetime instance is required')
|
raise ValidationError('datetime instance is required')
|
||||||
super(MinTimeUUID, self).__init__(value)
|
super(MinTimeUUID, self).__init__(value)
|
||||||
|
|
||||||
def get_value(self):
|
def update_context(self, ctx):
|
||||||
epoch = datetime(1970, 1, 1, tzinfo=self.value.tzinfo)
|
epoch = datetime(1970, 1, 1, tzinfo=self.value.tzinfo)
|
||||||
offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0
|
offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0
|
||||||
|
ctx[str(self.context_id)] = long(((self.value - epoch).total_seconds() - offset) * 1000)
|
||||||
|
|
||||||
return long(((self.value - epoch).total_seconds() - offset) * 1000)
|
|
||||||
|
|
||||||
def get_dict(self, column):
|
|
||||||
return {self.identifier: self.get_value()}
|
|
||||||
|
|
||||||
class MaxTimeUUID(BaseQueryFunction):
|
class MaxTimeUUID(BaseQueryFunction):
|
||||||
"""
|
"""
|
||||||
@@ -69,7 +66,7 @@ class MaxTimeUUID(BaseQueryFunction):
|
|||||||
http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun
|
http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_cql_string = 'MaxTimeUUID(:{})'
|
format_string = 'MaxTimeUUID(:{})'
|
||||||
|
|
||||||
def __init__(self, value):
|
def __init__(self, value):
|
||||||
"""
|
"""
|
||||||
@@ -80,14 +77,11 @@ class MaxTimeUUID(BaseQueryFunction):
|
|||||||
raise ValidationError('datetime instance is required')
|
raise ValidationError('datetime instance is required')
|
||||||
super(MaxTimeUUID, self).__init__(value)
|
super(MaxTimeUUID, self).__init__(value)
|
||||||
|
|
||||||
def get_value(self):
|
def update_context(self, ctx):
|
||||||
epoch = datetime(1970, 1, 1, tzinfo=self.value.tzinfo)
|
epoch = datetime(1970, 1, 1, tzinfo=self.value.tzinfo)
|
||||||
offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0
|
offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0
|
||||||
|
ctx[str(self.context_id)] = long(((self.value - epoch).total_seconds() - offset) * 1000)
|
||||||
|
|
||||||
return long(((self.value - epoch).total_seconds() - offset) * 1000)
|
|
||||||
|
|
||||||
def get_dict(self, column):
|
|
||||||
return {self.identifier: self.get_value()}
|
|
||||||
|
|
||||||
class Token(BaseQueryFunction):
|
class Token(BaseQueryFunction):
|
||||||
"""
|
"""
|
||||||
@@ -99,15 +93,20 @@ class Token(BaseQueryFunction):
|
|||||||
def __init__(self, *values):
|
def __init__(self, *values):
|
||||||
if len(values) == 1 and isinstance(values[0], (list, tuple)):
|
if len(values) == 1 and isinstance(values[0], (list, tuple)):
|
||||||
values = values[0]
|
values = values[0]
|
||||||
super(Token, self).__init__(values, [uuid1().hex for i in values])
|
super(Token, self).__init__(values)
|
||||||
|
self._columns = None
|
||||||
|
|
||||||
def get_dict(self, column):
|
def set_columns(self, columns):
|
||||||
items = zip(self.identifier, self.value, column.partition_columns)
|
self._columns = columns
|
||||||
return dict(
|
|
||||||
(id, col.to_database(val)) for id, val, col in items
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_cql(self):
|
def get_context_size(self):
|
||||||
token_args = ', '.join(':{}'.format(id) for id in self.identifier)
|
return len(self.value)
|
||||||
|
|
||||||
|
def __unicode__(self):
|
||||||
|
token_args = ', '.join(':{}'.format(self.context_id + i) for i in range(self.get_context_size()))
|
||||||
return "token({})".format(token_args)
|
return "token({})".format(token_args)
|
||||||
|
|
||||||
|
def update_context(self, ctx):
|
||||||
|
for i, (col, val) in enumerate(zip(self._columns, self.value)):
|
||||||
|
ctx[str(self.context_id + i)] = col.to_database(val)
|
||||||
|
|
||||||
|
|||||||
@@ -40,6 +40,10 @@ class NamedColumn(AbstractQueryableColumn):
|
|||||||
""" :rtype: NamedColumn """
|
""" :rtype: NamedColumn """
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def db_field_name(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cql(self):
|
def cql(self):
|
||||||
return self.get_cql()
|
return self.get_cql()
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from cqlengine.columns import Counter
|
|||||||
from cqlengine.connection import connection_manager, execute, RowResult
|
from cqlengine.connection import connection_manager, execute, RowResult
|
||||||
|
|
||||||
from cqlengine.exceptions import CQLEngineException, ValidationError
|
from cqlengine.exceptions import CQLEngineException, ValidationError
|
||||||
from cqlengine.functions import QueryValue, Token
|
from cqlengine.functions import QueryValue, Token, BaseQueryFunction
|
||||||
|
|
||||||
#CQL 3 reference:
|
#CQL 3 reference:
|
||||||
#http://www.datastax.com/docs/1.1/references/cql/index
|
#http://www.datastax.com/docs/1.1/references/cql/index
|
||||||
@@ -480,12 +480,14 @@ class AbstractQuerySet(object):
|
|||||||
|
|
||||||
for arg, val in kwargs.items():
|
for arg, val in kwargs.items():
|
||||||
col_name, col_op = self._parse_filter_arg(arg)
|
col_name, col_op = self._parse_filter_arg(arg)
|
||||||
|
quote_field = True
|
||||||
#resolve column and operator
|
#resolve column and operator
|
||||||
try:
|
try:
|
||||||
column = self.model._get_column(col_name)
|
column = self.model._get_column(col_name)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
if col_name == 'pk__token':
|
if col_name == 'pk__token':
|
||||||
column = columns._PartitionKeysToken(self.model)
|
column = columns._PartitionKeysToken(self.model)
|
||||||
|
quote_field = False
|
||||||
else:
|
else:
|
||||||
raise QueryException("Can't resolve column name: '{}'".format(col_name))
|
raise QueryException("Can't resolve column name: '{}'".format(col_name))
|
||||||
|
|
||||||
@@ -497,10 +499,12 @@ class AbstractQuerySet(object):
|
|||||||
if not isinstance(val, (list, tuple)):
|
if not isinstance(val, (list, tuple)):
|
||||||
raise QueryException('IN queries must use a list/tuple value')
|
raise QueryException('IN queries must use a list/tuple value')
|
||||||
query_val = [column.to_database(v) for v in val]
|
query_val = [column.to_database(v) for v in val]
|
||||||
|
elif isinstance(val, BaseQueryFunction):
|
||||||
|
query_val = val
|
||||||
else:
|
else:
|
||||||
query_val = column.to_database(val)
|
query_val = column.to_database(val)
|
||||||
|
|
||||||
clone._where.append(WhereClause(col_name, operator, query_val))
|
clone._where.append(WhereClause(column.db_field_name, operator, query_val, quote_field=quote_field))
|
||||||
|
|
||||||
return clone
|
return clone
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from cqlengine.functions import QueryValue
|
||||||
from cqlengine.operators import BaseWhereOperator, InOperator
|
from cqlengine.operators import BaseWhereOperator, InOperator
|
||||||
|
|
||||||
|
|
||||||
@@ -78,16 +79,27 @@ class BaseClause(object):
|
|||||||
class WhereClause(BaseClause):
|
class WhereClause(BaseClause):
|
||||||
""" a single where statement used in queries """
|
""" a single where statement used in queries """
|
||||||
|
|
||||||
def __init__(self, field, operator, value):
|
def __init__(self, field, operator, value, quote_field=True):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param field:
|
||||||
|
:param operator:
|
||||||
|
:param value:
|
||||||
|
:param quote_field: hack to get the token function rendering properly
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
if not isinstance(operator, BaseWhereOperator):
|
if not isinstance(operator, BaseWhereOperator):
|
||||||
raise StatementException(
|
raise StatementException(
|
||||||
"operator must be of type {}, got {}".format(BaseWhereOperator, type(operator))
|
"operator must be of type {}, got {}".format(BaseWhereOperator, type(operator))
|
||||||
)
|
)
|
||||||
super(WhereClause, self).__init__(field, value)
|
super(WhereClause, self).__init__(field, value)
|
||||||
self.operator = operator
|
self.operator = operator
|
||||||
|
self.query_value = self.value if isinstance(self.value, QueryValue) else QueryValue(self.value)
|
||||||
|
self.quote_field = quote_field
|
||||||
|
|
||||||
def __unicode__(self):
|
def __unicode__(self):
|
||||||
return u'"{}" {} :{}'.format(self.field, self.operator, self.context_id)
|
field = ('"{}"' if self.quote_field else '{}').format(self.field)
|
||||||
|
return u'{} {} {}'.format(field, self.operator, unicode(self.query_value))
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return super(WhereClause, self).__hash__() ^ hash(self.operator)
|
return super(WhereClause, self).__hash__() ^ hash(self.operator)
|
||||||
@@ -97,11 +109,18 @@ class WhereClause(BaseClause):
|
|||||||
return self.operator.__class__ == other.operator.__class__
|
return self.operator.__class__ == other.operator.__class__
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def get_context_size(self):
|
||||||
|
return self.query_value.get_context_size()
|
||||||
|
|
||||||
|
def set_context_id(self, i):
|
||||||
|
super(WhereClause, self).set_context_id(i)
|
||||||
|
self.query_value.set_context_id(i)
|
||||||
|
|
||||||
def update_context(self, ctx):
|
def update_context(self, ctx):
|
||||||
if isinstance(self.operator, InOperator):
|
if isinstance(self.operator, InOperator):
|
||||||
ctx[str(self.context_id)] = InQuoter(self.value)
|
ctx[str(self.context_id)] = InQuoter(self.value)
|
||||||
else:
|
else:
|
||||||
super(WhereClause, self).update_context(ctx)
|
self.query_value.update_context(ctx)
|
||||||
|
|
||||||
|
|
||||||
class AssignmentClause(BaseClause):
|
class AssignmentClause(BaseClause):
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import time
|
from cqlengine.columns import DateTime
|
||||||
|
|
||||||
from cqlengine.tests.base import BaseCassEngTestCase
|
from cqlengine.tests.base import BaseCassEngTestCase
|
||||||
from cqlengine import columns, Model
|
from cqlengine import columns, Model
|
||||||
from cqlengine import functions
|
from cqlengine import functions
|
||||||
from cqlengine import query
|
from cqlengine import query
|
||||||
|
from cqlengine.statements import WhereClause
|
||||||
|
from cqlengine.operators import EqualsOperator
|
||||||
|
|
||||||
class TestQuerySetOperation(BaseCassEngTestCase):
|
class TestQuerySetOperation(BaseCassEngTestCase):
|
||||||
|
|
||||||
@@ -13,22 +15,26 @@ class TestQuerySetOperation(BaseCassEngTestCase):
|
|||||||
Tests that queries with helper functions are generated properly
|
Tests that queries with helper functions are generated properly
|
||||||
"""
|
"""
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
col = columns.DateTime()
|
where = WhereClause('time', EqualsOperator(), functions.MaxTimeUUID(now))
|
||||||
col.set_column_name('time')
|
where.set_context_id(5)
|
||||||
qry = query.EqualsOperator(col, functions.MaxTimeUUID(now))
|
|
||||||
|
|
||||||
assert qry.cql == '"time" = MaxTimeUUID(:{})'.format(qry.value.identifier)
|
self.assertEqual(str(where), '"time" = MaxTimeUUID(:5)')
|
||||||
|
ctx = {}
|
||||||
|
where.update_context(ctx)
|
||||||
|
self.assertEqual(ctx, {'5': DateTime().to_database(now)})
|
||||||
|
|
||||||
def test_mintimeuuid_function(self):
|
def test_mintimeuuid_function(self):
|
||||||
"""
|
"""
|
||||||
Tests that queries with helper functions are generated properly
|
Tests that queries with helper functions are generated properly
|
||||||
"""
|
"""
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
col = columns.DateTime()
|
where = WhereClause('time', EqualsOperator(), functions.MinTimeUUID(now))
|
||||||
col.set_column_name('time')
|
where.set_context_id(5)
|
||||||
qry = query.EqualsOperator(col, functions.MinTimeUUID(now))
|
|
||||||
|
|
||||||
assert qry.cql == '"time" = MinTimeUUID(:{})'.format(qry.value.identifier)
|
self.assertEqual(str(where), '"time" = MinTimeUUID(:5)')
|
||||||
|
ctx = {}
|
||||||
|
where.update_context(ctx)
|
||||||
|
self.assertEqual(ctx, {'5': DateTime().to_database(now)})
|
||||||
|
|
||||||
def test_token_function(self):
|
def test_token_function(self):
|
||||||
|
|
||||||
@@ -39,12 +45,16 @@ class TestQuerySetOperation(BaseCassEngTestCase):
|
|||||||
func = functions.Token('a', 'b')
|
func = functions.Token('a', 'b')
|
||||||
|
|
||||||
q = TestModel.objects.filter(pk__token__gt=func)
|
q = TestModel.objects.filter(pk__token__gt=func)
|
||||||
self.assertEquals(q._where[0].cql, 'token("p1", "p2") > token(:{}, :{})'.format(*func.identifier))
|
where = q._where[0]
|
||||||
|
where.set_context_id(1)
|
||||||
|
self.assertEquals(str(where), 'token("p1", "p2") > token(:{}, :{})'.format(1, 2))
|
||||||
|
|
||||||
# Token(tuple()) is also possible for convinience
|
# Token(tuple()) is also possible for convenience
|
||||||
# it (allows for Token(obj.pk) syntax)
|
# it (allows for Token(obj.pk) syntax)
|
||||||
func = functions.Token(('a', 'b'))
|
func = functions.Token(('a', 'b'))
|
||||||
|
|
||||||
q = TestModel.objects.filter(pk__token__gt=func)
|
q = TestModel.objects.filter(pk__token__gt=func)
|
||||||
self.assertEquals(q._where[0].cql, 'token("p1", "p2") > token(:{}, :{})'.format(*func.identifier))
|
where = q._where[0]
|
||||||
|
where.set_context_id(1)
|
||||||
|
self.assertEquals(str(where), 'token("p1", "p2") > token(:{}, :{})'.format(1, 2))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user