Adds a map __merge update mode, to merge in map with the existing values

on the server.
This commit is contained in:
Danny Cosson
2014-02-10 11:28:03 -05:00
parent e11f53e370
commit 9dc55ec54b
3 changed files with 23 additions and 6 deletions

View File

@@ -681,9 +681,10 @@ class ModelQuerySet(AbstractQuerySet):
if isinstance(col, Counter):
# TODO: implement counter updates
raise NotImplementedError
elif isinstance(col, (List, Set)):
elif isinstance(col, (List, Set, Map)):
if isinstance(col, List): klass = ListUpdateClause
elif isinstance(col, Set): klass = SetUpdateClause
elif isinstance(col, Map): klass = MapUpdateClause
else: raise RuntimeError
us.add_assignment_clause(klass(
col_name, col.to_database(val), operation=col_op))

View File

@@ -310,13 +310,16 @@ class ListUpdateClause(ContainerUpdateClause):
class MapUpdateClause(ContainerUpdateClause):
""" updates a map collection """
def __init__(self, field, value, previous=None, column=None):
super(MapUpdateClause, self).__init__(field, value, previous, column=column)
def __init__(self, field, value, operation=None, previous=None, column=None):
super(MapUpdateClause, self).__init__(field, value, operation, previous, column=column)
self._updates = None
self.previous = self.previous or {}
def _analyze(self):
self._updates = sorted([k for k, v in self.value.items() if v != self.previous.get(k)]) or None
if self._operation == "merge":
self._updates = self.value.value
else:
self._updates = sorted([k for k, v in self.value.items() if v != self.previous.get(k)]) or None
self._analyzed = True
def get_context_size(self):
@@ -326,8 +329,7 @@ class MapUpdateClause(ContainerUpdateClause):
def update_context(self, ctx):
if not self._analyzed: self._analyze()
ctx_id = self.context_id
for key in self._updates or []:
val = self.value.get(key)
for key, val in self._updates.items():
ctx[str(ctx_id)] = self._column.key_col.to_database(key) if self._column else key
ctx[str(ctx_id + 1)] = self._column.value_col.to_database(val) if self._column else val
ctx_id += 2

View File

@@ -15,6 +15,7 @@ class TestQueryUpdateModel(Model):
text = columns.Text(required=False, index=True)
text_set = columns.Set(columns.Text, required=False)
text_list = columns.List(columns.Text, required=False)
text_map = columns.Map(columns.Text, columns.Text, required=False)
class QueryUpdateTests(BaseCassEngTestCase):
@@ -183,3 +184,16 @@ class QueryUpdateTests(BaseCassEngTestCase):
text_list__prepend=['bar', 'baz'])
obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster)
self.assertEqual(obj.text_list, ["bar", "baz", "foo"])
def test_map_merge_updates(self):
""" Merge a dictionary into existing value """
partition = uuid4()
cluster = 1
TestQueryUpdateModel.objects.create(
partition=partition, cluster=cluster,
text_map={"foo": '1', "bar": '2'})
TestQueryUpdateModel.objects(
partition=partition, cluster=cluster).update(
text_map__merge={"bar": '3', "baz": '4'})
obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster)
self.assertEqual(obj.text_map, {"foo": '1', "bar": '3', "baz": '4'})