starting to write the QuerySet class

This commit is contained in:
Blake Eggleston
2012-11-20 23:22:58 -08:00
parent 4152a37e80
commit 61177eb340
4 changed files with 177 additions and 15 deletions

View File

@@ -1,4 +1,6 @@
#cqlengine exceptions
class ModelException(BaseException): pass
class ValidationError(BaseException): pass
class CQLEngineException(BaseException): pass
class ModelException(CQLEngineException): pass
class ValidationError(CQLEngineException): pass
class QueryException(CQLEngineException): pass

View File

@@ -1,18 +1,58 @@
from collections import namedtuple
import copy
from cqlengine.connection import get_connection
from cqlengine.exceptions import QueryException
#CQL 3 reference:
#http://www.datastax.com/docs/1.1/references/cql/index
class Query(object):
WhereFilter = namedtuple('WhereFilter', ['column', 'operator', 'value'])
pass
class QueryOperatorException(QueryException): pass
class QueryOperator(object):
symbol = None
@classmethod
def get_operator(cls, symbol):
if not hasattr(cls, 'opmap'):
QueryOperator.opmap = {}
def _recurse(klass):
if klass.symbol:
QueryOperator.opmap[klass.symbol.upper()] = klass
for subklass in klass.__subclasses__():
_recurse(subklass)
pass
_recurse(QueryOperator)
try:
return QueryOperator.opmap[symbol.upper()]
except KeyError:
raise QueryOperatorException("{} doesn't map to a QueryOperator".format(symbol))
class EqualsOperator(QueryOperator):
symbol = 'EQ'
class InOperator(QueryOperator):
symbol = 'IN'
class GreaterThanOperator(QueryOperator):
symbol = "GT"
class GreaterThanOrEqualOperator(QueryOperator):
symbol = "GTE"
class LessThanOperator(QueryOperator):
symbol = "LT"
class LessThanOrEqualOperator(QueryOperator):
symbol = "LTE"
class QuerySet(object):
#TODO: querysets should be immutable
#TODO: querysets should be executed lazily
#TODO: conflicting filter args should raise exception unless a force kwarg is supplied
#TODO: support specifying offset and limit (use slice) (maybe return a mutated queryset)
#TODO: support specifying columns to exclude or select only
#CQL supports ==, >, >=, <, <=, IN (a,b,c,..n)
#REVERSE, LIMIT
@@ -21,9 +61,22 @@ class QuerySet(object):
def __init__(self, model, query_args={}):
super(QuerySet, self).__init__()
self.model = model
self.query_args = query_args
self.column_family_name = self.model.objects.column_family_name
#Where clause filters
self._where = []
#ordering arguments
self._order = []
#subset selection
self._limit = None
self._start = None
#see the defer and only methods
self._defer_fields = []
self._only_fields = []
self._cursor = None
#----query generation / execution----
@@ -31,7 +84,16 @@ class QuerySet(object):
conn = get_connection()
self._cursor = conn.cursor()
def _generate_querystring(self):
def _where_clause(self):
"""
Returns a where clause based on the given filter args
"""
pass
def _select_query(self):
"""
Returns a select clause based on the given filter args
"""
pass
@property
@@ -67,16 +129,36 @@ class QuerySet(object):
pass
def all(self):
return QuerySet(self.model)
clone = copy.deepcopy(self)
clone._where = []
return clone
def _parse_filter_arg(self, arg, val):
statement = arg.split('__')
if len(statement) == 1:
return WhereFilter(arg, None, val)
elif len(statement) == 2:
return WhereFilter(statement[0], statement[1], val)
else:
raise QueryException("Can't parse '{}'".format(arg))
def filter(self, **kwargs):
qargs = copy.deepcopy(self.query_args)
qargs.update(kwargs)
return QuerySet(self.model, query_args=qargs)
#add arguments to the where clause filters
clone = copy.deepcopy(self)
for arg, val in kwargs.items():
raw_statement = self._parse_filter_arg(arg, val)
#resolve column and operator
try:
column = self.model._columns[raw_statement.column]
except KeyError:
raise QueryException("Can't resolve column name: '{}'".format(raw_statement.column))
def exclude(self, **kwargs):
""" Need to invert the logic for all kwargs """
pass
operator = QueryOperator.get_operator(raw_statement.operator)
statement = WhereFilter(column, operator, val)
clone._where.append(statement)
return clone
def count(self):
""" Returns the number of rows matched by this query """
@@ -95,6 +177,32 @@ class QuerySet(object):
self._cursor.execute(qs, {self.model._pk_name:pk})
return self._get_next()
def _only_or_defer(self, action, fields):
clone = copy.deepcopy(self)
if clone._defer_fields or clone._only_fields:
raise QueryException("QuerySet alread has only or defer fields defined")
#check for strange fields
missing_fields = [f for f in fields if f not in self.model._columns.keys()]
if missing_fields:
raise QueryException("Can't resolve fields {} in {}".format(', '.join(missing_fields), self.model.__name__))
if action == 'defer':
clone._defer_fields = fields
elif action == 'only':
clone._only_fields = fields
else:
raise ValueError
return clone
def only(self, fields):
""" Load only these fields for the returned query """
return self._only_or_defer('only', fields)
def defer(self, fields):
""" Don't load these fields for the returned query """
return self._only_or_defer('defer', fields)
#----writes----
def save(self, instance):
@@ -137,7 +245,7 @@ class QuerySet(object):
cur.execute(qs, field_values)
#----delete---
def delete(self):
def delete(self, columns=[]):
"""
Deletes the contents of a query
"""

View File

View File

@@ -0,0 +1,52 @@
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.exceptions import ModelException
from cqlengine.models import Model
from cqlengine import columns
class TestQuerySet(BaseCassEngTestCase):
def test_query_filter_parsing(self):
"""
Tests the queryset filter method
"""
def test_where_clause_generation(self):
"""
Tests the where clause creation
"""
def test_querystring_generation(self):
"""
Tests the select querystring creation
"""
def test_queryset_is_immutable(self):
"""
Tests that calling a queryset function that changes it's state returns a new queryset
"""
def test_queryset_slicing(self):
"""
Check that the limit and start is implemented as iterator slices
"""
def test_proper_delete_behavior(self):
"""
Tests that deleting the contents of a queryset works properly
"""
def test_the_all_method_clears_where_filter(self):
"""
Tests that calling all on a queryset with previously defined filters returns a queryset with no filters
"""
def test_defining_only_and_defer_fails(self):
"""
Tests that trying to add fields to either only or defer, or doing so more than once fails
"""
def test_defining_only_or_defer_fields_fails(self):
"""
Tests that setting only or defer fields that don't exist raises an exception
"""