From 124d197a77c0b62749775f3e477c93a79ec23bec Mon Sep 17 00:00:00 2001 From: Eli Green Date: Thu, 2 Jun 2016 16:47:11 +0200 Subject: [PATCH] Avoid LWTExceptions when updating columns that are part of the condition --- cassandra/cqlengine/query.py | 14 +++++++++++++- .../integration/cqlengine/test_lwt_conditional.py | 15 +++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/cassandra/cqlengine/query.py b/cassandra/cqlengine/query.py index 05962b4e..25a60737 100644 --- a/cassandra/cqlengine/query.py +++ b/cassandra/cqlengine/query.py @@ -1151,6 +1151,7 @@ class ModelQuerySet(AbstractQuerySet): return nulled_columns = set() + updated_columns = set() us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists) for name, val in values.items(): @@ -1171,13 +1172,17 @@ class ModelQuerySet(AbstractQuerySet): continue us.add_update(col, val, operation=col_op) + updated_columns.add(col_name) if us.assignments: self._execute(us) + null_conditional = [condition for condition in self._conditional + if condition.field not in updated_columns] + if nulled_columns: ds = DeleteStatement(self.column_family_name, fields=nulled_columns, - where=self._where, conditionals=self._conditional, if_exists=self._if_exists) + where=self._where, conditionals=null_conditional, if_exists=self._if_exists) self._execute(ds) @@ -1262,6 +1267,8 @@ class DMLQuery(object): conditionals=self._conditional, if_exists=self._if_exists) for name, col in self.instance._clustering_keys.items(): null_clustering_key = null_clustering_key and col._val_is_null(getattr(self.instance, name, None)) + + updated_columns = set() # get defined fields and their column names for name, col in self.model._columns.items(): # if clustering key is null, don't include non static columns @@ -1279,6 +1286,7 @@ class DMLQuery(object): static_changed_only = static_changed_only and col.static statement.add_update(col, val, previous=val_mgr.previous_value) + updated_columns.add(col.db_field_name) if statement.assignments: for name, col in self.model._primary_keys.items(): @@ -1288,6 +1296,10 @@ class DMLQuery(object): statement.add_where(col, EqualsOperator(), getattr(self.instance, name)) self._execute(statement) + # remove conditions on fields that have been updated + self._conditional = [condition for condition in self._conditional + if condition.field not in updated_columns] + if not null_clustering_key: self._delete_null_columns() diff --git a/tests/integration/cqlengine/test_lwt_conditional.py b/tests/integration/cqlengine/test_lwt_conditional.py index d273df9c..8395154c 100644 --- a/tests/integration/cqlengine/test_lwt_conditional.py +++ b/tests/integration/cqlengine/test_lwt_conditional.py @@ -234,3 +234,18 @@ class TestConditional(BaseCassEngTestCase): self.assertIsNotNone(TestConditionalModel.objects(id=t.id).first().text) TestConditionalModel.objects(id=t.id).iff(count=5).update(text=None) self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text) + + def test_column_delete_after_update(self): + # DML path + t = TestConditionalModel.create(text='something', count=5) + t.iff(count=5).update(text=None, count=6) + + self.assertIsNone(t.text) + self.assertEqual(t.count, 6) + + # QuerySet path + t = TestConditionalModel.create(text='something', count=5) + TestConditionalModel.objects(id=t.id).iff(count=5).update(text=None, count=6) + + self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text) + self.assertEqual(TestConditionalModel.objects(id=t.id).first().count, 6)