From 9dc55ec54bef7c218ad2d9036228a2555b2b6aa5 Mon Sep 17 00:00:00 2001 From: Danny Cosson Date: Mon, 10 Feb 2014 11:28:03 -0500 Subject: [PATCH] Adds a map __merge update mode, to merge in map with the existing values on the server. --- cqlengine/query.py | 3 ++- cqlengine/statements.py | 12 +++++++----- cqlengine/tests/query/test_updates.py | 14 ++++++++++++++ 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/cqlengine/query.py b/cqlengine/query.py index 5f32a87c..ce703655 100644 --- a/cqlengine/query.py +++ b/cqlengine/query.py @@ -681,9 +681,10 @@ class ModelQuerySet(AbstractQuerySet): if isinstance(col, Counter): # TODO: implement counter updates raise NotImplementedError - elif isinstance(col, (List, Set)): + elif isinstance(col, (List, Set, Map)): if isinstance(col, List): klass = ListUpdateClause elif isinstance(col, Set): klass = SetUpdateClause + elif isinstance(col, Map): klass = MapUpdateClause else: raise RuntimeError us.add_assignment_clause(klass( col_name, col.to_database(val), operation=col_op)) diff --git a/cqlengine/statements.py b/cqlengine/statements.py index d53e0cd0..5ae1d213 100644 --- a/cqlengine/statements.py +++ b/cqlengine/statements.py @@ -310,13 +310,16 @@ class ListUpdateClause(ContainerUpdateClause): class MapUpdateClause(ContainerUpdateClause): """ updates a map collection """ - def __init__(self, field, value, previous=None, column=None): - super(MapUpdateClause, self).__init__(field, value, previous, column=column) + def __init__(self, field, value, operation=None, previous=None, column=None): + super(MapUpdateClause, self).__init__(field, value, operation, previous, column=column) self._updates = None self.previous = self.previous or {} def _analyze(self): - self._updates = sorted([k for k, v in self.value.items() if v != self.previous.get(k)]) or None + if self._operation == "merge": + self._updates = self.value.value + else: + self._updates = sorted([k for k, v in self.value.items() if v != self.previous.get(k)]) or None self._analyzed = True def get_context_size(self): @@ -326,8 +329,7 @@ class MapUpdateClause(ContainerUpdateClause): def update_context(self, ctx): if not self._analyzed: self._analyze() ctx_id = self.context_id - for key in self._updates or []: - val = self.value.get(key) + for key, val in self._updates.items(): ctx[str(ctx_id)] = self._column.key_col.to_database(key) if self._column else key ctx[str(ctx_id + 1)] = self._column.value_col.to_database(val) if self._column else val ctx_id += 2 diff --git a/cqlengine/tests/query/test_updates.py b/cqlengine/tests/query/test_updates.py index 1541f699..723a8e87 100644 --- a/cqlengine/tests/query/test_updates.py +++ b/cqlengine/tests/query/test_updates.py @@ -15,6 +15,7 @@ class TestQueryUpdateModel(Model): text = columns.Text(required=False, index=True) text_set = columns.Set(columns.Text, required=False) text_list = columns.List(columns.Text, required=False) + text_map = columns.Map(columns.Text, columns.Text, required=False) class QueryUpdateTests(BaseCassEngTestCase): @@ -183,3 +184,16 @@ class QueryUpdateTests(BaseCassEngTestCase): text_list__prepend=['bar', 'baz']) obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) self.assertEqual(obj.text_list, ["bar", "baz", "foo"]) + + def test_map_merge_updates(self): + """ Merge a dictionary into existing value """ + partition = uuid4() + cluster = 1 + TestQueryUpdateModel.objects.create( + partition=partition, cluster=cluster, + text_map={"foo": '1', "bar": '2'}) + TestQueryUpdateModel.objects( + partition=partition, cluster=cluster).update( + text_map__merge={"bar": '3', "baz": '4'}) + obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) + self.assertEqual(obj.text_map, {"foo": '1', "bar": '3', "baz": '4'})