adding equality and hash methods
This commit is contained in:
@@ -340,10 +340,14 @@ class AbstractQuerySet(object):
|
|||||||
""" returns the fields to select """
|
""" returns the fields to select """
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def _validate_select(self):
|
||||||
|
""" put select query validation here """
|
||||||
|
|
||||||
def _select_query(self):
|
def _select_query(self):
|
||||||
"""
|
"""
|
||||||
Returns a select clause based on the given filter args
|
Returns a select clause based on the given filter args
|
||||||
"""
|
"""
|
||||||
|
self._validate_select()
|
||||||
return SelectStatement(
|
return SelectStatement(
|
||||||
self.column_family_name,
|
self.column_family_name,
|
||||||
fields=self._select_fields(),
|
fields=self._select_fields(),
|
||||||
@@ -627,7 +631,9 @@ class AbstractQuerySet(object):
|
|||||||
execute(qs, self._where_values())
|
execute(qs, self._where_values())
|
||||||
|
|
||||||
def __eq__(self, q):
|
def __eq__(self, q):
|
||||||
return set(self._where) == set(q._where)
|
if len(self._where) == len(q._where):
|
||||||
|
return all([w in q._where for w in self._where])
|
||||||
|
return False
|
||||||
|
|
||||||
def __ne__(self, q):
|
def __ne__(self, q):
|
||||||
return not (self != q)
|
return not (self != q)
|
||||||
@@ -667,30 +673,25 @@ class ModelQuerySet(AbstractQuerySet):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def _validate_where_syntax(self):
|
def _validate_select(self):
|
||||||
""" Checks that a filterset will not create invalid cql """
|
""" Checks that a filterset will not create invalid cql """
|
||||||
|
|
||||||
#check that there's either a = or IN relationship with a primary key or indexed field
|
#check that there's either a = or IN relationship with a primary key or indexed field
|
||||||
equal_ops = [w for w in self._where if isinstance(w, EqualsOperator)]
|
equal_ops = [self.model._columns.get(w.field) for w in self._where if isinstance(w.operator, EqualsOperator)]
|
||||||
token_ops = [w for w in self._where if isinstance(w.value, Token)]
|
token_ops = [self.model._columns.get(w.field) for w in self._where if isinstance(w.operator, Token)]
|
||||||
if not any([w.column.primary_key or w.column.index for w in equal_ops]) and not token_ops:
|
if not any([w.primary_key or w.index for w in equal_ops]) and not token_ops:
|
||||||
raise QueryException('Where clauses require either a "=" or "IN" comparison with either a primary key or indexed field')
|
raise QueryException('Where clauses require either a "=" or "IN" comparison with either a primary key or indexed field')
|
||||||
|
|
||||||
if not self._allow_filtering:
|
if not self._allow_filtering:
|
||||||
#if the query is not on an indexed field
|
#if the query is not on an indexed field
|
||||||
if not any([w.column.index for w in equal_ops]):
|
if not any([w.index for w in equal_ops]):
|
||||||
if not any([w.column.partition_key for w in equal_ops]) and not token_ops:
|
if not any([w.partition_key for w in equal_ops]) and not token_ops:
|
||||||
raise QueryException('Filtering on a clustering key without a partition key is not allowed unless allow_filtering() is called on the querset')
|
raise QueryException('Filtering on a clustering key without a partition key is not allowed unless allow_filtering() is called on the querset')
|
||||||
if any(not w.column.partition_key for w in token_ops):
|
if any(not w.partition_key for w in token_ops):
|
||||||
raise QueryException('The token() function is only supported on the partition key')
|
raise QueryException('The token() function is only supported on the partition key')
|
||||||
|
|
||||||
def _where_clause(self):
|
|
||||||
""" Returns a where clause based on the given filter args """
|
|
||||||
self._validate_where_syntax()
|
|
||||||
return super(ModelQuerySet, self)._where_clause()
|
|
||||||
|
|
||||||
def _get_select_statement(self):
|
def _get_select_statement(self):
|
||||||
""" Returns the fields to be returned by the select query """
|
""" Returns the fields to be returned by the select query """
|
||||||
|
self._validate_select()
|
||||||
if self._defer_fields or self._only_fields:
|
if self._defer_fields or self._only_fields:
|
||||||
fields = self.model._columns.keys()
|
fields = self.model._columns.keys()
|
||||||
if self._defer_fields:
|
if self._defer_fields:
|
||||||
|
|||||||
@@ -50,6 +50,9 @@ class BaseClause(object):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return unicode(self).encode('utf-8')
|
return unicode(self).encode('utf-8')
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.field) ^ hash(self.value)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if isinstance(other, self.__class__):
|
if isinstance(other, self.__class__):
|
||||||
return self.field == other.field and self.value == other.value
|
return self.field == other.field and self.value == other.value
|
||||||
@@ -86,6 +89,9 @@ class WhereClause(BaseClause):
|
|||||||
def __unicode__(self):
|
def __unicode__(self):
|
||||||
return u'"{}" {} :{}'.format(self.field, self.operator, self.context_id)
|
return u'"{}" {} :{}'.format(self.field, self.operator, self.context_id)
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return super(WhereClause, self).__hash__() ^ hash(self.operator)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if super(WhereClause, self).__eq__(other):
|
if super(WhereClause, self).__eq__(other):
|
||||||
return self.operator.__class__ == other.operator.__class__
|
return self.operator.__class__ == other.operator.__class__
|
||||||
|
|||||||
@@ -491,12 +491,12 @@ class TestQuerySetConnectionHandling(BaseQuerySetUsage):
|
|||||||
assert q._cur is None
|
assert q._cur is None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TimeUUIDQueryModel(Model):
|
class TimeUUIDQueryModel(Model):
|
||||||
partition = columns.UUID(primary_key=True)
|
partition = columns.UUID(primary_key=True)
|
||||||
time = columns.TimeUUID(primary_key=True)
|
time = columns.TimeUUID(primary_key=True)
|
||||||
data = columns.Text(required=False)
|
data = columns.Text(required=False)
|
||||||
|
|
||||||
|
|
||||||
class TestMinMaxTimeUUIDFunctions(BaseCassEngTestCase):
|
class TestMinMaxTimeUUIDFunctions(BaseCassEngTestCase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user