adding equality and hash methods
This commit is contained in:
@@ -340,10 +340,14 @@ class AbstractQuerySet(object):
|
||||
""" returns the fields to select """
|
||||
return []
|
||||
|
||||
def _validate_select(self):
|
||||
""" put select query validation here """
|
||||
|
||||
def _select_query(self):
|
||||
"""
|
||||
Returns a select clause based on the given filter args
|
||||
"""
|
||||
self._validate_select()
|
||||
return SelectStatement(
|
||||
self.column_family_name,
|
||||
fields=self._select_fields(),
|
||||
@@ -627,7 +631,9 @@ class AbstractQuerySet(object):
|
||||
execute(qs, self._where_values())
|
||||
|
||||
def __eq__(self, q):
|
||||
return set(self._where) == set(q._where)
|
||||
if len(self._where) == len(q._where):
|
||||
return all([w in q._where for w in self._where])
|
||||
return False
|
||||
|
||||
def __ne__(self, q):
|
||||
return not (self != q)
|
||||
@@ -667,30 +673,25 @@ class ModelQuerySet(AbstractQuerySet):
|
||||
"""
|
||||
|
||||
"""
|
||||
def _validate_where_syntax(self):
|
||||
def _validate_select(self):
|
||||
""" Checks that a filterset will not create invalid cql """
|
||||
|
||||
#check that there's either a = or IN relationship with a primary key or indexed field
|
||||
equal_ops = [w for w in self._where if isinstance(w, EqualsOperator)]
|
||||
token_ops = [w for w in self._where if isinstance(w.value, Token)]
|
||||
if not any([w.column.primary_key or w.column.index for w in equal_ops]) and not token_ops:
|
||||
equal_ops = [self.model._columns.get(w.field) for w in self._where if isinstance(w.operator, EqualsOperator)]
|
||||
token_ops = [self.model._columns.get(w.field) for w in self._where if isinstance(w.operator, Token)]
|
||||
if not any([w.primary_key or w.index for w in equal_ops]) and not token_ops:
|
||||
raise QueryException('Where clauses require either a "=" or "IN" comparison with either a primary key or indexed field')
|
||||
|
||||
if not self._allow_filtering:
|
||||
#if the query is not on an indexed field
|
||||
if not any([w.column.index for w in equal_ops]):
|
||||
if not any([w.column.partition_key for w in equal_ops]) and not token_ops:
|
||||
if not any([w.index for w in equal_ops]):
|
||||
if not any([w.partition_key for w in equal_ops]) and not token_ops:
|
||||
raise QueryException('Filtering on a clustering key without a partition key is not allowed unless allow_filtering() is called on the querset')
|
||||
if any(not w.column.partition_key for w in token_ops):
|
||||
if any(not w.partition_key for w in token_ops):
|
||||
raise QueryException('The token() function is only supported on the partition key')
|
||||
|
||||
def _where_clause(self):
|
||||
""" Returns a where clause based on the given filter args """
|
||||
self._validate_where_syntax()
|
||||
return super(ModelQuerySet, self)._where_clause()
|
||||
|
||||
def _get_select_statement(self):
|
||||
""" Returns the fields to be returned by the select query """
|
||||
self._validate_select()
|
||||
if self._defer_fields or self._only_fields:
|
||||
fields = self.model._columns.keys()
|
||||
if self._defer_fields:
|
||||
|
@@ -50,6 +50,9 @@ class BaseClause(object):
|
||||
def __str__(self):
|
||||
return unicode(self).encode('utf-8')
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.field) ^ hash(self.value)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, self.__class__):
|
||||
return self.field == other.field and self.value == other.value
|
||||
@@ -86,6 +89,9 @@ class WhereClause(BaseClause):
|
||||
def __unicode__(self):
|
||||
return u'"{}" {} :{}'.format(self.field, self.operator, self.context_id)
|
||||
|
||||
def __hash__(self):
|
||||
return super(WhereClause, self).__hash__() ^ hash(self.operator)
|
||||
|
||||
def __eq__(self, other):
|
||||
if super(WhereClause, self).__eq__(other):
|
||||
return self.operator.__class__ == other.operator.__class__
|
||||
|
@@ -491,12 +491,12 @@ class TestQuerySetConnectionHandling(BaseQuerySetUsage):
|
||||
assert q._cur is None
|
||||
|
||||
|
||||
|
||||
class TimeUUIDQueryModel(Model):
|
||||
partition = columns.UUID(primary_key=True)
|
||||
time = columns.TimeUUID(primary_key=True)
|
||||
data = columns.Text(required=False)
|
||||
|
||||
|
||||
class TestMinMaxTimeUUIDFunctions(BaseCassEngTestCase):
|
||||
|
||||
@classmethod
|
||||
|
Reference in New Issue
Block a user