diff --git a/cqlengine/query.py b/cqlengine/query.py index cd5d7cec..dee495c5 100644 --- a/cqlengine/query.py +++ b/cqlengine/query.py @@ -340,10 +340,14 @@ class AbstractQuerySet(object): """ returns the fields to select """ return [] + def _validate_select(self): + """ put select query validation here """ + def _select_query(self): """ Returns a select clause based on the given filter args """ + self._validate_select() return SelectStatement( self.column_family_name, fields=self._select_fields(), @@ -627,7 +631,9 @@ class AbstractQuerySet(object): execute(qs, self._where_values()) def __eq__(self, q): - return set(self._where) == set(q._where) + if len(self._where) == len(q._where): + return all([w in q._where for w in self._where]) + return False def __ne__(self, q): return not (self != q) @@ -667,30 +673,25 @@ class ModelQuerySet(AbstractQuerySet): """ """ - def _validate_where_syntax(self): + def _validate_select(self): """ Checks that a filterset will not create invalid cql """ - #check that there's either a = or IN relationship with a primary key or indexed field - equal_ops = [w for w in self._where if isinstance(w, EqualsOperator)] - token_ops = [w for w in self._where if isinstance(w.value, Token)] - if not any([w.column.primary_key or w.column.index for w in equal_ops]) and not token_ops: + equal_ops = [self.model._columns.get(w.field) for w in self._where if isinstance(w.operator, EqualsOperator)] + token_ops = [self.model._columns.get(w.field) for w in self._where if isinstance(w.operator, Token)] + if not any([w.primary_key or w.index for w in equal_ops]) and not token_ops: raise QueryException('Where clauses require either a "=" or "IN" comparison with either a primary key or indexed field') if not self._allow_filtering: #if the query is not on an indexed field - if not any([w.column.index for w in equal_ops]): - if not any([w.column.partition_key for w in equal_ops]) and not token_ops: + if not any([w.index for w in equal_ops]): + if not any([w.partition_key for w in equal_ops]) and not token_ops: raise QueryException('Filtering on a clustering key without a partition key is not allowed unless allow_filtering() is called on the querset') - if any(not w.column.partition_key for w in token_ops): + if any(not w.partition_key for w in token_ops): raise QueryException('The token() function is only supported on the partition key') - def _where_clause(self): - """ Returns a where clause based on the given filter args """ - self._validate_where_syntax() - return super(ModelQuerySet, self)._where_clause() - def _get_select_statement(self): """ Returns the fields to be returned by the select query """ + self._validate_select() if self._defer_fields or self._only_fields: fields = self.model._columns.keys() if self._defer_fields: diff --git a/cqlengine/statements.py b/cqlengine/statements.py index 281701a9..80ec6a5a 100644 --- a/cqlengine/statements.py +++ b/cqlengine/statements.py @@ -50,6 +50,9 @@ class BaseClause(object): def __str__(self): return unicode(self).encode('utf-8') + def __hash__(self): + return hash(self.field) ^ hash(self.value) + def __eq__(self, other): if isinstance(other, self.__class__): return self.field == other.field and self.value == other.value @@ -86,6 +89,9 @@ class WhereClause(BaseClause): def __unicode__(self): return u'"{}" {} :{}'.format(self.field, self.operator, self.context_id) + def __hash__(self): + return super(WhereClause, self).__hash__() ^ hash(self.operator) + def __eq__(self, other): if super(WhereClause, self).__eq__(other): return self.operator.__class__ == other.operator.__class__ diff --git a/cqlengine/tests/query/test_queryset.py b/cqlengine/tests/query/test_queryset.py index 8017c157..97e60335 100644 --- a/cqlengine/tests/query/test_queryset.py +++ b/cqlengine/tests/query/test_queryset.py @@ -491,12 +491,12 @@ class TestQuerySetConnectionHandling(BaseQuerySetUsage): assert q._cur is None - class TimeUUIDQueryModel(Model): partition = columns.UUID(primary_key=True) time = columns.TimeUUID(primary_key=True) data = columns.Text(required=False) + class TestMinMaxTimeUUIDFunctions(BaseCassEngTestCase): @classmethod