diff --git a/cqlengine/exceptions.py b/cqlengine/exceptions.py index 3bb026f2..3a5444e9 100644 --- a/cqlengine/exceptions.py +++ b/cqlengine/exceptions.py @@ -1,4 +1,6 @@ #cqlengine exceptions -class ModelException(BaseException): pass -class ValidationError(BaseException): pass +class CQLEngineException(BaseException): pass +class ModelException(CQLEngineException): pass +class ValidationError(CQLEngineException): pass +class QueryException(CQLEngineException): pass diff --git a/cqlengine/query.py b/cqlengine/query.py index 70343d6b..f4103910 100644 --- a/cqlengine/query.py +++ b/cqlengine/query.py @@ -1,18 +1,58 @@ +from collections import namedtuple import copy from cqlengine.connection import get_connection +from cqlengine.exceptions import QueryException #CQL 3 reference: #http://www.datastax.com/docs/1.1/references/cql/index -class Query(object): +WhereFilter = namedtuple('WhereFilter', ['column', 'operator', 'value']) - pass +class QueryOperatorException(QueryException): pass + +class QueryOperator(object): + symbol = None + + @classmethod + def get_operator(cls, symbol): + if not hasattr(cls, 'opmap'): + QueryOperator.opmap = {} + def _recurse(klass): + if klass.symbol: + QueryOperator.opmap[klass.symbol.upper()] = klass + for subklass in klass.__subclasses__(): + _recurse(subklass) + pass + _recurse(QueryOperator) + try: + return QueryOperator.opmap[symbol.upper()] + except KeyError: + raise QueryOperatorException("{} doesn't map to a QueryOperator".format(symbol)) + +class EqualsOperator(QueryOperator): + symbol = 'EQ' + +class InOperator(QueryOperator): + symbol = 'IN' + +class GreaterThanOperator(QueryOperator): + symbol = "GT" + +class GreaterThanOrEqualOperator(QueryOperator): + symbol = "GTE" + +class LessThanOperator(QueryOperator): + symbol = "LT" + +class LessThanOrEqualOperator(QueryOperator): + symbol = "LTE" class QuerySet(object): #TODO: querysets should be immutable #TODO: querysets should be executed lazily - #TODO: conflicting filter args should raise exception unless a force kwarg is supplied + #TODO: support specifying offset and limit (use slice) (maybe return a mutated queryset) + #TODO: support specifying columns to exclude or select only #CQL supports ==, >, >=, <, <=, IN (a,b,c,..n) #REVERSE, LIMIT @@ -21,9 +61,22 @@ class QuerySet(object): def __init__(self, model, query_args={}): super(QuerySet, self).__init__() self.model = model - self.query_args = query_args self.column_family_name = self.model.objects.column_family_name + #Where clause filters + self._where = [] + + #ordering arguments + self._order = [] + + #subset selection + self._limit = None + self._start = None + + #see the defer and only methods + self._defer_fields = [] + self._only_fields = [] + self._cursor = None #----query generation / execution---- @@ -31,7 +84,16 @@ class QuerySet(object): conn = get_connection() self._cursor = conn.cursor() - def _generate_querystring(self): + def _where_clause(self): + """ + Returns a where clause based on the given filter args + """ + pass + + def _select_query(self): + """ + Returns a select clause based on the given filter args + """ pass @property @@ -67,16 +129,36 @@ class QuerySet(object): pass def all(self): - return QuerySet(self.model) + clone = copy.deepcopy(self) + clone._where = [] + return clone + + def _parse_filter_arg(self, arg, val): + statement = arg.split('__') + if len(statement) == 1: + return WhereFilter(arg, None, val) + elif len(statement) == 2: + return WhereFilter(statement[0], statement[1], val) + else: + raise QueryException("Can't parse '{}'".format(arg)) def filter(self, **kwargs): - qargs = copy.deepcopy(self.query_args) - qargs.update(kwargs) - return QuerySet(self.model, query_args=qargs) + #add arguments to the where clause filters + clone = copy.deepcopy(self) + for arg, val in kwargs.items(): + raw_statement = self._parse_filter_arg(arg, val) + #resolve column and operator + try: + column = self.model._columns[raw_statement.column] + except KeyError: + raise QueryException("Can't resolve column name: '{}'".format(raw_statement.column)) - def exclude(self, **kwargs): - """ Need to invert the logic for all kwargs """ - pass + operator = QueryOperator.get_operator(raw_statement.operator) + + statement = WhereFilter(column, operator, val) + clone._where.append(statement) + + return clone def count(self): """ Returns the number of rows matched by this query """ @@ -95,6 +177,32 @@ class QuerySet(object): self._cursor.execute(qs, {self.model._pk_name:pk}) return self._get_next() + def _only_or_defer(self, action, fields): + clone = copy.deepcopy(self) + if clone._defer_fields or clone._only_fields: + raise QueryException("QuerySet alread has only or defer fields defined") + + #check for strange fields + missing_fields = [f for f in fields if f not in self.model._columns.keys()] + if missing_fields: + raise QueryException("Can't resolve fields {} in {}".format(', '.join(missing_fields), self.model.__name__)) + + if action == 'defer': + clone._defer_fields = fields + elif action == 'only': + clone._only_fields = fields + else: + raise ValueError + + return clone + + def only(self, fields): + """ Load only these fields for the returned query """ + return self._only_or_defer('only', fields) + + def defer(self, fields): + """ Don't load these fields for the returned query """ + return self._only_or_defer('defer', fields) #----writes---- def save(self, instance): @@ -137,7 +245,7 @@ class QuerySet(object): cur.execute(qs, field_values) #----delete--- - def delete(self): + def delete(self, columns=[]): """ Deletes the contents of a query """ diff --git a/cqlengine/tests/query/__init__.py b/cqlengine/tests/query/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cqlengine/tests/query/test_queryset.py b/cqlengine/tests/query/test_queryset.py new file mode 100644 index 00000000..ef1bd9c3 --- /dev/null +++ b/cqlengine/tests/query/test_queryset.py @@ -0,0 +1,52 @@ +from cqlengine.tests.base import BaseCassEngTestCase + +from cqlengine.exceptions import ModelException +from cqlengine.models import Model +from cqlengine import columns + +class TestQuerySet(BaseCassEngTestCase): + + def test_query_filter_parsing(self): + """ + Tests the queryset filter method + """ + + def test_where_clause_generation(self): + """ + Tests the where clause creation + """ + + def test_querystring_generation(self): + """ + Tests the select querystring creation + """ + + def test_queryset_is_immutable(self): + """ + Tests that calling a queryset function that changes it's state returns a new queryset + """ + + def test_queryset_slicing(self): + """ + Check that the limit and start is implemented as iterator slices + """ + + def test_proper_delete_behavior(self): + """ + Tests that deleting the contents of a queryset works properly + """ + + def test_the_all_method_clears_where_filter(self): + """ + Tests that calling all on a queryset with previously defined filters returns a queryset with no filters + """ + + def test_defining_only_and_defer_fails(self): + """ + Tests that trying to add fields to either only or defer, or doing so more than once fails + """ + + def test_defining_only_or_defer_fields_fails(self): + """ + Tests that setting only or defer fields that don't exist raises an exception + """