Implements the ability to blind update add to a set.
This commit is contained in:
@@ -663,24 +663,34 @@ class ModelQuerySet(AbstractQuerySet):
|
|||||||
nulled_columns = set()
|
nulled_columns = set()
|
||||||
us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, timestamp=self._timestamp)
|
us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, timestamp=self._timestamp)
|
||||||
for name, val in values.items():
|
for name, val in values.items():
|
||||||
col = self.model._columns.get(name)
|
col_name, col_op = self._parse_filter_arg(name)
|
||||||
|
col = self.model._columns.get(col_name)
|
||||||
# check for nonexistant columns
|
# check for nonexistant columns
|
||||||
if col is None:
|
if col is None:
|
||||||
raise ValidationError("{}.{} has no column named: {}".format(self.__module__, self.model.__name__, name))
|
raise ValidationError("{}.{} has no column named: {}".format(self.__module__, self.model.__name__, col_name))
|
||||||
# check for primary key update attempts
|
# check for primary key update attempts
|
||||||
if col.is_primary_key:
|
if col.is_primary_key:
|
||||||
raise ValidationError("Cannot apply update to primary key '{}' for {}.{}".format(name, self.__module__, self.model.__name__))
|
raise ValidationError("Cannot apply update to primary key '{}' for {}.{}".format(col_name, self.__module__, self.model.__name__))
|
||||||
|
|
||||||
val = col.validate(val)
|
val = col.validate(val)
|
||||||
if val is None:
|
if val is None:
|
||||||
nulled_columns.add(name)
|
nulled_columns.add(col_name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# add the update statements
|
# add the update statements
|
||||||
if isinstance(col, Counter):
|
if isinstance(col, Counter):
|
||||||
# TODO: implement counter updates
|
# TODO: implement counter updates
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
elif isinstance(col, BaseContainerColumn):
|
||||||
|
if isinstance(col, List): klass = ListUpdateClause
|
||||||
|
elif isinstance(col, Map): klass = MapUpdateClause
|
||||||
|
elif isinstance(col, Set): klass = SetUpdateClause
|
||||||
|
else: raise RuntimeError
|
||||||
|
us.add_assignment_clause(klass(
|
||||||
|
col_name, col.to_database(val), operation=col_op))
|
||||||
else:
|
else:
|
||||||
us.add_assignment_clause(AssignmentClause(name, col.to_database(val)))
|
us.add_assignment_clause(AssignmentClause(
|
||||||
|
col_name, col.to_database(val)))
|
||||||
|
|
||||||
if us.assignments:
|
if us.assignments:
|
||||||
self._execute(us)
|
self._execute(us)
|
||||||
|
|||||||
@@ -136,10 +136,11 @@ class AssignmentClause(BaseClause):
|
|||||||
|
|
||||||
class ContainerUpdateClause(AssignmentClause):
|
class ContainerUpdateClause(AssignmentClause):
|
||||||
|
|
||||||
def __init__(self, field, value, previous=None, column=None):
|
def __init__(self, field, value, operation=None, previous=None, column=None):
|
||||||
super(ContainerUpdateClause, self).__init__(field, value)
|
super(ContainerUpdateClause, self).__init__(field, value)
|
||||||
self.previous = previous
|
self.previous = previous
|
||||||
self._assignments = None
|
self._assignments = None
|
||||||
|
self._operation = operation
|
||||||
self._analyzed = False
|
self._analyzed = False
|
||||||
self._column = column
|
self._column = column
|
||||||
|
|
||||||
@@ -159,8 +160,8 @@ class ContainerUpdateClause(AssignmentClause):
|
|||||||
class SetUpdateClause(ContainerUpdateClause):
|
class SetUpdateClause(ContainerUpdateClause):
|
||||||
""" updates a set collection """
|
""" updates a set collection """
|
||||||
|
|
||||||
def __init__(self, field, value, previous=None, column=None):
|
def __init__(self, field, value, operation=None, previous=None, column=None):
|
||||||
super(SetUpdateClause, self).__init__(field, value, previous, column=column)
|
super(SetUpdateClause, self).__init__(field, value, operation, previous, column=column)
|
||||||
self._additions = None
|
self._additions = None
|
||||||
self._removals = None
|
self._removals = None
|
||||||
|
|
||||||
@@ -182,6 +183,8 @@ class SetUpdateClause(ContainerUpdateClause):
|
|||||||
""" works out the updates to be performed """
|
""" works out the updates to be performed """
|
||||||
if self.value is None or self.value == self.previous:
|
if self.value is None or self.value == self.previous:
|
||||||
pass
|
pass
|
||||||
|
elif self._operation == "add":
|
||||||
|
self._additions = self.value
|
||||||
elif self.previous is None:
|
elif self.previous is None:
|
||||||
self._assignments = self.value
|
self._assignments = self.value
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class TestQueryUpdateModel(Model):
|
|||||||
cluster = columns.Integer(primary_key=True)
|
cluster = columns.Integer(primary_key=True)
|
||||||
count = columns.Integer(required=False)
|
count = columns.Integer(required=False)
|
||||||
text = columns.Text(required=False, index=True)
|
text = columns.Text(required=False, index=True)
|
||||||
|
text_set = columns.Set(columns.Text, required=False)
|
||||||
|
|
||||||
class QueryUpdateTests(BaseCassEngTestCase):
|
class QueryUpdateTests(BaseCassEngTestCase):
|
||||||
|
|
||||||
@@ -115,3 +115,13 @@ class QueryUpdateTests(BaseCassEngTestCase):
|
|||||||
|
|
||||||
def test_counter_updates(self):
|
def test_counter_updates(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_set_add_updates(self):
|
||||||
|
partition = uuid4()
|
||||||
|
cluster = 1
|
||||||
|
TestQueryUpdateModel.objects.create(
|
||||||
|
partition=partition, cluster=cluster, text_set={"foo"})
|
||||||
|
TestQueryUpdateModel.objects(
|
||||||
|
partition=partition, cluster=cluster).update(text_set__add={'bar'})
|
||||||
|
obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster)
|
||||||
|
self.assertEqual(obj.text_set, {"foo", "bar"})
|
||||||
|
|||||||
Reference in New Issue
Block a user