Adds a map __merge update mode, to merge in map with the existing values
on the server.
This commit is contained in:
@@ -681,9 +681,10 @@ class ModelQuerySet(AbstractQuerySet):
|
|||||||
if isinstance(col, Counter):
|
if isinstance(col, Counter):
|
||||||
# TODO: implement counter updates
|
# TODO: implement counter updates
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
elif isinstance(col, (List, Set)):
|
elif isinstance(col, (List, Set, Map)):
|
||||||
if isinstance(col, List): klass = ListUpdateClause
|
if isinstance(col, List): klass = ListUpdateClause
|
||||||
elif isinstance(col, Set): klass = SetUpdateClause
|
elif isinstance(col, Set): klass = SetUpdateClause
|
||||||
|
elif isinstance(col, Map): klass = MapUpdateClause
|
||||||
else: raise RuntimeError
|
else: raise RuntimeError
|
||||||
us.add_assignment_clause(klass(
|
us.add_assignment_clause(klass(
|
||||||
col_name, col.to_database(val), operation=col_op))
|
col_name, col.to_database(val), operation=col_op))
|
||||||
|
|||||||
@@ -310,13 +310,16 @@ class ListUpdateClause(ContainerUpdateClause):
|
|||||||
class MapUpdateClause(ContainerUpdateClause):
|
class MapUpdateClause(ContainerUpdateClause):
|
||||||
""" updates a map collection """
|
""" updates a map collection """
|
||||||
|
|
||||||
def __init__(self, field, value, previous=None, column=None):
|
def __init__(self, field, value, operation=None, previous=None, column=None):
|
||||||
super(MapUpdateClause, self).__init__(field, value, previous, column=column)
|
super(MapUpdateClause, self).__init__(field, value, operation, previous, column=column)
|
||||||
self._updates = None
|
self._updates = None
|
||||||
self.previous = self.previous or {}
|
self.previous = self.previous or {}
|
||||||
|
|
||||||
def _analyze(self):
|
def _analyze(self):
|
||||||
self._updates = sorted([k for k, v in self.value.items() if v != self.previous.get(k)]) or None
|
if self._operation == "merge":
|
||||||
|
self._updates = self.value.value
|
||||||
|
else:
|
||||||
|
self._updates = sorted([k for k, v in self.value.items() if v != self.previous.get(k)]) or None
|
||||||
self._analyzed = True
|
self._analyzed = True
|
||||||
|
|
||||||
def get_context_size(self):
|
def get_context_size(self):
|
||||||
@@ -326,8 +329,7 @@ class MapUpdateClause(ContainerUpdateClause):
|
|||||||
def update_context(self, ctx):
|
def update_context(self, ctx):
|
||||||
if not self._analyzed: self._analyze()
|
if not self._analyzed: self._analyze()
|
||||||
ctx_id = self.context_id
|
ctx_id = self.context_id
|
||||||
for key in self._updates or []:
|
for key, val in self._updates.items():
|
||||||
val = self.value.get(key)
|
|
||||||
ctx[str(ctx_id)] = self._column.key_col.to_database(key) if self._column else key
|
ctx[str(ctx_id)] = self._column.key_col.to_database(key) if self._column else key
|
||||||
ctx[str(ctx_id + 1)] = self._column.value_col.to_database(val) if self._column else val
|
ctx[str(ctx_id + 1)] = self._column.value_col.to_database(val) if self._column else val
|
||||||
ctx_id += 2
|
ctx_id += 2
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ class TestQueryUpdateModel(Model):
|
|||||||
text = columns.Text(required=False, index=True)
|
text = columns.Text(required=False, index=True)
|
||||||
text_set = columns.Set(columns.Text, required=False)
|
text_set = columns.Set(columns.Text, required=False)
|
||||||
text_list = columns.List(columns.Text, required=False)
|
text_list = columns.List(columns.Text, required=False)
|
||||||
|
text_map = columns.Map(columns.Text, columns.Text, required=False)
|
||||||
|
|
||||||
class QueryUpdateTests(BaseCassEngTestCase):
|
class QueryUpdateTests(BaseCassEngTestCase):
|
||||||
|
|
||||||
@@ -183,3 +184,16 @@ class QueryUpdateTests(BaseCassEngTestCase):
|
|||||||
text_list__prepend=['bar', 'baz'])
|
text_list__prepend=['bar', 'baz'])
|
||||||
obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster)
|
obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster)
|
||||||
self.assertEqual(obj.text_list, ["bar", "baz", "foo"])
|
self.assertEqual(obj.text_list, ["bar", "baz", "foo"])
|
||||||
|
|
||||||
|
def test_map_merge_updates(self):
|
||||||
|
""" Merge a dictionary into existing value """
|
||||||
|
partition = uuid4()
|
||||||
|
cluster = 1
|
||||||
|
TestQueryUpdateModel.objects.create(
|
||||||
|
partition=partition, cluster=cluster,
|
||||||
|
text_map={"foo": '1', "bar": '2'})
|
||||||
|
TestQueryUpdateModel.objects(
|
||||||
|
partition=partition, cluster=cluster).update(
|
||||||
|
text_map__merge={"bar": '3', "baz": '4'})
|
||||||
|
obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster)
|
||||||
|
self.assertEqual(obj.text_map, {"foo": '1', "bar": '3', "baz": '4'})
|
||||||
|
|||||||
Reference in New Issue
Block a user