Implements the ability to blind update add to a set.

This commit is contained in:
Danny Cosson
2014-02-09 14:39:26 -05:00
parent 79fc15ecd8
commit 95d99bf359
3 changed files with 32 additions and 9 deletions

View File

@@ -663,24 +663,34 @@ class ModelQuerySet(AbstractQuerySet):
nulled_columns = set()
us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, timestamp=self._timestamp)
for name, val in values.items():
col = self.model._columns.get(name)
col_name, col_op = self._parse_filter_arg(name)
col = self.model._columns.get(col_name)
# check for nonexistant columns
if col is None:
raise ValidationError("{}.{} has no column named: {}".format(self.__module__, self.model.__name__, name))
raise ValidationError("{}.{} has no column named: {}".format(self.__module__, self.model.__name__, col_name))
# check for primary key update attempts
if col.is_primary_key:
raise ValidationError("Cannot apply update to primary key '{}' for {}.{}".format(name, self.__module__, self.model.__name__))
raise ValidationError("Cannot apply update to primary key '{}' for {}.{}".format(col_name, self.__module__, self.model.__name__))
val = col.validate(val)
if val is None:
nulled_columns.add(name)
nulled_columns.add(col_name)
continue
# add the update statements
if isinstance(col, Counter):
# TODO: implement counter updates
raise NotImplementedError
elif isinstance(col, BaseContainerColumn):
if isinstance(col, List): klass = ListUpdateClause
elif isinstance(col, Map): klass = MapUpdateClause
elif isinstance(col, Set): klass = SetUpdateClause
else: raise RuntimeError
us.add_assignment_clause(klass(
col_name, col.to_database(val), operation=col_op))
else:
us.add_assignment_clause(AssignmentClause(name, col.to_database(val)))
us.add_assignment_clause(AssignmentClause(
col_name, col.to_database(val)))
if us.assignments:
self._execute(us)

View File

@@ -136,10 +136,11 @@ class AssignmentClause(BaseClause):
class ContainerUpdateClause(AssignmentClause):
def __init__(self, field, value, previous=None, column=None):
def __init__(self, field, value, operation=None, previous=None, column=None):
super(ContainerUpdateClause, self).__init__(field, value)
self.previous = previous
self._assignments = None
self._operation = operation
self._analyzed = False
self._column = column
@@ -159,8 +160,8 @@ class ContainerUpdateClause(AssignmentClause):
class SetUpdateClause(ContainerUpdateClause):
""" updates a set collection """
def __init__(self, field, value, previous=None, column=None):
super(SetUpdateClause, self).__init__(field, value, previous, column=column)
def __init__(self, field, value, operation=None, previous=None, column=None):
super(SetUpdateClause, self).__init__(field, value, operation, previous, column=column)
self._additions = None
self._removals = None
@@ -182,6 +183,8 @@ class SetUpdateClause(ContainerUpdateClause):
""" works out the updates to be performed """
if self.value is None or self.value == self.previous:
pass
elif self._operation == "add":
self._additions = self.value
elif self.previous is None:
self._assignments = self.value
else:

View File

@@ -13,7 +13,7 @@ class TestQueryUpdateModel(Model):
cluster = columns.Integer(primary_key=True)
count = columns.Integer(required=False)
text = columns.Text(required=False, index=True)
text_set = columns.Set(columns.Text, required=False)
class QueryUpdateTests(BaseCassEngTestCase):
@@ -115,3 +115,13 @@ class QueryUpdateTests(BaseCassEngTestCase):
def test_counter_updates(self):
pass
def test_set_add_updates(self):
partition = uuid4()
cluster = 1
TestQueryUpdateModel.objects.create(
partition=partition, cluster=cluster, text_set={"foo"})
TestQueryUpdateModel.objects(
partition=partition, cluster=cluster).update(text_set__add={'bar'})
obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster)
self.assertEqual(obj.text_set, {"foo", "bar"})