Files
deb-python-cassandra-driver/cqlengine/query.py

571 lines
19 KiB
Python

from collections import namedtuple
import copy
from hashlib import md5
from time import time
from cqlengine.connection import connection_manager
from cqlengine.exceptions import CQLEngineException
from cqlengine.functions import BaseQueryFunction
#CQL 3 reference:
#http://www.datastax.com/docs/1.1/references/cql/index
class QueryException(CQLEngineException): pass
class QueryOperatorException(QueryException): pass
class QueryOperator(object):
# The symbol that identifies this operator in filter kwargs
# ie: colname__<symbol>
symbol = None
# The comparator symbol this operator uses in cql
cql_symbol = None
def __init__(self, column, value):
self.column = column
self.value = value
#the identifier is a unique key that will be used in string
#replacement on query strings, it's created from a hash
#of this object's id and the time
self.identifier = md5(str(id(self)) + str(time())).hexdigest()
#perform validation on this operator
self.validate_operator()
self.validate_value()
@property
def cql(self):
"""
Returns this operator's portion of the WHERE clause
:param valname: the dict key that this operator's compare value will be found in
"""
if isinstance(self.value, BaseQueryFunction):
return '"{}" {} {}'.format(self.column.db_field_name, self.cql_symbol, self.value.to_cql(self.identifier))
else:
return '"{}" {} :{}'.format(self.column.db_field_name, self.cql_symbol, self.identifier)
def validate_operator(self):
"""
Checks that this operator can be used on the column provided
"""
if self.symbol is None:
raise QueryOperatorException(
"{} is not a valid operator, use one with 'symbol' defined".format(
self.__class__.__name__
)
)
if self.cql_symbol is None:
raise QueryOperatorException(
"{} is not a valid operator, use one with 'cql_symbol' defined".format(
self.__class__.__name__
)
)
def validate_value(self):
"""
Checks that the compare value works with this operator
Doesn't do anything by default
"""
pass
def get_dict(self):
"""
Returns this operators contribution to the cql.query arg dictionanry
ie: if this column's name is colname, and the identifier is colval,
this should return the dict: {'colval':<self.value>}
SELECT * FROM column_family WHERE colname=:colval
"""
if isinstance(self.value, BaseQueryFunction):
return {self.identifier: self.column.to_database(self.value.get_value())}
else:
return {self.identifier: self.column.to_database(self.value)}
@classmethod
def get_operator(cls, symbol):
if not hasattr(cls, 'opmap'):
QueryOperator.opmap = {}
def _recurse(klass):
if klass.symbol:
QueryOperator.opmap[klass.symbol.upper()] = klass
for subklass in klass.__subclasses__():
_recurse(subklass)
pass
_recurse(QueryOperator)
try:
return QueryOperator.opmap[symbol.upper()]
except KeyError:
raise QueryOperatorException("{} doesn't map to a QueryOperator".format(symbol))
class EqualsOperator(QueryOperator):
symbol = 'EQ'
cql_symbol = '='
class InOperator(EqualsOperator):
symbol = 'IN'
cql_symbol = 'IN'
class GreaterThanOperator(QueryOperator):
symbol = "GT"
cql_symbol = '>'
class GreaterThanOrEqualOperator(QueryOperator):
symbol = "GTE"
cql_symbol = '>='
class LessThanOperator(QueryOperator):
symbol = "LT"
cql_symbol = '<'
class LessThanOrEqualOperator(QueryOperator):
symbol = "LTE"
cql_symbol = '<='
class QuerySet(object):
def __init__(self, model):
super(QuerySet, self).__init__()
self.model = model
self.column_family_name = self.model.column_family_name()
#Where clause filters
self._where = []
#ordering arguments
self._order = None
self._allow_filtering = False
#CQL has a default limit of 10000, it's defined here
#because explicit is better than implicit
self._limit = 10000
#see the defer and only methods
self._defer_fields = []
self._only_fields = []
#results cache
self._con = None
self._cur = None
self._result_cache = None
self._result_idx = None
def __unicode__(self):
return self._select_query()
def __str__(self):
return str(self.__unicode__())
def __call__(self, **kwargs):
return self.filter(**kwargs)
def __deepcopy__(self, memo):
clone = self.__class__(self.model)
for k,v in self.__dict__.items():
if k in ['_con', '_cur', '_result_cache', '_result_idx']:
clone.__dict__[k] = None
else:
clone.__dict__[k] = copy.deepcopy(v, memo)
return clone
def __len__(self):
return self.count()
def __del__(self):
if self._con:
self._con.close()
self._con = None
self._cur = None
#----query generation / execution----
def _validate_where_syntax(self):
""" Checks that a filterset will not create invalid cql """
#check that there's either a = or IN relationship with a primary key or indexed field
equal_ops = [w for w in self._where if isinstance(w, EqualsOperator)]
if not any([w.column.primary_key or w.column.index for w in equal_ops]):
raise QueryException('Where clauses require either a "=" or "IN" comparison with either a primary key or indexed field')
if not self._allow_filtering:
#if the query is not on an indexed field
if not any([w.column.index for w in equal_ops]):
if not any([w.column._partition_key for w in equal_ops]):
raise QueryException('Filtering on a clustering key without a partition key is not allowed unless allow_filtering() is called on the querset')
#TODO: abuse this to see if we can get cql to raise an exception
def _where_clause(self):
""" Returns a where clause based on the given filter args """
self._validate_where_syntax()
return ' AND '.join([f.cql for f in self._where])
def _where_values(self):
""" Returns the value dict to be passed to the cql query """
values = {}
for where in self._where:
values.update(where.get_dict())
return values
def _select_query(self):
"""
Returns a select clause based on the given filter args
"""
fields = self.model._columns.keys()
if self._defer_fields:
fields = [f for f in fields if f not in self._defer_fields]
elif self._only_fields:
fields = [f for f in fields if f in self._only_fields]
db_fields = [self.model._columns[f].db_field_name for f in fields]
qs = ['SELECT {}'.format(', '.join(['"{}"'.format(f) for f in db_fields]))]
qs += ['FROM {}'.format(self.column_family_name)]
if self._where:
qs += ['WHERE {}'.format(self._where_clause())]
if self._order:
qs += ['ORDER BY {}'.format(self._order)]
if self._limit:
qs += ['LIMIT {}'.format(self._limit)]
if self._allow_filtering:
qs += ['ALLOW FILTERING']
return ' '.join(qs)
#----Reads------
def _execute_query(self):
if self._result_cache is None:
self._con = connection_manager()
self._cur = self._con.execute(self._select_query(), self._where_values())
self._result_cache = [None]*self._cur.rowcount
def _fill_result_cache_to_idx(self, idx):
self._execute_query()
if self._result_idx is None:
self._result_idx = -1
qty = idx - self._result_idx
if qty < 1:
return
else:
names = [i[0] for i in self._cur.description]
for values in self._cur.fetchmany(qty):
value_dict = dict(zip(names, values))
self._result_idx += 1
self._result_cache[self._result_idx] = self._construct_instance(value_dict)
#return the connection to the connection pool if we have all objects
if self._result_cache and self._result_cache[-1] is not None:
self._con.close()
self._con = None
self._cur = None
def __iter__(self):
self._execute_query()
for idx in range(len(self._result_cache)):
instance = self._result_cache[idx]
if instance is None:
self._fill_result_cache_to_idx(idx)
yield self._result_cache[idx]
def __getitem__(self, s):
self._execute_query()
num_results = len(self._result_cache)
if isinstance(s, slice):
#calculate the amount of results that need to be loaded
end = num_results if s.step is None else s.step
if end < 0:
end += num_results
else:
end -= 1
self._fill_result_cache_to_idx(end)
return self._result_cache[s.start:s.stop:s.step]
else:
#return the object at this index
s = long(s)
#handle negative indexing
if s < 0: s += num_results
if s >= num_results:
raise IndexError
else:
self._fill_result_cache_to_idx(s)
return self._result_cache[s]
def _construct_instance(self, values):
#translate column names to model names
field_dict = {}
db_map = self.model._db_map
for key, val in values.items():
if key in db_map:
field_dict[db_map[key]] = val
else:
field_dict[key] = val
return self.model(**field_dict)
def first(self):
try:
return iter(self).next()
except StopIteration:
return None
def all(self):
clone = copy.deepcopy(self)
clone._where = []
return clone
def _parse_filter_arg(self, arg):
"""
Parses a filter arg in the format:
<colname>__<op>
:returns: colname, op tuple
"""
statement = arg.split('__')
if len(statement) == 1:
return arg, None
elif len(statement) == 2:
return statement[0], statement[1]
else:
raise QueryException("Can't parse '{}'".format(arg))
def filter(self, **kwargs):
#add arguments to the where clause filters
clone = copy.deepcopy(self)
for arg, val in kwargs.items():
col_name, col_op = self._parse_filter_arg(arg)
#resolve column and operator
try:
column = self.model._columns[col_name]
except KeyError:
raise QueryException("Can't resolve column name: '{}'".format(col_name))
#get query operator, or use equals if not supplied
operator_class = QueryOperator.get_operator(col_op or 'EQ')
operator = operator_class(column, val)
clone._where.append(operator)
return clone
def get(self, **kwargs):
"""
Returns a single instance matching this query, optionally with additional filter kwargs.
A DoesNotExistError will be raised if there are no rows matching the query
A MultipleObjectsFoundError will be raised if there is more than one row matching the queyr
"""
if kwargs: return self.filter(**kwargs).get()
self._execute_query()
if len(self._result_cache) == 0:
raise self.model.DoesNotExist
elif len(self._result_cache) > 1:
raise self.model.MultipleObjectsReturned(
'{} objects found'.format(len(self._result_cache)))
else:
return self[0]
def order_by(self, colname):
"""
orders the result set.
ordering can only select one column, and it must be the second column in a composite primary key
Default order is ascending, prepend a '-' to the column name for descending
"""
if colname is None:
clone = copy.deepcopy(self)
clone._order = None
return clone
order_type = 'DESC' if colname.startswith('-') else 'ASC'
colname = colname.replace('-', '')
column = self.model._columns.get(colname)
if column is None:
raise QueryException("Can't resolve the column name: '{}'".format(colname))
#validate the column selection
if not column.primary_key:
raise QueryException(
"Can't order on '{}', can only order on (clustered) primary keys".format(colname))
pks = [v for k,v in self.model._columns.items() if v.primary_key]
if column == pks[0]:
raise QueryException(
"Can't order by the first primary key, clustering (secondary) keys only")
clone = copy.deepcopy(self)
clone._order = '"{}" {}'.format(column.db_field_name, order_type)
return clone
def count(self):
""" Returns the number of rows matched by this query """
#TODO: check for previous query execution and return row count if it exists
if self._result_cache is None:
qs = ['SELECT COUNT(*)']
qs += ['FROM {}'.format(self.column_family_name)]
if self._where:
qs += ['WHERE {}'.format(self._where_clause())]
if self._allow_filtering:
qs += ['ALLOW FILTERING']
qs = ' '.join(qs)
with connection_manager() as con:
cur = con.execute(qs, self._where_values())
return cur.fetchone()[0]
else:
return len(self._result_cache)
def limit(self, v):
"""
Sets the limit on the number of results returned
CQL has a default limit of 10,000
"""
if not (v is None or isinstance(v, (int, long))):
raise TypeError
if v == self._limit:
return self
if v < 0:
raise QueryException("Negative limit is not allowed")
clone = copy.deepcopy(self)
clone._limit = v
return clone
def allow_filtering(self):
"""
Enables the unwise practive of querying on a clustering
key without also defining a partition key
"""
clone = copy.deepcopy(self)
clone._allow_filtering = True
return clone
def _only_or_defer(self, action, fields):
clone = copy.deepcopy(self)
if clone._defer_fields or clone._only_fields:
raise QueryException("QuerySet alread has only or defer fields defined")
#check for strange fields
missing_fields = [f for f in fields if f not in self.model._columns.keys()]
if missing_fields:
raise QueryException(
"Can't resolve fields {} in {}".format(
', '.join(missing_fields), self.model.__name__))
if action == 'defer':
clone._defer_fields = fields
elif action == 'only':
clone._only_fields = fields
else:
raise ValueError
return clone
def only(self, fields):
""" Load only these fields for the returned query """
return self._only_or_defer('only', fields)
def defer(self, fields):
""" Don't load these fields for the returned query """
return self._only_or_defer('defer', fields)
#----writes----
def save(self, instance):
"""
Creates / updates a row.
This is a blind insert call.
All validation and cleaning needs to happen
prior to calling this.
"""
assert type(instance) == self.model
#organize data
value_pairs = []
values = instance.as_dict()
#get defined fields and their column names
for name, col in self.model._columns.items():
val = values.get(name)
if val is None: continue
value_pairs += [(col.db_field_name, val)]
#construct query string
field_names = zip(*value_pairs)[0]
field_values = dict(value_pairs)
qs = ["INSERT INTO {}".format(self.column_family_name)]
qs += ["({})".format(', '.join(['"{}"'.format(f) for f in field_names]))]
qs += ['VALUES']
qs += ["({})".format(', '.join([':'+f for f in field_names]))]
qs = ' '.join(qs)
with connection_manager() as con:
con.execute(qs, field_values)
#delete deleted / nulled columns
deleted = [k for k,v in instance._values.items() if v.deleted]
if deleted:
del_fields = [self.model._columns[f] for f in deleted]
del_fields = [f.db_field_name for f in del_fields if not f.primary_key]
pks = self.model._primary_keys
qs = ['DELETE {}'.format(', '.join(['"{}"'.format(f) for f in del_fields]))]
qs += ['FROM {}'.format(self.column_family_name)]
qs += ['WHERE']
eq = lambda col: '{0} = :{0}'.format(v.column.db_field_name)
qs += [' AND '.join([eq(f) for f in pks.values()])]
qs = ' '.join(qs)
pk_dict = dict([(v.db_field_name, getattr(instance, k)) for k,v in pks.items()])
with connection_manager() as con:
con.execute(qs, pk_dict)
def create(self, **kwargs):
return self.model(**kwargs).save()
#----delete---
def delete(self, columns=[]):
"""
Deletes the contents of a query
"""
#validate where clause
partition_key = self.model._primary_keys.values()[0]
if not any([c.column == partition_key for c in self._where]):
raise QueryException("The partition key must be defined on delete queries")
qs = ['DELETE FROM {}'.format(self.column_family_name)]
qs += ['WHERE {}'.format(self._where_clause())]
qs = ' '.join(qs)
with connection_manager() as con:
con.execute(qs, self._where_values())
def delete_instance(self, instance):
""" Deletes one instance """
pk_name = self.model._pk_name
qs = ['DELETE FROM {}'.format(self.column_family_name)]
qs += ['WHERE {0}=:{0}'.format(pk_name)]
qs = ' '.join(qs)
with connection_manager() as con:
con.execute(qs, {pk_name:instance.pk})