hacking in the query functions, also getting all tests to pass

This commit is contained in:
Blake Eggleston
2013-11-03 09:53:51 -08:00
parent 3a5a09fe5a
commit f426e0138e
6 changed files with 93 additions and 50 deletions

View File

@@ -865,8 +865,15 @@ class _PartitionKeysToken(Column):
self.partition_columns = model._partition_keys.values() self.partition_columns = model._partition_keys.values()
super(_PartitionKeysToken, self).__init__(partition_key=True) super(_PartitionKeysToken, self).__init__(partition_key=True)
@property
def db_field_name(self):
return 'token({})'.format(', '.join(['"{}"'.format(c.db_field_name) for c in self.partition_columns]))
def to_database(self, value): def to_database(self, value):
raise NotImplementedError from cqlengine.functions import Token
assert isinstance(value, Token)
value.set_columns(self.partition_columns)
return value
def get_cql(self): def get_cql(self):
return "token({})".format(", ".join(c.cql for c in self.partition_columns)) return "token({})".format(", ".join(c.cql for c in self.partition_columns))

View File

@@ -9,24 +9,24 @@ class QueryValue(object):
be passed into .filter() keyword args be passed into .filter() keyword args
""" """
_cql_string = ':{}' format_string = ':{}'
def __init__(self, value, identifier=None): def __init__(self, value):
self.value = value self.value = value
self.identifier = uuid1().hex if identifier is None else identifier self.context_id = None
def get_cql(self): def __unicode__(self):
return self._cql_string.format(self.identifier) return self.format_string.format(self.context_id)
def get_value(self): def set_context_id(self, ctx_id):
return self.value self.context_id = ctx_id
def get_dict(self, column): def get_context_size(self):
return {self.identifier: column.to_database(self.get_value())} return 1
def update_context(self, ctx):
ctx[str(self.context_id)] = self.value
@property
def cql(self):
return self.get_cql()
class BaseQueryFunction(QueryValue): class BaseQueryFunction(QueryValue):
""" """
@@ -42,7 +42,7 @@ class MinTimeUUID(BaseQueryFunction):
http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun
""" """
_cql_string = 'MinTimeUUID(:{})' format_string = 'MinTimeUUID(:{})'
def __init__(self, value): def __init__(self, value):
""" """
@@ -53,14 +53,11 @@ class MinTimeUUID(BaseQueryFunction):
raise ValidationError('datetime instance is required') raise ValidationError('datetime instance is required')
super(MinTimeUUID, self).__init__(value) super(MinTimeUUID, self).__init__(value)
def get_value(self): def update_context(self, ctx):
epoch = datetime(1970, 1, 1, tzinfo=self.value.tzinfo) epoch = datetime(1970, 1, 1, tzinfo=self.value.tzinfo)
offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0 offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0
ctx[str(self.context_id)] = long(((self.value - epoch).total_seconds() - offset) * 1000)
return long(((self.value - epoch).total_seconds() - offset) * 1000)
def get_dict(self, column):
return {self.identifier: self.get_value()}
class MaxTimeUUID(BaseQueryFunction): class MaxTimeUUID(BaseQueryFunction):
""" """
@@ -69,7 +66,7 @@ class MaxTimeUUID(BaseQueryFunction):
http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun
""" """
_cql_string = 'MaxTimeUUID(:{})' format_string = 'MaxTimeUUID(:{})'
def __init__(self, value): def __init__(self, value):
""" """
@@ -80,14 +77,11 @@ class MaxTimeUUID(BaseQueryFunction):
raise ValidationError('datetime instance is required') raise ValidationError('datetime instance is required')
super(MaxTimeUUID, self).__init__(value) super(MaxTimeUUID, self).__init__(value)
def get_value(self): def update_context(self, ctx):
epoch = datetime(1970, 1, 1, tzinfo=self.value.tzinfo) epoch = datetime(1970, 1, 1, tzinfo=self.value.tzinfo)
offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0 offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0
ctx[str(self.context_id)] = long(((self.value - epoch).total_seconds() - offset) * 1000)
return long(((self.value - epoch).total_seconds() - offset) * 1000)
def get_dict(self, column):
return {self.identifier: self.get_value()}
class Token(BaseQueryFunction): class Token(BaseQueryFunction):
""" """
@@ -99,15 +93,20 @@ class Token(BaseQueryFunction):
def __init__(self, *values): def __init__(self, *values):
if len(values) == 1 and isinstance(values[0], (list, tuple)): if len(values) == 1 and isinstance(values[0], (list, tuple)):
values = values[0] values = values[0]
super(Token, self).__init__(values, [uuid1().hex for i in values]) super(Token, self).__init__(values)
self._columns = None
def get_dict(self, column): def set_columns(self, columns):
items = zip(self.identifier, self.value, column.partition_columns) self._columns = columns
return dict(
(id, col.to_database(val)) for id, val, col in items
)
def get_cql(self): def get_context_size(self):
token_args = ', '.join(':{}'.format(id) for id in self.identifier) return len(self.value)
def __unicode__(self):
token_args = ', '.join(':{}'.format(self.context_id + i) for i in range(self.get_context_size()))
return "token({})".format(token_args) return "token({})".format(token_args)
def update_context(self, ctx):
for i, (col, val) in enumerate(zip(self._columns, self.value)):
ctx[str(self.context_id + i)] = col.to_database(val)

View File

@@ -40,6 +40,10 @@ class NamedColumn(AbstractQueryableColumn):
""" :rtype: NamedColumn """ """ :rtype: NamedColumn """
return self return self
@property
def db_field_name(self):
return self.name
@property @property
def cql(self): def cql(self):
return self.get_cql() return self.get_cql()

View File

@@ -10,7 +10,7 @@ from cqlengine.columns import Counter
from cqlengine.connection import connection_manager, execute, RowResult from cqlengine.connection import connection_manager, execute, RowResult
from cqlengine.exceptions import CQLEngineException, ValidationError from cqlengine.exceptions import CQLEngineException, ValidationError
from cqlengine.functions import QueryValue, Token from cqlengine.functions import QueryValue, Token, BaseQueryFunction
#CQL 3 reference: #CQL 3 reference:
#http://www.datastax.com/docs/1.1/references/cql/index #http://www.datastax.com/docs/1.1/references/cql/index
@@ -480,12 +480,14 @@ class AbstractQuerySet(object):
for arg, val in kwargs.items(): for arg, val in kwargs.items():
col_name, col_op = self._parse_filter_arg(arg) col_name, col_op = self._parse_filter_arg(arg)
quote_field = True
#resolve column and operator #resolve column and operator
try: try:
column = self.model._get_column(col_name) column = self.model._get_column(col_name)
except KeyError: except KeyError:
if col_name == 'pk__token': if col_name == 'pk__token':
column = columns._PartitionKeysToken(self.model) column = columns._PartitionKeysToken(self.model)
quote_field = False
else: else:
raise QueryException("Can't resolve column name: '{}'".format(col_name)) raise QueryException("Can't resolve column name: '{}'".format(col_name))
@@ -497,10 +499,12 @@ class AbstractQuerySet(object):
if not isinstance(val, (list, tuple)): if not isinstance(val, (list, tuple)):
raise QueryException('IN queries must use a list/tuple value') raise QueryException('IN queries must use a list/tuple value')
query_val = [column.to_database(v) for v in val] query_val = [column.to_database(v) for v in val]
elif isinstance(val, BaseQueryFunction):
query_val = val
else: else:
query_val = column.to_database(val) query_val = column.to_database(val)
clone._where.append(WhereClause(col_name, operator, query_val)) clone._where.append(WhereClause(column.db_field_name, operator, query_val, quote_field=quote_field))
return clone return clone

View File

@@ -1,3 +1,4 @@
from cqlengine.functions import QueryValue
from cqlengine.operators import BaseWhereOperator, InOperator from cqlengine.operators import BaseWhereOperator, InOperator
@@ -78,16 +79,27 @@ class BaseClause(object):
class WhereClause(BaseClause): class WhereClause(BaseClause):
""" a single where statement used in queries """ """ a single where statement used in queries """
def __init__(self, field, operator, value): def __init__(self, field, operator, value, quote_field=True):
"""
:param field:
:param operator:
:param value:
:param quote_field: hack to get the token function rendering properly
:return:
"""
if not isinstance(operator, BaseWhereOperator): if not isinstance(operator, BaseWhereOperator):
raise StatementException( raise StatementException(
"operator must be of type {}, got {}".format(BaseWhereOperator, type(operator)) "operator must be of type {}, got {}".format(BaseWhereOperator, type(operator))
) )
super(WhereClause, self).__init__(field, value) super(WhereClause, self).__init__(field, value)
self.operator = operator self.operator = operator
self.query_value = self.value if isinstance(self.value, QueryValue) else QueryValue(self.value)
self.quote_field = quote_field
def __unicode__(self): def __unicode__(self):
return u'"{}" {} :{}'.format(self.field, self.operator, self.context_id) field = ('"{}"' if self.quote_field else '{}').format(self.field)
return u'{} {} {}'.format(field, self.operator, unicode(self.query_value))
def __hash__(self): def __hash__(self):
return super(WhereClause, self).__hash__() ^ hash(self.operator) return super(WhereClause, self).__hash__() ^ hash(self.operator)
@@ -97,11 +109,18 @@ class WhereClause(BaseClause):
return self.operator.__class__ == other.operator.__class__ return self.operator.__class__ == other.operator.__class__
return False return False
def get_context_size(self):
return self.query_value.get_context_size()
def set_context_id(self, i):
super(WhereClause, self).set_context_id(i)
self.query_value.set_context_id(i)
def update_context(self, ctx): def update_context(self, ctx):
if isinstance(self.operator, InOperator): if isinstance(self.operator, InOperator):
ctx[str(self.context_id)] = InQuoter(self.value) ctx[str(self.context_id)] = InQuoter(self.value)
else: else:
super(WhereClause, self).update_context(ctx) self.query_value.update_context(ctx)
class AssignmentClause(BaseClause): class AssignmentClause(BaseClause):

View File

@@ -1,10 +1,12 @@
from datetime import datetime from datetime import datetime
import time from cqlengine.columns import DateTime
from cqlengine.tests.base import BaseCassEngTestCase from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine import columns, Model from cqlengine import columns, Model
from cqlengine import functions from cqlengine import functions
from cqlengine import query from cqlengine import query
from cqlengine.statements import WhereClause
from cqlengine.operators import EqualsOperator
class TestQuerySetOperation(BaseCassEngTestCase): class TestQuerySetOperation(BaseCassEngTestCase):
@@ -13,22 +15,26 @@ class TestQuerySetOperation(BaseCassEngTestCase):
Tests that queries with helper functions are generated properly Tests that queries with helper functions are generated properly
""" """
now = datetime.now() now = datetime.now()
col = columns.DateTime() where = WhereClause('time', EqualsOperator(), functions.MaxTimeUUID(now))
col.set_column_name('time') where.set_context_id(5)
qry = query.EqualsOperator(col, functions.MaxTimeUUID(now))
assert qry.cql == '"time" = MaxTimeUUID(:{})'.format(qry.value.identifier) self.assertEqual(str(where), '"time" = MaxTimeUUID(:5)')
ctx = {}
where.update_context(ctx)
self.assertEqual(ctx, {'5': DateTime().to_database(now)})
def test_mintimeuuid_function(self): def test_mintimeuuid_function(self):
""" """
Tests that queries with helper functions are generated properly Tests that queries with helper functions are generated properly
""" """
now = datetime.now() now = datetime.now()
col = columns.DateTime() where = WhereClause('time', EqualsOperator(), functions.MinTimeUUID(now))
col.set_column_name('time') where.set_context_id(5)
qry = query.EqualsOperator(col, functions.MinTimeUUID(now))
assert qry.cql == '"time" = MinTimeUUID(:{})'.format(qry.value.identifier) self.assertEqual(str(where), '"time" = MinTimeUUID(:5)')
ctx = {}
where.update_context(ctx)
self.assertEqual(ctx, {'5': DateTime().to_database(now)})
def test_token_function(self): def test_token_function(self):
@@ -39,12 +45,16 @@ class TestQuerySetOperation(BaseCassEngTestCase):
func = functions.Token('a', 'b') func = functions.Token('a', 'b')
q = TestModel.objects.filter(pk__token__gt=func) q = TestModel.objects.filter(pk__token__gt=func)
self.assertEquals(q._where[0].cql, 'token("p1", "p2") > token(:{}, :{})'.format(*func.identifier)) where = q._where[0]
where.set_context_id(1)
self.assertEquals(str(where), 'token("p1", "p2") > token(:{}, :{})'.format(1, 2))
# Token(tuple()) is also possible for convinience # Token(tuple()) is also possible for convenience
# it (allows for Token(obj.pk) syntax) # it (allows for Token(obj.pk) syntax)
func = functions.Token(('a', 'b')) func = functions.Token(('a', 'b'))
q = TestModel.objects.filter(pk__token__gt=func) q = TestModel.objects.filter(pk__token__gt=func)
self.assertEquals(q._where[0].cql, 'token("p1", "p2") > token(:{}, :{})'.format(*func.identifier)) where = q._where[0]
where.set_context_id(1)
self.assertEquals(str(where), 'token("p1", "p2") > token(:{}, :{})'.format(1, 2))