Merge pull request #68 from bdeggleston/query-builder

Adding support for query expressions
This commit is contained in:
Jon Haddad
2013-06-15 22:42:58 -07:00
4 changed files with 202 additions and 56 deletions

View File

@@ -29,9 +29,12 @@ class hybrid_classmethod(object):
return self.instmethod.__get__(instance, owner)
def __call__(self, *args, **kwargs):
""" Just a hint to IDEs that it's ok to call this """
"""
Just a hint to IDEs that it's ok to call this
"""
raise NotImplementedError
class QuerySetDescriptor(object):
"""
returns a fresh queryset for the given model
@@ -39,15 +42,29 @@ class QuerySetDescriptor(object):
"""
def __get__(self, obj, model):
""" :rtype: QuerySet """
if model.__abstract__:
raise CQLEngineException('cannot execute queries against abstract models')
return QuerySet(model)
def __call__(self, *args, **kwargs):
""" Just a hint to IDEs that it's ok to call this """
"""
Just a hint to IDEs that it's ok to call this
:rtype: QuerySet
"""
raise NotImplementedError
class ColumnDescriptor(AbstractColumnDescriptor):
class ColumnQueryEvaluator(AbstractColumnDescriptor):
def __init__(self, column):
self.column = column
def _get_column(self):
return self.column
class ColumnDescriptor(object):
"""
Handles the reading and writing of column values to and from
a model instance's value manager, as well as creating
@@ -61,6 +78,7 @@ class ColumnDescriptor(AbstractColumnDescriptor):
:return:
"""
self.column = column
self.query_evaluator = ColumnQueryEvaluator(self.column)
def __get__(self, instance, owner):
"""
@@ -74,7 +92,7 @@ class ColumnDescriptor(AbstractColumnDescriptor):
if instance:
return instance._values[self.column.column_name].getval()
else:
return self.column
return self.query_evaluator
def __set__(self, instance, value):
"""
@@ -203,12 +221,12 @@ class BaseModel(object):
return cls.objects.all()
@classmethod
def filter(cls, **kwargs):
return cls.objects.filter(**kwargs)
def filter(cls, *args, **kwargs):
return cls.objects.filter(*args, **kwargs)
@classmethod
def get(cls, **kwargs):
return cls.objects.get(**kwargs)
def get(cls, *args, **kwargs):
return cls.objects.get(*args, **kwargs)
def save(self):
is_new = self.pk is None

View File

@@ -163,12 +163,17 @@ class AbstractColumnDescriptor(object):
def _get_column(self):
raise NotImplementedError
def in_(self, item):
"""
Returns an in operator
used in where you'd typically want to use python's `in` operator
"""
return InOperator(self._get_column(), item)
def __eq__(self, other):
return EqualsOperator(self._get_column(), other)
def __contains__(self, item):
return InOperator(self._get_column(), item)
def __gt__(self, other):
return GreaterThanOperator(self._get_column(), other)
@@ -195,7 +200,6 @@ class NamedColumnDescriptor(AbstractColumnDescriptor):
def to_database(self, val):
return val
C = NamedColumnDescriptor
class TableDescriptor(object):
""" describes a cql table """
@@ -204,7 +208,6 @@ class TableDescriptor(object):
self.keyspace = keyspace
self.name = name
T = TableDescriptor
class KeyspaceDescriptor(object):
""" Describes a cql keyspace """
@@ -219,7 +222,6 @@ class KeyspaceDescriptor(object):
"""
return TableDescriptor(self.name, name)
K = KeyspaceDescriptor
class BatchType(object):
Unlogged = 'UNLOGGED'
@@ -316,8 +318,8 @@ class QuerySet(object):
def __str__(self):
return str(self.__unicode__())
def __call__(self, **kwargs):
return self.filter(**kwargs)
def __call__(self, *args, **kwargs):
return self.filter(*args, **kwargs)
def __deepcopy__(self, memo):
clone = self.__class__(self.model)
@@ -529,9 +531,21 @@ class QuerySet(object):
else:
raise QueryException("Can't parse '{}'".format(arg))
def filter(self, **kwargs):
def filter(self, *args, **kwargs):
"""
Adds WHERE arguments to the queryset, returning a new queryset
#TODO: show examples
:rtype: QuerySet
"""
#add arguments to the where clause filters
clone = copy.deepcopy(self)
for operator in args:
if not isinstance(operator, QueryOperator):
raise QueryException('{} is not a valid query operator'.format(operator))
clone._where.append(operator)
for arg, val in kwargs.items():
col_name, col_op = self._parse_filter_arg(arg)
#resolve column and operator
@@ -551,20 +565,16 @@ class QuerySet(object):
return clone
def query(self, *args):
"""
Same end result as filter, but uses the new comparator style args
ie: Model.column == val
"""
def get(self, **kwargs):
def get(self, *args, **kwargs):
"""
Returns a single instance matching this query, optionally with additional filter kwargs.
A DoesNotExistError will be raised if there are no rows matching the query
A MultipleObjectsFoundError will be raised if there is more than one row matching the queyr
"""
if kwargs: return self.filter(**kwargs).get()
if args or kwargs:
return self.filter(*args, **kwargs).get()
self._execute_query()
if len(self._result_cache) == 0:
raise self.model.DoesNotExist

View File

@@ -3,7 +3,7 @@ from cqlengine.query import QueryException
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.exceptions import ModelException, CQLEngineException
from cqlengine.models import Model, ModelDefinitionException
from cqlengine.models import Model, ModelDefinitionException, ColumnQueryEvaluator
from cqlengine import columns
import cqlengine
@@ -270,7 +270,7 @@ class TestAbstractModelClasses(BaseCassEngTestCase):
def test_abstract_columns_are_inherited(self):
""" Tests that columns defined in the abstract class are inherited into the concrete class """
assert hasattr(ConcreteModelWithCol, 'pkey')
assert isinstance(ConcreteModelWithCol.pkey, columns.Column)
assert isinstance(ConcreteModelWithCol.pkey, ColumnQueryEvaluator)
assert isinstance(ConcreteModelWithCol._columns['pkey'], columns.Column)
def test_concrete_class_table_creation_cycle(self):

View File

@@ -52,14 +52,44 @@ class TestQuerySetOperation(BaseCassEngTestCase):
assert isinstance(op, query.GreaterThanOrEqualOperator)
assert op.value == 1
def test_query_expression_parsing(self):
""" Tests that query experessions are evaluated properly """
query1 = TestModel.filter(TestModel.test_id == 5)
assert len(query1._where) == 1
op = query1._where[0]
assert isinstance(op, query.EqualsOperator)
assert op.value == 5
query2 = query1.filter(TestModel.expected_result >= 1)
assert len(query2._where) == 2
op = query2._where[1]
assert isinstance(op, query.GreaterThanOrEqualOperator)
assert op.value == 1
def test_using_invalid_column_names_in_filter_kwargs_raises_error(self):
"""
Tests that using invalid or nonexistant column names for filter args raises an error
"""
with self.assertRaises(query.QueryException):
query0 = TestModel.objects(nonsense=5)
TestModel.objects(nonsense=5)
def test_where_clause_generation(self):
def test_using_nonexistant_column_names_in_query_args_raises_error(self):
"""
Tests that using invalid or nonexistant columns for query args raises an error
"""
with self.assertRaises(AttributeError):
TestModel.objects(TestModel.nonsense == 5)
def test_using_non_query_operators_in_query_args_raises_error(self):
"""
Tests that providing query args that are not query operator instances raises an error
"""
with self.assertRaises(query.QueryException):
TestModel.objects(5)
def test_filter_method_where_clause_generation(self):
"""
Tests the where clause creation
"""
@@ -73,6 +103,19 @@ class TestQuerySetOperation(BaseCassEngTestCase):
where = query2._where_clause()
assert where == '"test_id" = :{} AND "expected_result" >= :{}'.format(*ids)
def test_query_expression_where_clause_generation(self):
"""
Tests the where clause creation
"""
query1 = TestModel.objects(TestModel.test_id == 5)
ids = [o.query_value.identifier for o in query1._where]
where = query1._where_clause()
assert where == '"test_id" = :{}'.format(*ids)
query2 = query1.filter(TestModel.expected_result >= 1)
ids = [o.query_value.identifier for o in query2._where]
where = query2._where_clause()
assert where == '"test_id" = :{} AND "expected_result" >= :{}'.format(*ids)
def test_querystring_generation(self):
"""
@@ -88,6 +131,7 @@ class TestQuerySetOperation(BaseCassEngTestCase):
query2 = query1.filter(expected_result__gte=1)
assert len(query2._where) == 2
assert len(query1._where) == 1
def test_the_all_method_duplicates_queryset(self):
"""
@@ -163,12 +207,21 @@ class BaseQuerySetUsage(BaseCassEngTestCase):
class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage):
def test_count(self):
""" Tests that adding filtering statements affects the count query as expected """
assert TestModel.objects.count() == 12
q = TestModel.objects(test_id=0)
assert q.count() == 4
def test_query_expression_count(self):
""" Tests that adding query statements affects the count query as expected """
assert TestModel.objects.count() == 12
q = TestModel.objects(TestModel.test_id == 0)
assert q.count() == 4
def test_iteration(self):
""" Tests that iterating over a query set pulls back all of the expected results """
q = TestModel.objects(test_id=0)
#tuple of expected attempt_id, expected_result values
compare_set = set([(0,5), (1,10), (2,15), (3,20)])
@@ -178,6 +231,7 @@ class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage):
compare_set.remove(val)
assert len(compare_set) == 0
# test with regular filtering
q = TestModel.objects(attempt_id=3).allow_filtering()
assert len(q) == 3
#tuple of expected test_id, expected_result values
@@ -188,9 +242,21 @@ class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage):
compare_set.remove(val)
assert len(compare_set) == 0
# test with query method
q = TestModel.objects(TestModel.attempt_id == 3).allow_filtering()
assert len(q) == 3
#tuple of expected test_id, expected_result values
compare_set = set([(0,20), (1,20), (2,75)])
for t in q:
val = t.test_id, t.expected_result
assert val in compare_set
compare_set.remove(val)
assert len(compare_set) == 0
def test_multiple_iterations_work_properly(self):
""" Tests that iterating over a query set more than once works """
q = TestModel.objects(test_id=0)
# test with both the filtering method and the query method
for q in (TestModel.objects(test_id=0), TestModel.objects(TestModel.test_id == 0)):
#tuple of expected attempt_id, expected_result values
compare_set = set([(0,5), (1,10), (2,15), (3,20)])
for t in q:
@@ -211,7 +277,8 @@ class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage):
"""
tests that the use of one iterator does not affect the behavior of another
"""
q = TestModel.objects(test_id=0).order_by('attempt_id')
for q in (TestModel.objects(test_id=0), TestModel.objects(TestModel.test_id == 0)):
q = q.order_by('attempt_id')
expected_order = [0,1,2,3]
iter1 = iter(q)
iter2 = iter(q)
@@ -240,6 +307,27 @@ class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage):
assert m.test_id == 0
assert m.attempt_id == 0
def test_query_expression_get_success_case(self):
"""
Tests that the .get() method works on new and existing querysets
"""
m = TestModel.get(TestModel.test_id == 0, TestModel.attempt_id == 0)
assert isinstance(m, TestModel)
assert m.test_id == 0
assert m.attempt_id == 0
q = TestModel.objects(TestModel.test_id == 0, TestModel.attempt_id == 0)
m = q.get()
assert isinstance(m, TestModel)
assert m.test_id == 0
assert m.attempt_id == 0
q = TestModel.objects(TestModel.test_id == 0)
m = q.get(TestModel.attempt_id == 0)
assert isinstance(m, TestModel)
assert m.test_id == 0
assert m.attempt_id == 0
def test_get_doesnotexist_exception(self):
"""
Tests that get calls that don't return a result raises a DoesNotExist error
@@ -273,10 +361,14 @@ class TestQuerySetOrdering(BaseQuerySetUsage):
assert model.attempt_id == expect
def test_ordering_by_non_second_primary_keys_fail(self):
# kwarg filtering
with self.assertRaises(query.QueryException):
q = TestModel.objects(test_id=0).order_by('test_id')
# kwarg filtering
with self.assertRaises(query.QueryException):
q = TestModel.objects(TestModel.test_id == 0).order_by('test_id')
def test_ordering_by_non_primary_keys_fails(self):
with self.assertRaises(query.QueryException):
q = TestModel.objects(test_id=0).order_by('description')
@@ -343,7 +435,7 @@ class TestQuerySetValidation(BaseQuerySetUsage):
"""
with self.assertRaises(query.QueryException):
q = TestModel.objects(test_result=25)
[i for i in q]
list([i for i in q])
def test_primary_key_or_index_must_have_equal_relation_filter(self):
"""
@@ -351,7 +443,7 @@ class TestQuerySetValidation(BaseQuerySetUsage):
"""
with self.assertRaises(query.QueryException):
q = TestModel.objects(test_id__gt=0)
[i for i in q]
list([i for i in q])
def test_indexed_field_can_be_queried(self):
@@ -359,7 +451,6 @@ class TestQuerySetValidation(BaseQuerySetUsage):
Tests that queries on an indexed field will work without any primary key relations specified
"""
q = IndexedTestModel.objects(test_result=25)
count = q.count()
assert q.count() == 4
class TestQuerySetDelete(BaseQuerySetUsage):
@@ -438,6 +529,7 @@ class TestMinMaxTimeUUIDFunctions(BaseCassEngTestCase):
TimeUUIDQueryModel.create(partition=pk, time=uuid1(), data='4')
time.sleep(0.2)
# test kwarg filtering
q = TimeUUIDQueryModel.filter(partition=pk, time__lte=functions.MaxTimeUUID(midpoint))
q = [d for d in q]
assert len(q) == 2
@@ -451,13 +543,39 @@ class TestMinMaxTimeUUIDFunctions(BaseCassEngTestCase):
assert '3' in datas
assert '4' in datas
# test query expression filtering
q = TimeUUIDQueryModel.filter(
TimeUUIDQueryModel.partition == pk,
TimeUUIDQueryModel.time <= functions.MaxTimeUUID(midpoint)
)
q = [d for d in q]
assert len(q) == 2
datas = [d.data for d in q]
assert '1' in datas
assert '2' in datas
q = TimeUUIDQueryModel.filter(
TimeUUIDQueryModel.partition == pk,
TimeUUIDQueryModel.time >= functions.MinTimeUUID(midpoint)
)
assert len(q) == 2
datas = [d.data for d in q]
assert '3' in datas
assert '4' in datas
class TestInOperator(BaseQuerySetUsage):
def test_success_case(self):
def test_kwarg_success_case(self):
""" Tests the in operator works with the kwarg query method """
q = TestModel.filter(test_id__in=[0,1])
assert q.count() == 8
def test_query_expression_success_case(self):
""" Tests the in operator works with the query expression query method """
q = TestModel.filter(TestModel.test_id.in_([0, 1]))
assert q.count() == 8
class TestValuesList(BaseQuerySetUsage):
def test_values_list(self):