482 lines
15 KiB
Python
482 lines
15 KiB
Python
from collections import namedtuple
|
|
import copy
|
|
from hashlib import md5
|
|
from time import time
|
|
|
|
from cqlengine.connection import connection_manager
|
|
from cqlengine.exceptions import CQLEngineException
|
|
|
|
#CQL 3 reference:
|
|
#http://www.datastax.com/docs/1.1/references/cql/index
|
|
|
|
class QueryException(CQLEngineException): pass
|
|
class QueryOperatorException(QueryException): pass
|
|
|
|
class QueryOperator(object):
|
|
# The symbol that identifies this operator in filter kwargs
|
|
# ie: colname__<symbol>
|
|
symbol = None
|
|
|
|
# The comparator symbol this operator uses in cql
|
|
cql_symbol = None
|
|
|
|
def __init__(self, column, value):
|
|
self.column = column
|
|
self.value = value
|
|
|
|
#the identifier is a unique key that will be used in string
|
|
#replacement on query strings, it's created from a hash
|
|
#of this object's id and the time
|
|
self.identifier = md5(str(id(self)) + str(time())).hexdigest()
|
|
|
|
#perform validation on this operator
|
|
self.validate_operator()
|
|
self.validate_value()
|
|
|
|
@property
|
|
def cql(self):
|
|
"""
|
|
Returns this operator's portion of the WHERE clause
|
|
:param valname: the dict key that this operator's compare value will be found in
|
|
"""
|
|
return '{} {} :{}'.format(self.column.db_field_name, self.cql_symbol, self.identifier)
|
|
|
|
def validate_operator(self):
|
|
"""
|
|
Checks that this operator can be used on the column provided
|
|
"""
|
|
if self.symbol is None:
|
|
raise QueryOperatorException(
|
|
"{} is not a valid operator, use one with 'symbol' defined".format(
|
|
self.__class__.__name__
|
|
)
|
|
)
|
|
if self.cql_symbol is None:
|
|
raise QueryOperatorException(
|
|
"{} is not a valid operator, use one with 'cql_symbol' defined".format(
|
|
self.__class__.__name__
|
|
)
|
|
)
|
|
|
|
def validate_value(self):
|
|
"""
|
|
Checks that the compare value works with this operator
|
|
|
|
Doesn't do anything by default
|
|
"""
|
|
pass
|
|
|
|
def get_dict(self):
|
|
"""
|
|
Returns this operators contribution to the cql.query arg dictionanry
|
|
|
|
ie: if this column's name is colname, and the identifier is colval,
|
|
this should return the dict: {'colval':<self.value>}
|
|
SELECT * FROM column_family WHERE colname=:colval
|
|
"""
|
|
return {self.identifier: self.value}
|
|
|
|
@classmethod
|
|
def get_operator(cls, symbol):
|
|
if not hasattr(cls, 'opmap'):
|
|
QueryOperator.opmap = {}
|
|
def _recurse(klass):
|
|
if klass.symbol:
|
|
QueryOperator.opmap[klass.symbol.upper()] = klass
|
|
for subklass in klass.__subclasses__():
|
|
_recurse(subklass)
|
|
pass
|
|
_recurse(QueryOperator)
|
|
try:
|
|
return QueryOperator.opmap[symbol.upper()]
|
|
except KeyError:
|
|
raise QueryOperatorException("{} doesn't map to a QueryOperator".format(symbol))
|
|
|
|
class EqualsOperator(QueryOperator):
|
|
symbol = 'EQ'
|
|
cql_symbol = '='
|
|
|
|
class InOperator(EqualsOperator):
|
|
symbol = 'IN'
|
|
cql_symbol = 'IN'
|
|
|
|
class GreaterThanOperator(QueryOperator):
|
|
symbol = "GT"
|
|
cql_symbol = '>'
|
|
|
|
class GreaterThanOrEqualOperator(QueryOperator):
|
|
symbol = "GTE"
|
|
cql_symbol = '>='
|
|
|
|
class LessThanOperator(QueryOperator):
|
|
symbol = "LT"
|
|
cql_symbol = '<'
|
|
|
|
class LessThanOrEqualOperator(QueryOperator):
|
|
symbol = "LTE"
|
|
cql_symbol = '<='
|
|
|
|
class QuerySet(object):
|
|
#TODO: delete empty columns on save
|
|
#TODO: support specifying columns to exclude or select only
|
|
#TODO: cache results in this instance, but don't copy them on deepcopy
|
|
|
|
def __init__(self, model):
|
|
super(QuerySet, self).__init__()
|
|
self.model = model
|
|
self.column_family_name = self.model.column_family_name()
|
|
|
|
#Where clause filters
|
|
self._where = []
|
|
|
|
#ordering arguments
|
|
self._order = None
|
|
|
|
#CQL has a default limit of 10000, it's defined here
|
|
#because explicit is better than implicit
|
|
self._limit = 10000
|
|
|
|
#see the defer and only methods
|
|
self._defer_fields = []
|
|
self._only_fields = []
|
|
|
|
#results cache
|
|
self._result_cache = None
|
|
|
|
self._cursor = None
|
|
|
|
def __unicode__(self):
|
|
return self._select_query()
|
|
|
|
def __str__(self):
|
|
return str(self.__unicode__())
|
|
|
|
def __call__(self, **kwargs):
|
|
return self.filter(**kwargs)
|
|
|
|
def __deepcopy__(self, memo):
|
|
clone = self.__class__(self.model)
|
|
for k,v in self.__dict__.items():
|
|
if k in ['_result_cache']:
|
|
clone.__dict__[k] = None
|
|
else:
|
|
clone.__dict__[k] = copy.deepcopy(v, memo)
|
|
|
|
return clone
|
|
|
|
def __len__(self):
|
|
return self.count()
|
|
|
|
#----query generation / execution----
|
|
|
|
def _validate_where_syntax(self):
|
|
""" Checks that a filterset will not create invalid cql """
|
|
|
|
#check that there's either a = or IN relationship with a primary key or indexed field
|
|
equal_ops = [w for w in self._where if isinstance(w, EqualsOperator)]
|
|
if not any([w.column.primary_key or w.column.index for w in equal_ops]):
|
|
raise QueryException('Where clauses require either a "=" or "IN" comparison with either a primary key or indexed field')
|
|
#TODO: abuse this to see if we can get cql to raise an exception
|
|
|
|
def _where_clause(self):
|
|
""" Returns a where clause based on the given filter args """
|
|
self._validate_where_syntax()
|
|
return ' AND '.join([f.cql for f in self._where])
|
|
|
|
def _where_values(self):
|
|
""" Returns the value dict to be passed to the cql query """
|
|
values = {}
|
|
for where in self._where:
|
|
values.update(where.get_dict())
|
|
return values
|
|
|
|
def _select_query(self):
|
|
"""
|
|
Returns a select clause based on the given filter args
|
|
"""
|
|
fields = self.model._columns.keys()
|
|
if self._defer_fields:
|
|
fields = [f for f in fields if f not in self._defer_fields]
|
|
elif self._only_fields:
|
|
fields = [f for f in fields if f in self._only_fields]
|
|
db_fields = [self.model._columns[f].db_field_name for f in fields]
|
|
|
|
qs = ['SELECT {}'.format(', '.join(db_fields))]
|
|
qs += ['FROM {}'.format(self.column_family_name)]
|
|
|
|
if self._where:
|
|
qs += ['WHERE {}'.format(self._where_clause())]
|
|
|
|
if self._order:
|
|
qs += ['ORDER BY {}'.format(self._order)]
|
|
|
|
if self._limit:
|
|
qs += ['LIMIT {}'.format(self._limit)]
|
|
|
|
return ' '.join(qs)
|
|
|
|
#----Reads------
|
|
|
|
def __iter__(self):
|
|
#TODO: cache results
|
|
if self._cursor is None:
|
|
#TODO: the query and caching should happen in the same function
|
|
with connection_manager() as con:
|
|
self._cursor = con.execute(self._select_query(), self._where_values())
|
|
self._rowcount = self._cursor.rowcount
|
|
return self
|
|
|
|
def __getitem__(self, s):
|
|
|
|
if isinstance(s, slice):
|
|
#return a new query with limit defined
|
|
#start and step are not supported
|
|
if s.start: raise QueryException('CQL does not support START')
|
|
if s.step: raise QueryException('step is not supported')
|
|
return self.limit(s.stop)
|
|
else:
|
|
#return the object at this index
|
|
s = long(s)
|
|
raise NotImplementedError
|
|
|
|
def _construct_instance(self, values):
|
|
#translate column names to model names
|
|
field_dict = {}
|
|
db_map = self.model._db_map
|
|
for key, val in values.items():
|
|
if key in db_map:
|
|
field_dict[db_map[key]] = val
|
|
else:
|
|
field_dict[key] = val
|
|
return self.model(**field_dict)
|
|
|
|
def _get_next(self):
|
|
""" Gets the next cursor result """
|
|
cur = self._cursor
|
|
values = cur.fetchone()
|
|
if values is None: return
|
|
names = [i[0] for i in cur.description]
|
|
value_dict = dict(zip(names, values))
|
|
return self._construct_instance(value_dict)
|
|
|
|
def next(self):
|
|
instance = self._get_next()
|
|
if instance is None:
|
|
#TODO: this is inefficient, we should be caching the results
|
|
self._cursor = None
|
|
raise StopIteration
|
|
return instance
|
|
|
|
def first(self):
|
|
return iter(self)._get_next()
|
|
|
|
def all(self):
|
|
clone = copy.deepcopy(self)
|
|
clone._where = []
|
|
return clone
|
|
|
|
def _parse_filter_arg(self, arg):
|
|
"""
|
|
Parses a filter arg in the format:
|
|
<colname>__<op>
|
|
:returns: colname, op tuple
|
|
"""
|
|
statement = arg.split('__')
|
|
if len(statement) == 1:
|
|
return arg, None
|
|
elif len(statement) == 2:
|
|
return statement[0], statement[1]
|
|
else:
|
|
raise QueryException("Can't parse '{}'".format(arg))
|
|
|
|
def filter(self, **kwargs):
|
|
#add arguments to the where clause filters
|
|
clone = copy.deepcopy(self)
|
|
for arg, val in kwargs.items():
|
|
col_name, col_op = self._parse_filter_arg(arg)
|
|
#resolve column and operator
|
|
try:
|
|
column = self.model._columns[col_name]
|
|
except KeyError:
|
|
raise QueryException("Can't resolve column name: '{}'".format(col_name))
|
|
|
|
#get query operator, or use equals if not supplied
|
|
operator_class = QueryOperator.get_operator(col_op or 'EQ')
|
|
operator = operator_class(column, val)
|
|
|
|
clone._where.append(operator)
|
|
|
|
return clone
|
|
|
|
def order_by(self, colname):
|
|
"""
|
|
orders the result set.
|
|
ordering can only select one column, and it must be the second column in a composite primary key
|
|
|
|
Default order is ascending, prepend a '-' to the column name for descending
|
|
"""
|
|
if colname is None:
|
|
clone = copy.deepcopy(self)
|
|
clone._order = None
|
|
return clone
|
|
|
|
order_type = 'DESC' if colname.startswith('-') else 'ASC'
|
|
colname = colname.replace('-', '')
|
|
|
|
column = self.model._columns.get(colname)
|
|
if column is None:
|
|
raise QueryException("Can't resolve the column name: '{}'".format(colname))
|
|
|
|
#validate the column selection
|
|
if not column.primary_key:
|
|
raise QueryException(
|
|
"Can't order on '{}', can only order on (clustered) primary keys".format(colname))
|
|
|
|
pks = [v for k,v in self.model._columns.items() if v.primary_key]
|
|
if column == pks[0]:
|
|
raise QueryException(
|
|
"Can't order by the first primary key, clustering (secondary) keys only")
|
|
|
|
clone = copy.deepcopy(self)
|
|
clone._order = '{} {}'.format(column.db_field_name, order_type)
|
|
return clone
|
|
|
|
def count(self):
|
|
""" Returns the number of rows matched by this query """
|
|
#TODO: check for previous query execution and return row count if it exists
|
|
qs = ['SELECT COUNT(*)']
|
|
qs += ['FROM {}'.format(self.column_family_name)]
|
|
if self._where:
|
|
qs += ['WHERE {}'.format(self._where_clause())]
|
|
qs = ' '.join(qs)
|
|
|
|
with connection_manager() as con:
|
|
cur = con.execute(qs, self._where_values())
|
|
return cur.fetchone()[0]
|
|
|
|
def limit(self, v):
|
|
"""
|
|
Sets the limit on the number of results returned
|
|
CQL has a default limit of 10,000
|
|
"""
|
|
if not (v is None or isinstance(v, (int, long))):
|
|
raise TypeError
|
|
if v == self._limit:
|
|
return self
|
|
|
|
if v < 0:
|
|
raise QueryException("Negative limit is not allowed")
|
|
|
|
clone = copy.deepcopy(self)
|
|
clone._limit = v
|
|
return clone
|
|
|
|
def _only_or_defer(self, action, fields):
|
|
clone = copy.deepcopy(self)
|
|
if clone._defer_fields or clone._only_fields:
|
|
raise QueryException("QuerySet alread has only or defer fields defined")
|
|
|
|
#check for strange fields
|
|
missing_fields = [f for f in fields if f not in self.model._columns.keys()]
|
|
if missing_fields:
|
|
raise QueryException(
|
|
"Can't resolve fields {} in {}".format(
|
|
', '.join(missing_fields), self.model.__name__))
|
|
|
|
if action == 'defer':
|
|
clone._defer_fields = fields
|
|
elif action == 'only':
|
|
clone._only_fields = fields
|
|
else:
|
|
raise ValueError
|
|
|
|
return clone
|
|
|
|
def only(self, fields):
|
|
""" Load only these fields for the returned query """
|
|
return self._only_or_defer('only', fields)
|
|
|
|
def defer(self, fields):
|
|
""" Don't load these fields for the returned query """
|
|
return self._only_or_defer('defer', fields)
|
|
|
|
#----writes----
|
|
def save(self, instance):
|
|
"""
|
|
Creates / updates a row.
|
|
This is a blind insert call.
|
|
All validation and cleaning needs to happen
|
|
prior to calling this.
|
|
"""
|
|
assert type(instance) == self.model
|
|
|
|
#organize data
|
|
value_pairs = []
|
|
values = instance.as_dict()
|
|
|
|
#get defined fields and their column names
|
|
for name, col in self.model._columns.items():
|
|
val = values.get(name)
|
|
if val is None: continue
|
|
value_pairs += [(col.db_field_name, val)]
|
|
|
|
#construct query string
|
|
field_names = zip(*value_pairs)[0]
|
|
field_values = dict(value_pairs)
|
|
qs = ["INSERT INTO {}".format(self.column_family_name)]
|
|
qs += ["({})".format(', '.join(field_names))]
|
|
qs += ['VALUES']
|
|
qs += ["({})".format(', '.join([':'+f for f in field_names]))]
|
|
qs = ' '.join(qs)
|
|
|
|
with connection_manager() as con:
|
|
con.execute(qs, field_values)
|
|
|
|
#delete deleted / nulled columns
|
|
deleted = [k for k,v in instance._values.items() if v.deleted]
|
|
if deleted:
|
|
del_fields = [self.model._columns[f] for f in deleted]
|
|
del_fields = [f.db_field_name for f in del_fields if not f.primary_key]
|
|
pks = self.model._primary_keys
|
|
qs = ['DELETE {}'.format(', '.join(del_fields))]
|
|
qs += ['FROM {}'.format(self.column_family_name)]
|
|
qs += ['WHERE']
|
|
eq = lambda col: '{0} = :{0}'.format(v.column.db_field_name)
|
|
qs += [' AND '.join([eq(f) for f in pks.values()])]
|
|
qs = ' '.join(qs)
|
|
|
|
pk_dict = dict([(v.db_field_name, getattr(instance, k)) for k,v in pks.items()])
|
|
|
|
with connection_manager() as con:
|
|
con.execute(qs, pk_dict)
|
|
|
|
|
|
def create(self, **kwargs):
|
|
return self.model(**kwargs).save()
|
|
|
|
#----delete---
|
|
def delete(self, columns=[]):
|
|
"""
|
|
Deletes the contents of a query
|
|
"""
|
|
qs = ['DELETE FROM {}'.format(self.column_family_name)]
|
|
if self._where:
|
|
qs += ['WHERE {}'.format(self._where_clause())]
|
|
qs = ' '.join(qs)
|
|
|
|
with connection_manager() as con:
|
|
con.execute(qs, self._where_values())
|
|
|
|
|
|
def delete_instance(self, instance):
|
|
""" Deletes one instance """
|
|
pk_name = self.model._pk_name
|
|
qs = ['DELETE FROM {}'.format(self.column_family_name)]
|
|
qs += ['WHERE {0}=:{0}'.format(pk_name)]
|
|
qs = ' '.join(qs)
|
|
|
|
with connection_manager() as con:
|
|
con.execute(qs, {pk_name:instance.pk})
|
|
|
|
|