diff --git a/cqlengine/statements.py b/cqlengine/statements.py index a6cdf4b3..633472d0 100644 --- a/cqlengine/statements.py +++ b/cqlengine/statements.py @@ -173,15 +173,18 @@ class SetUpdateClause(ContainerUpdateClause): def __unicode__(self): qs = [] ctx_id = self.context_id - if self.previous is None and not (self._assignments or self._additions or self._removals): + if (self.previous is None and + self._assignments is None and + self._additions is None and + self._removals is None): qs += ['"{}" = %({})s'.format(self.field, ctx_id)] - if self._assignments: + if self._assignments is not None: qs += ['"{}" = %({})s'.format(self.field, ctx_id)] ctx_id += 1 - if self._additions: + if self._additions is not None: qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)] ctx_id += 1 - if self._removals: + if self._removals is not None: qs += ['"{0}" = "{0}" - %({1})s'.format(self.field, ctx_id)] return ', '.join(qs) @@ -204,22 +207,28 @@ class SetUpdateClause(ContainerUpdateClause): def get_context_size(self): if not self._analyzed: self._analyze() - if self.previous is None and not (self._assignments or self._additions or self._removals): + if (self.previous is None and + self._assignments is None and + self._additions is None and + self._removals is None): return 1 return int(bool(self._assignments)) + int(bool(self._additions)) + int(bool(self._removals)) def update_context(self, ctx): if not self._analyzed: self._analyze() ctx_id = self.context_id - if self.previous is None and not (self._assignments or self._additions or self._removals): + if (self.previous is None and + self._assignments is None and + self._additions is None and + self._removals is None): ctx[str(ctx_id)] = self._to_database({}) - if self._assignments: + if self._assignments is not None: ctx[str(ctx_id)] = self._to_database(self._assignments) ctx_id += 1 - if self._additions: + if self._additions is not None: ctx[str(ctx_id)] = self._to_database(self._additions) ctx_id += 1 - if self._removals: + if self._removals is not None: ctx[str(ctx_id)] = self._to_database(self._removals) @@ -239,11 +248,11 @@ class ListUpdateClause(ContainerUpdateClause): qs += ['"{}" = %({})s'.format(self.field, ctx_id)] ctx_id += 1 - if self._prepend: + if self._prepend is not None: qs += ['"{0}" = %({1})s + "{0}"'.format(self.field, ctx_id)] ctx_id += 1 - if self._append: + if self._append is not None: qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)] return ', '.join(qs) @@ -258,13 +267,13 @@ class ListUpdateClause(ContainerUpdateClause): if self._assignments is not None: ctx[str(ctx_id)] = self._to_database(self._assignments) ctx_id += 1 - if self._prepend: + if self._prepend is not None: # CQL seems to prepend element at a time, starting # with the element at idx 0, we can either reverse # it here, or have it inserted in reverse ctx[str(ctx_id)] = self._to_database(list(reversed(self._prepend))) ctx_id += 1 - if self._append: + if self._append is not None: ctx[str(ctx_id)] = self._to_database(self._append) def _analyze(self): diff --git a/cqlengine/tests/columns/test_validation.py b/cqlengine/tests/columns/test_validation.py index 9c327e4e..d665237c 100644 --- a/cqlengine/tests/columns/test_validation.py +++ b/cqlengine/tests/columns/test_validation.py @@ -30,8 +30,6 @@ from cqlengine.columns import Inet from cqlengine.management import sync_table, drop_table from cqlengine.models import Model -import sys - class TestDatetime(BaseCassEngTestCase): class DatetimeTest(Model): diff --git a/cqlengine/tests/statements/test_update_statement.py b/cqlengine/tests/statements/test_update_statement.py index f75a9530..ab5aeec0 100644 --- a/cqlengine/tests/statements/test_update_statement.py +++ b/cqlengine/tests/statements/test_update_statement.py @@ -1,8 +1,12 @@ from unittest import TestCase -from cqlengine.statements import UpdateStatement, WhereClause, AssignmentClause +from cqlengine.columns import Set, List from cqlengine.operators import * +from cqlengine.statements import (UpdateStatement, WhereClause, + AssignmentClause, SetUpdateClause, + ListUpdateClause) import six + class UpdateStatementTests(TestCase): def test_table_rendering(self): @@ -40,3 +44,27 @@ class UpdateStatementTests(TestCase): us.add_where_clause(WhereClause('a', EqualsOperator(), 'x')) self.assertIn('USING TTL 60', six.text_type(us)) + def test_update_set_add(self): + us = UpdateStatement('table') + us.add_assignment_clause(SetUpdateClause('a', Set.Quoter({1}), operation='add')) + self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s') + + def test_update_empty_set_add_does_not_assign(self): + us = UpdateStatement('table') + us.add_assignment_clause(SetUpdateClause('a', Set.Quoter(set()), operation='add')) + self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s') + + def test_update_empty_set_removal_does_not_assign(self): + us = UpdateStatement('table') + us.add_assignment_clause(SetUpdateClause('a', Set.Quoter(set()), operation='remove')) + self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" - %(0)s') + + def test_update_list_prepend_with_empty_list(self): + us = UpdateStatement('table') + us.add_assignment_clause(ListUpdateClause('a', List.Quoter([]), operation='prepend')) + self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(0)s + "a"') + + def test_update_list_append_with_empty_list(self): + us = UpdateStatement('table') + us.add_assignment_clause(ListUpdateClause('a', List.Quoter([]), operation='append')) + self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s')