hacking in the query functions, also getting all tests to pass

This commit is contained in:
Blake Eggleston
2013-11-03 09:53:51 -08:00
parent 3a5a09fe5a
commit f426e0138e
6 changed files with 93 additions and 50 deletions

View File

@@ -865,8 +865,15 @@ class _PartitionKeysToken(Column):
self.partition_columns = model._partition_keys.values()
super(_PartitionKeysToken, self).__init__(partition_key=True)
@property
def db_field_name(self):
return 'token({})'.format(', '.join(['"{}"'.format(c.db_field_name) for c in self.partition_columns]))
def to_database(self, value):
raise NotImplementedError
from cqlengine.functions import Token
assert isinstance(value, Token)
value.set_columns(self.partition_columns)
return value
def get_cql(self):
return "token({})".format(", ".join(c.cql for c in self.partition_columns))

View File

@@ -9,24 +9,24 @@ class QueryValue(object):
be passed into .filter() keyword args
"""
_cql_string = ':{}'
format_string = ':{}'
def __init__(self, value, identifier=None):
def __init__(self, value):
self.value = value
self.identifier = uuid1().hex if identifier is None else identifier
self.context_id = None
def get_cql(self):
return self._cql_string.format(self.identifier)
def __unicode__(self):
return self.format_string.format(self.context_id)
def get_value(self):
return self.value
def set_context_id(self, ctx_id):
self.context_id = ctx_id
def get_dict(self, column):
return {self.identifier: column.to_database(self.get_value())}
def get_context_size(self):
return 1
def update_context(self, ctx):
ctx[str(self.context_id)] = self.value
@property
def cql(self):
return self.get_cql()
class BaseQueryFunction(QueryValue):
"""
@@ -42,7 +42,7 @@ class MinTimeUUID(BaseQueryFunction):
http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun
"""
_cql_string = 'MinTimeUUID(:{})'
format_string = 'MinTimeUUID(:{})'
def __init__(self, value):
"""
@@ -53,14 +53,11 @@ class MinTimeUUID(BaseQueryFunction):
raise ValidationError('datetime instance is required')
super(MinTimeUUID, self).__init__(value)
def get_value(self):
def update_context(self, ctx):
epoch = datetime(1970, 1, 1, tzinfo=self.value.tzinfo)
offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0
ctx[str(self.context_id)] = long(((self.value - epoch).total_seconds() - offset) * 1000)
return long(((self.value - epoch).total_seconds() - offset) * 1000)
def get_dict(self, column):
return {self.identifier: self.get_value()}
class MaxTimeUUID(BaseQueryFunction):
"""
@@ -69,7 +66,7 @@ class MaxTimeUUID(BaseQueryFunction):
http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun
"""
_cql_string = 'MaxTimeUUID(:{})'
format_string = 'MaxTimeUUID(:{})'
def __init__(self, value):
"""
@@ -80,14 +77,11 @@ class MaxTimeUUID(BaseQueryFunction):
raise ValidationError('datetime instance is required')
super(MaxTimeUUID, self).__init__(value)
def get_value(self):
def update_context(self, ctx):
epoch = datetime(1970, 1, 1, tzinfo=self.value.tzinfo)
offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0
ctx[str(self.context_id)] = long(((self.value - epoch).total_seconds() - offset) * 1000)
return long(((self.value - epoch).total_seconds() - offset) * 1000)
def get_dict(self, column):
return {self.identifier: self.get_value()}
class Token(BaseQueryFunction):
"""
@@ -99,15 +93,20 @@ class Token(BaseQueryFunction):
def __init__(self, *values):
if len(values) == 1 and isinstance(values[0], (list, tuple)):
values = values[0]
super(Token, self).__init__(values, [uuid1().hex for i in values])
super(Token, self).__init__(values)
self._columns = None
def get_dict(self, column):
items = zip(self.identifier, self.value, column.partition_columns)
return dict(
(id, col.to_database(val)) for id, val, col in items
)
def set_columns(self, columns):
self._columns = columns
def get_cql(self):
token_args = ', '.join(':{}'.format(id) for id in self.identifier)
def get_context_size(self):
return len(self.value)
def __unicode__(self):
token_args = ', '.join(':{}'.format(self.context_id + i) for i in range(self.get_context_size()))
return "token({})".format(token_args)
def update_context(self, ctx):
for i, (col, val) in enumerate(zip(self._columns, self.value)):
ctx[str(self.context_id + i)] = col.to_database(val)

View File

@@ -40,6 +40,10 @@ class NamedColumn(AbstractQueryableColumn):
""" :rtype: NamedColumn """
return self
@property
def db_field_name(self):
return self.name
@property
def cql(self):
return self.get_cql()

View File

@@ -10,7 +10,7 @@ from cqlengine.columns import Counter
from cqlengine.connection import connection_manager, execute, RowResult
from cqlengine.exceptions import CQLEngineException, ValidationError
from cqlengine.functions import QueryValue, Token
from cqlengine.functions import QueryValue, Token, BaseQueryFunction
#CQL 3 reference:
#http://www.datastax.com/docs/1.1/references/cql/index
@@ -480,12 +480,14 @@ class AbstractQuerySet(object):
for arg, val in kwargs.items():
col_name, col_op = self._parse_filter_arg(arg)
quote_field = True
#resolve column and operator
try:
column = self.model._get_column(col_name)
except KeyError:
if col_name == 'pk__token':
column = columns._PartitionKeysToken(self.model)
quote_field = False
else:
raise QueryException("Can't resolve column name: '{}'".format(col_name))
@@ -497,10 +499,12 @@ class AbstractQuerySet(object):
if not isinstance(val, (list, tuple)):
raise QueryException('IN queries must use a list/tuple value')
query_val = [column.to_database(v) for v in val]
elif isinstance(val, BaseQueryFunction):
query_val = val
else:
query_val = column.to_database(val)
clone._where.append(WhereClause(col_name, operator, query_val))
clone._where.append(WhereClause(column.db_field_name, operator, query_val, quote_field=quote_field))
return clone

View File

@@ -1,3 +1,4 @@
from cqlengine.functions import QueryValue
from cqlengine.operators import BaseWhereOperator, InOperator
@@ -78,16 +79,27 @@ class BaseClause(object):
class WhereClause(BaseClause):
""" a single where statement used in queries """
def __init__(self, field, operator, value):
def __init__(self, field, operator, value, quote_field=True):
"""
:param field:
:param operator:
:param value:
:param quote_field: hack to get the token function rendering properly
:return:
"""
if not isinstance(operator, BaseWhereOperator):
raise StatementException(
"operator must be of type {}, got {}".format(BaseWhereOperator, type(operator))
)
super(WhereClause, self).__init__(field, value)
self.operator = operator
self.query_value = self.value if isinstance(self.value, QueryValue) else QueryValue(self.value)
self.quote_field = quote_field
def __unicode__(self):
return u'"{}" {} :{}'.format(self.field, self.operator, self.context_id)
field = ('"{}"' if self.quote_field else '{}').format(self.field)
return u'{} {} {}'.format(field, self.operator, unicode(self.query_value))
def __hash__(self):
return super(WhereClause, self).__hash__() ^ hash(self.operator)
@@ -97,11 +109,18 @@ class WhereClause(BaseClause):
return self.operator.__class__ == other.operator.__class__
return False
def get_context_size(self):
return self.query_value.get_context_size()
def set_context_id(self, i):
super(WhereClause, self).set_context_id(i)
self.query_value.set_context_id(i)
def update_context(self, ctx):
if isinstance(self.operator, InOperator):
ctx[str(self.context_id)] = InQuoter(self.value)
else:
super(WhereClause, self).update_context(ctx)
self.query_value.update_context(ctx)
class AssignmentClause(BaseClause):

View File

@@ -1,10 +1,12 @@
from datetime import datetime
import time
from cqlengine.columns import DateTime
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine import columns, Model
from cqlengine import functions
from cqlengine import query
from cqlengine.statements import WhereClause
from cqlengine.operators import EqualsOperator
class TestQuerySetOperation(BaseCassEngTestCase):
@@ -13,22 +15,26 @@ class TestQuerySetOperation(BaseCassEngTestCase):
Tests that queries with helper functions are generated properly
"""
now = datetime.now()
col = columns.DateTime()
col.set_column_name('time')
qry = query.EqualsOperator(col, functions.MaxTimeUUID(now))
where = WhereClause('time', EqualsOperator(), functions.MaxTimeUUID(now))
where.set_context_id(5)
assert qry.cql == '"time" = MaxTimeUUID(:{})'.format(qry.value.identifier)
self.assertEqual(str(where), '"time" = MaxTimeUUID(:5)')
ctx = {}
where.update_context(ctx)
self.assertEqual(ctx, {'5': DateTime().to_database(now)})
def test_mintimeuuid_function(self):
"""
Tests that queries with helper functions are generated properly
"""
now = datetime.now()
col = columns.DateTime()
col.set_column_name('time')
qry = query.EqualsOperator(col, functions.MinTimeUUID(now))
where = WhereClause('time', EqualsOperator(), functions.MinTimeUUID(now))
where.set_context_id(5)
assert qry.cql == '"time" = MinTimeUUID(:{})'.format(qry.value.identifier)
self.assertEqual(str(where), '"time" = MinTimeUUID(:5)')
ctx = {}
where.update_context(ctx)
self.assertEqual(ctx, {'5': DateTime().to_database(now)})
def test_token_function(self):
@@ -39,12 +45,16 @@ class TestQuerySetOperation(BaseCassEngTestCase):
func = functions.Token('a', 'b')
q = TestModel.objects.filter(pk__token__gt=func)
self.assertEquals(q._where[0].cql, 'token("p1", "p2") > token(:{}, :{})'.format(*func.identifier))
where = q._where[0]
where.set_context_id(1)
self.assertEquals(str(where), 'token("p1", "p2") > token(:{}, :{})'.format(1, 2))
# Token(tuple()) is also possible for convinience
# Token(tuple()) is also possible for convenience
# it (allows for Token(obj.pk) syntax)
func = functions.Token(('a', 'b'))
q = TestModel.objects.filter(pk__token__gt=func)
self.assertEquals(q._where[0].cql, 'token("p1", "p2") > token(:{}, :{})'.format(*func.identifier))
where = q._where[0]
where.set_context_id(1)
self.assertEquals(str(where), 'token("p1", "p2") > token(:{}, :{})'.format(1, 2))