From af50ced2df9ddd3d0dbcfc41910d79080a65dff0 Mon Sep 17 00:00:00 2001 From: Blake Eggleston Date: Sun, 10 Mar 2013 14:11:16 -0700 Subject: [PATCH] adding support for partial updates for set columns --- cqlengine/columns.py | 68 ++++++++++++++++++- cqlengine/models.py | 3 +- cqlengine/query.py | 11 ++- .../tests/columns/test_container_columns.py | 26 +++++++ 4 files changed, 102 insertions(+), 6 deletions(-) diff --git a/cqlengine/columns.py b/cqlengine/columns.py index 67d9212a..84b245da 100644 --- a/cqlengine/columns.py +++ b/cqlengine/columns.py @@ -29,6 +29,9 @@ class BaseValueManager(object): """ return self.value != self.previous_value + def reset_previous_value(self): + self.previous_value = copy(self.value) + def getval(self): return self.value @@ -283,9 +286,6 @@ class Counter(Column): super(Counter, self).__init__(**kwargs) raise NotImplementedError -class ContainerValueManager(BaseValueManager): - pass - class ContainerQuoter(object): """ contains a single value, which will quote itself for CQL insertion statements @@ -323,6 +323,12 @@ class BaseContainerColumn(Column): db_type = self.db_type.format(self.value_type.db_type) return '{} {}'.format(self.db_field_name, db_type) + def get_update_statement(self, val, prev, ctx): + """ + Used to add partial update statements + """ + raise NotImplementedError + class Set(BaseContainerColumn): """ Stores a set of unordered, unique values @@ -359,8 +365,50 @@ class Set(BaseContainerColumn): return {self.value_col.validate(v) for v in val} def to_database(self, value): + if value is None: return None return self.Quoter({self.value_col.to_database(v) for v in value}) + def get_update_statement(self, val, prev, ctx): + """ + Returns statements that will be added to an object's update statement + also updates the query context + + :param val: the current column value + :param prev: the previous column value + :param ctx: the values that will be passed to the query + :rtype: list + """ + + # remove from Quoter containers, if applicable + if isinstance(val, self.Quoter): val = val.value + if isinstance(prev, self.Quoter): prev = prev.value + + if val is None or val == prev: + # don't return anything if the new value is the same as + # the old one, or if the new value is none + return [] + elif prev is None or not any({v in prev for v in val}): + field = uuid1().hex + ctx[field] = self.Quoter(val) + return ['"{}" = :{}'.format(self.db_field_name, field)] + else: + # partial update time + to_create = val - prev + to_delete = prev - val + statements = [] + + if to_create: + field_id = uuid1().hex + ctx[field_id] = self.Quoter(to_create) + statements += ['"{0}" = "{0}" + :{1}'.format(self.db_field_name, field_id)] + + if to_delete: + field_id = uuid1().hex + ctx[field_id] = self.Quoter(to_delete) + statements += ['"{0}" = "{0}" - :{1}'.format(self.db_field_name, field_id)] + + return statements + class List(BaseContainerColumn): """ Stores a list of ordered values @@ -383,8 +431,15 @@ class List(BaseContainerColumn): return [self.value_col.validate(v) for v in val] def to_database(self, value): + if value is None: return None return self.Quoter([self.value_col.to_database(v) for v in value]) + def get_update_statement(self, val, prev, values): + """ + http://en.wikipedia.org/wiki/Boyer%E2%80%93Moore_string_search_algorithm + """ + pass + class Map(BaseContainerColumn): """ Stores a key -> value map (dictionary) @@ -438,6 +493,13 @@ class Map(BaseContainerColumn): return {self.key_col.to_python(k):self.value_col.to_python(v) for k,v in value.items()} def to_database(self, value): + if value is None: return None return self.Quoter({self.key_col.to_database(k):self.value_col.to_database(v) for k,v in value.items()}) + def get_update_statement(self, val, prev, ctx): + """ + http://www.datastax.com/docs/1.2/cql_cli/using/collections_map#deletion + """ + pass + diff --git a/cqlengine/models.py b/cqlengine/models.py index da4e18fc..9b6ec3b0 100644 --- a/cqlengine/models.py +++ b/cqlengine/models.py @@ -136,8 +136,7 @@ class BaseModel(object): #reset the value managers for v in self._values.values(): - v.previous_value = v.value - + v.reset_previous_value() self._is_persisted = True return self diff --git a/cqlengine/query.py b/cqlengine/query.py index 0b605399..fd91e45a 100644 --- a/cqlengine/query.py +++ b/cqlengine/query.py @@ -4,6 +4,7 @@ from datetime import datetime from hashlib import md5 from time import time from uuid import uuid1 +from cqlengine import BaseContainerColumn from cqlengine.connection import connection_manager from cqlengine.exceptions import CQLEngineException @@ -641,7 +642,15 @@ class DMLQuery(object): if not col.is_primary_key: val = values.get(name) if val is None: continue - set_statements += ['"{}" = :{}'.format(col.db_field_name, field_ids[col.db_field_name])] + if isinstance(col, BaseContainerColumn): + #remove value from query values, the column will handle it + query_values.pop(field_ids.get(name), None) + + val_mgr = self.instance._values[name] + set_statements += col.get_update_statement(val, val_mgr.previous_value, query_values) + pass + else: + set_statements += ['"{}" = :{}'.format(col.db_field_name, field_ids[col.db_field_name])] qs += [', '.join(set_statements)] qs += ['WHERE'] diff --git a/cqlengine/tests/columns/test_container_columns.py b/cqlengine/tests/columns/test_container_columns.py index 4a1721e0..15f89192 100644 --- a/cqlengine/tests/columns/test_container_columns.py +++ b/cqlengine/tests/columns/test_container_columns.py @@ -45,6 +45,32 @@ class TestSetColumn(BaseCassEngTestCase): with self.assertRaises(ValidationError): TestSetModel.create(int_set={'string', True}, text_set={1, 3.0}) + def test_partial_updates(self): + """ Tests that partial udpates work as expected """ + m1 = TestSetModel.create(int_set={1,2,3,4}) + + m1.int_set.add(5) + m1.int_set.remove(1) + assert m1.int_set == {2,3,4,5} + + m1.save() + + m2 = TestSetModel.get(partition=m1.partition) + assert m2.int_set == {2,3,4,5} + + def test_partial_update_creation(self): + """ + Tests that proper update statements are created for a partial set update + :return: + """ + ctx = {} + col = columns.Set(columns.Integer, db_field="TEST") + statements = col.get_update_statement({1,2,3,4}, {2,3,4,5}, ctx) + + assert len([v for v in ctx.values() if {1} == v.value]) == 1 + assert len([v for v in ctx.values() if {5} == v.value]) == 1 + assert len([s for s in statements if '"TEST" = "TEST" -' in s]) == 1 + assert len([s for s in statements if '"TEST" = "TEST" +' in s]) == 1 class TestListModel(Model): partition = columns.UUID(primary_key=True, default=uuid4)