Adds a map __merge update mode, to merge in map with the existing values
on the server.
This commit is contained in:
@@ -681,9 +681,10 @@ class ModelQuerySet(AbstractQuerySet):
|
||||
if isinstance(col, Counter):
|
||||
# TODO: implement counter updates
|
||||
raise NotImplementedError
|
||||
elif isinstance(col, (List, Set)):
|
||||
elif isinstance(col, (List, Set, Map)):
|
||||
if isinstance(col, List): klass = ListUpdateClause
|
||||
elif isinstance(col, Set): klass = SetUpdateClause
|
||||
elif isinstance(col, Map): klass = MapUpdateClause
|
||||
else: raise RuntimeError
|
||||
us.add_assignment_clause(klass(
|
||||
col_name, col.to_database(val), operation=col_op))
|
||||
|
||||
@@ -310,13 +310,16 @@ class ListUpdateClause(ContainerUpdateClause):
|
||||
class MapUpdateClause(ContainerUpdateClause):
|
||||
""" updates a map collection """
|
||||
|
||||
def __init__(self, field, value, previous=None, column=None):
|
||||
super(MapUpdateClause, self).__init__(field, value, previous, column=column)
|
||||
def __init__(self, field, value, operation=None, previous=None, column=None):
|
||||
super(MapUpdateClause, self).__init__(field, value, operation, previous, column=column)
|
||||
self._updates = None
|
||||
self.previous = self.previous or {}
|
||||
|
||||
def _analyze(self):
|
||||
self._updates = sorted([k for k, v in self.value.items() if v != self.previous.get(k)]) or None
|
||||
if self._operation == "merge":
|
||||
self._updates = self.value.value
|
||||
else:
|
||||
self._updates = sorted([k for k, v in self.value.items() if v != self.previous.get(k)]) or None
|
||||
self._analyzed = True
|
||||
|
||||
def get_context_size(self):
|
||||
@@ -326,8 +329,7 @@ class MapUpdateClause(ContainerUpdateClause):
|
||||
def update_context(self, ctx):
|
||||
if not self._analyzed: self._analyze()
|
||||
ctx_id = self.context_id
|
||||
for key in self._updates or []:
|
||||
val = self.value.get(key)
|
||||
for key, val in self._updates.items():
|
||||
ctx[str(ctx_id)] = self._column.key_col.to_database(key) if self._column else key
|
||||
ctx[str(ctx_id + 1)] = self._column.value_col.to_database(val) if self._column else val
|
||||
ctx_id += 2
|
||||
|
||||
@@ -15,6 +15,7 @@ class TestQueryUpdateModel(Model):
|
||||
text = columns.Text(required=False, index=True)
|
||||
text_set = columns.Set(columns.Text, required=False)
|
||||
text_list = columns.List(columns.Text, required=False)
|
||||
text_map = columns.Map(columns.Text, columns.Text, required=False)
|
||||
|
||||
class QueryUpdateTests(BaseCassEngTestCase):
|
||||
|
||||
@@ -183,3 +184,16 @@ class QueryUpdateTests(BaseCassEngTestCase):
|
||||
text_list__prepend=['bar', 'baz'])
|
||||
obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster)
|
||||
self.assertEqual(obj.text_list, ["bar", "baz", "foo"])
|
||||
|
||||
def test_map_merge_updates(self):
|
||||
""" Merge a dictionary into existing value """
|
||||
partition = uuid4()
|
||||
cluster = 1
|
||||
TestQueryUpdateModel.objects.create(
|
||||
partition=partition, cluster=cluster,
|
||||
text_map={"foo": '1', "bar": '2'})
|
||||
TestQueryUpdateModel.objects(
|
||||
partition=partition, cluster=cluster).update(
|
||||
text_map__merge={"bar": '3', "baz": '4'})
|
||||
obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster)
|
||||
self.assertEqual(obj.text_map, {"foo": '1', "bar": '3', "baz": '4'})
|
||||
|
||||
Reference in New Issue
Block a user