Implements the ability to blind update add to a set.

This commit is contained in:
Danny Cosson
2014-02-09 14:39:26 -05:00
parent 79fc15ecd8
commit 95d99bf359
3 changed files with 32 additions and 9 deletions

View File

@@ -663,24 +663,34 @@ class ModelQuerySet(AbstractQuerySet):
nulled_columns = set() nulled_columns = set()
us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, timestamp=self._timestamp) us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, timestamp=self._timestamp)
for name, val in values.items(): for name, val in values.items():
col = self.model._columns.get(name) col_name, col_op = self._parse_filter_arg(name)
col = self.model._columns.get(col_name)
# check for nonexistant columns # check for nonexistant columns
if col is None: if col is None:
raise ValidationError("{}.{} has no column named: {}".format(self.__module__, self.model.__name__, name)) raise ValidationError("{}.{} has no column named: {}".format(self.__module__, self.model.__name__, col_name))
# check for primary key update attempts # check for primary key update attempts
if col.is_primary_key: if col.is_primary_key:
raise ValidationError("Cannot apply update to primary key '{}' for {}.{}".format(name, self.__module__, self.model.__name__)) raise ValidationError("Cannot apply update to primary key '{}' for {}.{}".format(col_name, self.__module__, self.model.__name__))
val = col.validate(val) val = col.validate(val)
if val is None: if val is None:
nulled_columns.add(name) nulled_columns.add(col_name)
continue continue
# add the update statements # add the update statements
if isinstance(col, Counter): if isinstance(col, Counter):
# TODO: implement counter updates # TODO: implement counter updates
raise NotImplementedError raise NotImplementedError
elif isinstance(col, BaseContainerColumn):
if isinstance(col, List): klass = ListUpdateClause
elif isinstance(col, Map): klass = MapUpdateClause
elif isinstance(col, Set): klass = SetUpdateClause
else: raise RuntimeError
us.add_assignment_clause(klass(
col_name, col.to_database(val), operation=col_op))
else: else:
us.add_assignment_clause(AssignmentClause(name, col.to_database(val))) us.add_assignment_clause(AssignmentClause(
col_name, col.to_database(val)))
if us.assignments: if us.assignments:
self._execute(us) self._execute(us)

View File

@@ -136,10 +136,11 @@ class AssignmentClause(BaseClause):
class ContainerUpdateClause(AssignmentClause): class ContainerUpdateClause(AssignmentClause):
def __init__(self, field, value, previous=None, column=None): def __init__(self, field, value, operation=None, previous=None, column=None):
super(ContainerUpdateClause, self).__init__(field, value) super(ContainerUpdateClause, self).__init__(field, value)
self.previous = previous self.previous = previous
self._assignments = None self._assignments = None
self._operation = operation
self._analyzed = False self._analyzed = False
self._column = column self._column = column
@@ -159,8 +160,8 @@ class ContainerUpdateClause(AssignmentClause):
class SetUpdateClause(ContainerUpdateClause): class SetUpdateClause(ContainerUpdateClause):
""" updates a set collection """ """ updates a set collection """
def __init__(self, field, value, previous=None, column=None): def __init__(self, field, value, operation=None, previous=None, column=None):
super(SetUpdateClause, self).__init__(field, value, previous, column=column) super(SetUpdateClause, self).__init__(field, value, operation, previous, column=column)
self._additions = None self._additions = None
self._removals = None self._removals = None
@@ -182,6 +183,8 @@ class SetUpdateClause(ContainerUpdateClause):
""" works out the updates to be performed """ """ works out the updates to be performed """
if self.value is None or self.value == self.previous: if self.value is None or self.value == self.previous:
pass pass
elif self._operation == "add":
self._additions = self.value
elif self.previous is None: elif self.previous is None:
self._assignments = self.value self._assignments = self.value
else: else:

View File

@@ -13,7 +13,7 @@ class TestQueryUpdateModel(Model):
cluster = columns.Integer(primary_key=True) cluster = columns.Integer(primary_key=True)
count = columns.Integer(required=False) count = columns.Integer(required=False)
text = columns.Text(required=False, index=True) text = columns.Text(required=False, index=True)
text_set = columns.Set(columns.Text, required=False)
class QueryUpdateTests(BaseCassEngTestCase): class QueryUpdateTests(BaseCassEngTestCase):
@@ -115,3 +115,13 @@ class QueryUpdateTests(BaseCassEngTestCase):
def test_counter_updates(self): def test_counter_updates(self):
pass pass
def test_set_add_updates(self):
partition = uuid4()
cluster = 1
TestQueryUpdateModel.objects.create(
partition=partition, cluster=cluster, text_set={"foo"})
TestQueryUpdateModel.objects(
partition=partition, cluster=cluster).update(text_set__add={'bar'})
obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster)
self.assertEqual(obj.text_set, {"foo", "bar"})