adding equality and hash methods

This commit is contained in:
Blake Eggleston
2013-10-27 17:14:27 -07:00
parent 14008640fc
commit 0807b299dd
3 changed files with 22 additions and 15 deletions

View File

@@ -340,10 +340,14 @@ class AbstractQuerySet(object):
""" returns the fields to select """
return []
def _validate_select(self):
""" put select query validation here """
def _select_query(self):
"""
Returns a select clause based on the given filter args
"""
self._validate_select()
return SelectStatement(
self.column_family_name,
fields=self._select_fields(),
@@ -627,7 +631,9 @@ class AbstractQuerySet(object):
execute(qs, self._where_values())
def __eq__(self, q):
return set(self._where) == set(q._where)
if len(self._where) == len(q._where):
return all([w in q._where for w in self._where])
return False
def __ne__(self, q):
return not (self != q)
@@ -667,30 +673,25 @@ class ModelQuerySet(AbstractQuerySet):
"""
"""
def _validate_where_syntax(self):
def _validate_select(self):
""" Checks that a filterset will not create invalid cql """
#check that there's either a = or IN relationship with a primary key or indexed field
equal_ops = [w for w in self._where if isinstance(w, EqualsOperator)]
token_ops = [w for w in self._where if isinstance(w.value, Token)]
if not any([w.column.primary_key or w.column.index for w in equal_ops]) and not token_ops:
equal_ops = [self.model._columns.get(w.field) for w in self._where if isinstance(w.operator, EqualsOperator)]
token_ops = [self.model._columns.get(w.field) for w in self._where if isinstance(w.operator, Token)]
if not any([w.primary_key or w.index for w in equal_ops]) and not token_ops:
raise QueryException('Where clauses require either a "=" or "IN" comparison with either a primary key or indexed field')
if not self._allow_filtering:
#if the query is not on an indexed field
if not any([w.column.index for w in equal_ops]):
if not any([w.column.partition_key for w in equal_ops]) and not token_ops:
if not any([w.index for w in equal_ops]):
if not any([w.partition_key for w in equal_ops]) and not token_ops:
raise QueryException('Filtering on a clustering key without a partition key is not allowed unless allow_filtering() is called on the querset')
if any(not w.column.partition_key for w in token_ops):
if any(not w.partition_key for w in token_ops):
raise QueryException('The token() function is only supported on the partition key')
def _where_clause(self):
""" Returns a where clause based on the given filter args """
self._validate_where_syntax()
return super(ModelQuerySet, self)._where_clause()
def _get_select_statement(self):
""" Returns the fields to be returned by the select query """
self._validate_select()
if self._defer_fields or self._only_fields:
fields = self.model._columns.keys()
if self._defer_fields:

View File

@@ -50,6 +50,9 @@ class BaseClause(object):
def __str__(self):
return unicode(self).encode('utf-8')
def __hash__(self):
return hash(self.field) ^ hash(self.value)
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.field == other.field and self.value == other.value
@@ -86,6 +89,9 @@ class WhereClause(BaseClause):
def __unicode__(self):
return u'"{}" {} :{}'.format(self.field, self.operator, self.context_id)
def __hash__(self):
return super(WhereClause, self).__hash__() ^ hash(self.operator)
def __eq__(self, other):
if super(WhereClause, self).__eq__(other):
return self.operator.__class__ == other.operator.__class__

View File

@@ -491,12 +491,12 @@ class TestQuerySetConnectionHandling(BaseQuerySetUsage):
assert q._cur is None
class TimeUUIDQueryModel(Model):
partition = columns.UUID(primary_key=True)
time = columns.TimeUUID(primary_key=True)
data = columns.Text(required=False)
class TestMinMaxTimeUUIDFunctions(BaseCassEngTestCase):
@classmethod