Merge pull request #68 from bdeggleston/query-builder

Adding support for query expressions
This commit is contained in:
Jon Haddad
2013-06-15 22:42:58 -07:00
4 changed files with 202 additions and 56 deletions

View File

@@ -29,9 +29,12 @@ class hybrid_classmethod(object):
return self.instmethod.__get__(instance, owner) return self.instmethod.__get__(instance, owner)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
""" Just a hint to IDEs that it's ok to call this """ """
Just a hint to IDEs that it's ok to call this
"""
raise NotImplementedError raise NotImplementedError
class QuerySetDescriptor(object): class QuerySetDescriptor(object):
""" """
returns a fresh queryset for the given model returns a fresh queryset for the given model
@@ -39,15 +42,29 @@ class QuerySetDescriptor(object):
""" """
def __get__(self, obj, model): def __get__(self, obj, model):
""" :rtype: QuerySet """
if model.__abstract__: if model.__abstract__:
raise CQLEngineException('cannot execute queries against abstract models') raise CQLEngineException('cannot execute queries against abstract models')
return QuerySet(model) return QuerySet(model)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
""" Just a hint to IDEs that it's ok to call this """ """
Just a hint to IDEs that it's ok to call this
:rtype: QuerySet
"""
raise NotImplementedError raise NotImplementedError
class ColumnDescriptor(AbstractColumnDescriptor):
class ColumnQueryEvaluator(AbstractColumnDescriptor):
def __init__(self, column):
self.column = column
def _get_column(self):
return self.column
class ColumnDescriptor(object):
""" """
Handles the reading and writing of column values to and from Handles the reading and writing of column values to and from
a model instance's value manager, as well as creating a model instance's value manager, as well as creating
@@ -61,6 +78,7 @@ class ColumnDescriptor(AbstractColumnDescriptor):
:return: :return:
""" """
self.column = column self.column = column
self.query_evaluator = ColumnQueryEvaluator(self.column)
def __get__(self, instance, owner): def __get__(self, instance, owner):
""" """
@@ -74,7 +92,7 @@ class ColumnDescriptor(AbstractColumnDescriptor):
if instance: if instance:
return instance._values[self.column.column_name].getval() return instance._values[self.column.column_name].getval()
else: else:
return self.column return self.query_evaluator
def __set__(self, instance, value): def __set__(self, instance, value):
""" """
@@ -203,12 +221,12 @@ class BaseModel(object):
return cls.objects.all() return cls.objects.all()
@classmethod @classmethod
def filter(cls, **kwargs): def filter(cls, *args, **kwargs):
return cls.objects.filter(**kwargs) return cls.objects.filter(*args, **kwargs)
@classmethod @classmethod
def get(cls, **kwargs): def get(cls, *args, **kwargs):
return cls.objects.get(**kwargs) return cls.objects.get(*args, **kwargs)
def save(self): def save(self):
is_new = self.pk is None is_new = self.pk is None

View File

@@ -163,12 +163,17 @@ class AbstractColumnDescriptor(object):
def _get_column(self): def _get_column(self):
raise NotImplementedError raise NotImplementedError
def in_(self, item):
"""
Returns an in operator
used in where you'd typically want to use python's `in` operator
"""
return InOperator(self._get_column(), item)
def __eq__(self, other): def __eq__(self, other):
return EqualsOperator(self._get_column(), other) return EqualsOperator(self._get_column(), other)
def __contains__(self, item):
return InOperator(self._get_column(), item)
def __gt__(self, other): def __gt__(self, other):
return GreaterThanOperator(self._get_column(), other) return GreaterThanOperator(self._get_column(), other)
@@ -195,7 +200,6 @@ class NamedColumnDescriptor(AbstractColumnDescriptor):
def to_database(self, val): def to_database(self, val):
return val return val
C = NamedColumnDescriptor
class TableDescriptor(object): class TableDescriptor(object):
""" describes a cql table """ """ describes a cql table """
@@ -204,7 +208,6 @@ class TableDescriptor(object):
self.keyspace = keyspace self.keyspace = keyspace
self.name = name self.name = name
T = TableDescriptor
class KeyspaceDescriptor(object): class KeyspaceDescriptor(object):
""" Describes a cql keyspace """ """ Describes a cql keyspace """
@@ -219,7 +222,6 @@ class KeyspaceDescriptor(object):
""" """
return TableDescriptor(self.name, name) return TableDescriptor(self.name, name)
K = KeyspaceDescriptor
class BatchType(object): class BatchType(object):
Unlogged = 'UNLOGGED' Unlogged = 'UNLOGGED'
@@ -316,8 +318,8 @@ class QuerySet(object):
def __str__(self): def __str__(self):
return str(self.__unicode__()) return str(self.__unicode__())
def __call__(self, **kwargs): def __call__(self, *args, **kwargs):
return self.filter(**kwargs) return self.filter(*args, **kwargs)
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
clone = self.__class__(self.model) clone = self.__class__(self.model)
@@ -529,9 +531,21 @@ class QuerySet(object):
else: else:
raise QueryException("Can't parse '{}'".format(arg)) raise QueryException("Can't parse '{}'".format(arg))
def filter(self, **kwargs): def filter(self, *args, **kwargs):
"""
Adds WHERE arguments to the queryset, returning a new queryset
#TODO: show examples
:rtype: QuerySet
"""
#add arguments to the where clause filters #add arguments to the where clause filters
clone = copy.deepcopy(self) clone = copy.deepcopy(self)
for operator in args:
if not isinstance(operator, QueryOperator):
raise QueryException('{} is not a valid query operator'.format(operator))
clone._where.append(operator)
for arg, val in kwargs.items(): for arg, val in kwargs.items():
col_name, col_op = self._parse_filter_arg(arg) col_name, col_op = self._parse_filter_arg(arg)
#resolve column and operator #resolve column and operator
@@ -551,20 +565,16 @@ class QuerySet(object):
return clone return clone
def query(self, *args): def get(self, *args, **kwargs):
"""
Same end result as filter, but uses the new comparator style args
ie: Model.column == val
"""
def get(self, **kwargs):
""" """
Returns a single instance matching this query, optionally with additional filter kwargs. Returns a single instance matching this query, optionally with additional filter kwargs.
A DoesNotExistError will be raised if there are no rows matching the query A DoesNotExistError will be raised if there are no rows matching the query
A MultipleObjectsFoundError will be raised if there is more than one row matching the queyr A MultipleObjectsFoundError will be raised if there is more than one row matching the queyr
""" """
if kwargs: return self.filter(**kwargs).get() if args or kwargs:
return self.filter(*args, **kwargs).get()
self._execute_query() self._execute_query()
if len(self._result_cache) == 0: if len(self._result_cache) == 0:
raise self.model.DoesNotExist raise self.model.DoesNotExist

View File

@@ -3,7 +3,7 @@ from cqlengine.query import QueryException
from cqlengine.tests.base import BaseCassEngTestCase from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.exceptions import ModelException, CQLEngineException from cqlengine.exceptions import ModelException, CQLEngineException
from cqlengine.models import Model, ModelDefinitionException from cqlengine.models import Model, ModelDefinitionException, ColumnQueryEvaluator
from cqlengine import columns from cqlengine import columns
import cqlengine import cqlengine
@@ -270,7 +270,7 @@ class TestAbstractModelClasses(BaseCassEngTestCase):
def test_abstract_columns_are_inherited(self): def test_abstract_columns_are_inherited(self):
""" Tests that columns defined in the abstract class are inherited into the concrete class """ """ Tests that columns defined in the abstract class are inherited into the concrete class """
assert hasattr(ConcreteModelWithCol, 'pkey') assert hasattr(ConcreteModelWithCol, 'pkey')
assert isinstance(ConcreteModelWithCol.pkey, columns.Column) assert isinstance(ConcreteModelWithCol.pkey, ColumnQueryEvaluator)
assert isinstance(ConcreteModelWithCol._columns['pkey'], columns.Column) assert isinstance(ConcreteModelWithCol._columns['pkey'], columns.Column)
def test_concrete_class_table_creation_cycle(self): def test_concrete_class_table_creation_cycle(self):

View File

@@ -52,14 +52,44 @@ class TestQuerySetOperation(BaseCassEngTestCase):
assert isinstance(op, query.GreaterThanOrEqualOperator) assert isinstance(op, query.GreaterThanOrEqualOperator)
assert op.value == 1 assert op.value == 1
def test_query_expression_parsing(self):
""" Tests that query experessions are evaluated properly """
query1 = TestModel.filter(TestModel.test_id == 5)
assert len(query1._where) == 1
op = query1._where[0]
assert isinstance(op, query.EqualsOperator)
assert op.value == 5
query2 = query1.filter(TestModel.expected_result >= 1)
assert len(query2._where) == 2
op = query2._where[1]
assert isinstance(op, query.GreaterThanOrEqualOperator)
assert op.value == 1
def test_using_invalid_column_names_in_filter_kwargs_raises_error(self): def test_using_invalid_column_names_in_filter_kwargs_raises_error(self):
""" """
Tests that using invalid or nonexistant column names for filter args raises an error Tests that using invalid or nonexistant column names for filter args raises an error
""" """
with self.assertRaises(query.QueryException): with self.assertRaises(query.QueryException):
query0 = TestModel.objects(nonsense=5) TestModel.objects(nonsense=5)
def test_where_clause_generation(self): def test_using_nonexistant_column_names_in_query_args_raises_error(self):
"""
Tests that using invalid or nonexistant columns for query args raises an error
"""
with self.assertRaises(AttributeError):
TestModel.objects(TestModel.nonsense == 5)
def test_using_non_query_operators_in_query_args_raises_error(self):
"""
Tests that providing query args that are not query operator instances raises an error
"""
with self.assertRaises(query.QueryException):
TestModel.objects(5)
def test_filter_method_where_clause_generation(self):
""" """
Tests the where clause creation Tests the where clause creation
""" """
@@ -73,6 +103,19 @@ class TestQuerySetOperation(BaseCassEngTestCase):
where = query2._where_clause() where = query2._where_clause()
assert where == '"test_id" = :{} AND "expected_result" >= :{}'.format(*ids) assert where == '"test_id" = :{} AND "expected_result" >= :{}'.format(*ids)
def test_query_expression_where_clause_generation(self):
"""
Tests the where clause creation
"""
query1 = TestModel.objects(TestModel.test_id == 5)
ids = [o.query_value.identifier for o in query1._where]
where = query1._where_clause()
assert where == '"test_id" = :{}'.format(*ids)
query2 = query1.filter(TestModel.expected_result >= 1)
ids = [o.query_value.identifier for o in query2._where]
where = query2._where_clause()
assert where == '"test_id" = :{} AND "expected_result" >= :{}'.format(*ids)
def test_querystring_generation(self): def test_querystring_generation(self):
""" """
@@ -88,6 +131,7 @@ class TestQuerySetOperation(BaseCassEngTestCase):
query2 = query1.filter(expected_result__gte=1) query2 = query1.filter(expected_result__gte=1)
assert len(query2._where) == 2 assert len(query2._where) == 2
assert len(query1._where) == 1
def test_the_all_method_duplicates_queryset(self): def test_the_all_method_duplicates_queryset(self):
""" """
@@ -163,12 +207,21 @@ class BaseQuerySetUsage(BaseCassEngTestCase):
class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage): class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage):
def test_count(self): def test_count(self):
""" Tests that adding filtering statements affects the count query as expected """
assert TestModel.objects.count() == 12 assert TestModel.objects.count() == 12
q = TestModel.objects(test_id=0) q = TestModel.objects(test_id=0)
assert q.count() == 4 assert q.count() == 4
def test_query_expression_count(self):
""" Tests that adding query statements affects the count query as expected """
assert TestModel.objects.count() == 12
q = TestModel.objects(TestModel.test_id == 0)
assert q.count() == 4
def test_iteration(self): def test_iteration(self):
""" Tests that iterating over a query set pulls back all of the expected results """
q = TestModel.objects(test_id=0) q = TestModel.objects(test_id=0)
#tuple of expected attempt_id, expected_result values #tuple of expected attempt_id, expected_result values
compare_set = set([(0,5), (1,10), (2,15), (3,20)]) compare_set = set([(0,5), (1,10), (2,15), (3,20)])
@@ -178,6 +231,7 @@ class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage):
compare_set.remove(val) compare_set.remove(val)
assert len(compare_set) == 0 assert len(compare_set) == 0
# test with regular filtering
q = TestModel.objects(attempt_id=3).allow_filtering() q = TestModel.objects(attempt_id=3).allow_filtering()
assert len(q) == 3 assert len(q) == 3
#tuple of expected test_id, expected_result values #tuple of expected test_id, expected_result values
@@ -188,36 +242,49 @@ class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage):
compare_set.remove(val) compare_set.remove(val)
assert len(compare_set) == 0 assert len(compare_set) == 0
# test with query method
q = TestModel.objects(TestModel.attempt_id == 3).allow_filtering()
assert len(q) == 3
#tuple of expected test_id, expected_result values
compare_set = set([(0,20), (1,20), (2,75)])
for t in q:
val = t.test_id, t.expected_result
assert val in compare_set
compare_set.remove(val)
assert len(compare_set) == 0
def test_multiple_iterations_work_properly(self): def test_multiple_iterations_work_properly(self):
""" Tests that iterating over a query set more than once works """ """ Tests that iterating over a query set more than once works """
q = TestModel.objects(test_id=0) # test with both the filtering method and the query method
#tuple of expected attempt_id, expected_result values for q in (TestModel.objects(test_id=0), TestModel.objects(TestModel.test_id == 0)):
compare_set = set([(0,5), (1,10), (2,15), (3,20)]) #tuple of expected attempt_id, expected_result values
for t in q: compare_set = set([(0,5), (1,10), (2,15), (3,20)])
val = t.attempt_id, t.expected_result for t in q:
assert val in compare_set val = t.attempt_id, t.expected_result
compare_set.remove(val) assert val in compare_set
assert len(compare_set) == 0 compare_set.remove(val)
assert len(compare_set) == 0
#try it again #try it again
compare_set = set([(0,5), (1,10), (2,15), (3,20)]) compare_set = set([(0,5), (1,10), (2,15), (3,20)])
for t in q: for t in q:
val = t.attempt_id, t.expected_result val = t.attempt_id, t.expected_result
assert val in compare_set assert val in compare_set
compare_set.remove(val) compare_set.remove(val)
assert len(compare_set) == 0 assert len(compare_set) == 0
def test_multiple_iterators_are_isolated(self): def test_multiple_iterators_are_isolated(self):
""" """
tests that the use of one iterator does not affect the behavior of another tests that the use of one iterator does not affect the behavior of another
""" """
q = TestModel.objects(test_id=0).order_by('attempt_id') for q in (TestModel.objects(test_id=0), TestModel.objects(TestModel.test_id == 0)):
expected_order = [0,1,2,3] q = q.order_by('attempt_id')
iter1 = iter(q) expected_order = [0,1,2,3]
iter2 = iter(q) iter1 = iter(q)
for attempt_id in expected_order: iter2 = iter(q)
assert iter1.next().attempt_id == attempt_id for attempt_id in expected_order:
assert iter2.next().attempt_id == attempt_id assert iter1.next().attempt_id == attempt_id
assert iter2.next().attempt_id == attempt_id
def test_get_success_case(self): def test_get_success_case(self):
""" """
@@ -240,6 +307,27 @@ class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage):
assert m.test_id == 0 assert m.test_id == 0
assert m.attempt_id == 0 assert m.attempt_id == 0
def test_query_expression_get_success_case(self):
"""
Tests that the .get() method works on new and existing querysets
"""
m = TestModel.get(TestModel.test_id == 0, TestModel.attempt_id == 0)
assert isinstance(m, TestModel)
assert m.test_id == 0
assert m.attempt_id == 0
q = TestModel.objects(TestModel.test_id == 0, TestModel.attempt_id == 0)
m = q.get()
assert isinstance(m, TestModel)
assert m.test_id == 0
assert m.attempt_id == 0
q = TestModel.objects(TestModel.test_id == 0)
m = q.get(TestModel.attempt_id == 0)
assert isinstance(m, TestModel)
assert m.test_id == 0
assert m.attempt_id == 0
def test_get_doesnotexist_exception(self): def test_get_doesnotexist_exception(self):
""" """
Tests that get calls that don't return a result raises a DoesNotExist error Tests that get calls that don't return a result raises a DoesNotExist error
@@ -273,10 +361,14 @@ class TestQuerySetOrdering(BaseQuerySetUsage):
assert model.attempt_id == expect assert model.attempt_id == expect
def test_ordering_by_non_second_primary_keys_fail(self): def test_ordering_by_non_second_primary_keys_fail(self):
# kwarg filtering
with self.assertRaises(query.QueryException): with self.assertRaises(query.QueryException):
q = TestModel.objects(test_id=0).order_by('test_id') q = TestModel.objects(test_id=0).order_by('test_id')
# kwarg filtering
with self.assertRaises(query.QueryException):
q = TestModel.objects(TestModel.test_id == 0).order_by('test_id')
def test_ordering_by_non_primary_keys_fails(self): def test_ordering_by_non_primary_keys_fails(self):
with self.assertRaises(query.QueryException): with self.assertRaises(query.QueryException):
q = TestModel.objects(test_id=0).order_by('description') q = TestModel.objects(test_id=0).order_by('description')
@@ -343,7 +435,7 @@ class TestQuerySetValidation(BaseQuerySetUsage):
""" """
with self.assertRaises(query.QueryException): with self.assertRaises(query.QueryException):
q = TestModel.objects(test_result=25) q = TestModel.objects(test_result=25)
[i for i in q] list([i for i in q])
def test_primary_key_or_index_must_have_equal_relation_filter(self): def test_primary_key_or_index_must_have_equal_relation_filter(self):
""" """
@@ -351,7 +443,7 @@ class TestQuerySetValidation(BaseQuerySetUsage):
""" """
with self.assertRaises(query.QueryException): with self.assertRaises(query.QueryException):
q = TestModel.objects(test_id__gt=0) q = TestModel.objects(test_id__gt=0)
[i for i in q] list([i for i in q])
def test_indexed_field_can_be_queried(self): def test_indexed_field_can_be_queried(self):
@@ -359,7 +451,6 @@ class TestQuerySetValidation(BaseQuerySetUsage):
Tests that queries on an indexed field will work without any primary key relations specified Tests that queries on an indexed field will work without any primary key relations specified
""" """
q = IndexedTestModel.objects(test_result=25) q = IndexedTestModel.objects(test_result=25)
count = q.count()
assert q.count() == 4 assert q.count() == 4
class TestQuerySetDelete(BaseQuerySetUsage): class TestQuerySetDelete(BaseQuerySetUsage):
@@ -438,6 +529,7 @@ class TestMinMaxTimeUUIDFunctions(BaseCassEngTestCase):
TimeUUIDQueryModel.create(partition=pk, time=uuid1(), data='4') TimeUUIDQueryModel.create(partition=pk, time=uuid1(), data='4')
time.sleep(0.2) time.sleep(0.2)
# test kwarg filtering
q = TimeUUIDQueryModel.filter(partition=pk, time__lte=functions.MaxTimeUUID(midpoint)) q = TimeUUIDQueryModel.filter(partition=pk, time__lte=functions.MaxTimeUUID(midpoint))
q = [d for d in q] q = [d for d in q]
assert len(q) == 2 assert len(q) == 2
@@ -451,13 +543,39 @@ class TestMinMaxTimeUUIDFunctions(BaseCassEngTestCase):
assert '3' in datas assert '3' in datas
assert '4' in datas assert '4' in datas
# test query expression filtering
q = TimeUUIDQueryModel.filter(
TimeUUIDQueryModel.partition == pk,
TimeUUIDQueryModel.time <= functions.MaxTimeUUID(midpoint)
)
q = [d for d in q]
assert len(q) == 2
datas = [d.data for d in q]
assert '1' in datas
assert '2' in datas
q = TimeUUIDQueryModel.filter(
TimeUUIDQueryModel.partition == pk,
TimeUUIDQueryModel.time >= functions.MinTimeUUID(midpoint)
)
assert len(q) == 2
datas = [d.data for d in q]
assert '3' in datas
assert '4' in datas
class TestInOperator(BaseQuerySetUsage): class TestInOperator(BaseQuerySetUsage):
def test_success_case(self): def test_kwarg_success_case(self):
""" Tests the in operator works with the kwarg query method """
q = TestModel.filter(test_id__in=[0,1]) q = TestModel.filter(test_id__in=[0,1])
assert q.count() == 8 assert q.count() == 8
def test_query_expression_success_case(self):
""" Tests the in operator works with the query expression query method """
q = TestModel.filter(TestModel.test_id.in_([0, 1]))
assert q.count() == 8
class TestValuesList(BaseQuerySetUsage): class TestValuesList(BaseQuerySetUsage):
def test_values_list(self): def test_values_list(self):