From 50d349f26eabb9e14e7b23641931063d3f3b5009 Mon Sep 17 00:00:00 2001 From: Alan Boudreault Date: Thu, 14 Jul 2016 16:26:58 -0400 Subject: [PATCH] Fix pk__token equality filter --- cassandra/cqlengine/query.py | 5 +++-- tests/integration/cqlengine/query/test_queryoperators.py | 6 +++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/cassandra/cqlengine/query.py b/cassandra/cqlengine/query.py index 02101f6e..c3d7506a 100644 --- a/cassandra/cqlengine/query.py +++ b/cassandra/cqlengine/query.py @@ -545,7 +545,7 @@ class AbstractQuerySet(object): if len(statement) == 1: return arg, None elif len(statement) == 2: - return statement[0], statement[1] + return (statement[0], statement[1]) if arg != 'pk__token' else (arg, None) else: raise QueryException("Can't parse '{0}'".format(arg)) @@ -954,7 +954,8 @@ class ModelQuerySet(AbstractQuerySet): def _validate_select_where(self): """ Checks that a filterset will not create invalid select statement """ # check that there's either a =, a IN or a CONTAINS (collection) relationship with a primary key or indexed field - equal_ops = [self.model._get_column_by_db_name(w.field) for w in self._where if isinstance(w.operator, EqualsOperator)] + equal_ops = [self.model._get_column_by_db_name(w.field) \ + for w in self._where if isinstance(w.operator, EqualsOperator) and not isinstance(w.value, Token)] token_comparison = any([w for w in self._where if isinstance(w.value, Token)]) if not any(w.primary_key or w.index for w in equal_ops) and not token_comparison and not self._allow_filtering: raise QueryException(('Where clauses require either =, a IN or a CONTAINS (collection) ' diff --git a/tests/integration/cqlengine/query/test_queryoperators.py b/tests/integration/cqlengine/query/test_queryoperators.py index c2a2a742..055e8f3d 100644 --- a/tests/integration/cqlengine/query/test_queryoperators.py +++ b/tests/integration/cqlengine/query/test_queryoperators.py @@ -72,7 +72,7 @@ class TestTokenFunction(BaseCassEngTestCase): super(TestTokenFunction, self).tearDown() drop_table(TokenTestModel) - @execute_count(14) + @execute_count(15) def test_token_function(self): """ Tests that token functions work properly """ assert TokenTestModel.objects().count() == 0 @@ -91,6 +91,10 @@ class TestTokenFunction(BaseCassEngTestCase): assert len(seen_keys) == 10 assert all([i in seen_keys for i in range(10)]) + # pk__token equality + r = TokenTestModel.objects(pk__token=functions.Token(last_token)) + self.assertEqual(len(r), 1) + def test_compound_pk_token_function(self): class TestModel(Model):