adding support for partial updates for set columns

This commit is contained in:
Blake Eggleston
2013-03-10 14:11:16 -07:00
parent 66ca79f777
commit af50ced2df
4 changed files with 102 additions and 6 deletions

View File

@@ -29,6 +29,9 @@ class BaseValueManager(object):
"""
return self.value != self.previous_value
def reset_previous_value(self):
self.previous_value = copy(self.value)
def getval(self):
return self.value
@@ -283,9 +286,6 @@ class Counter(Column):
super(Counter, self).__init__(**kwargs)
raise NotImplementedError
class ContainerValueManager(BaseValueManager):
pass
class ContainerQuoter(object):
"""
contains a single value, which will quote itself for CQL insertion statements
@@ -323,6 +323,12 @@ class BaseContainerColumn(Column):
db_type = self.db_type.format(self.value_type.db_type)
return '{} {}'.format(self.db_field_name, db_type)
def get_update_statement(self, val, prev, ctx):
"""
Used to add partial update statements
"""
raise NotImplementedError
class Set(BaseContainerColumn):
"""
Stores a set of unordered, unique values
@@ -359,8 +365,50 @@ class Set(BaseContainerColumn):
return {self.value_col.validate(v) for v in val}
def to_database(self, value):
if value is None: return None
return self.Quoter({self.value_col.to_database(v) for v in value})
def get_update_statement(self, val, prev, ctx):
"""
Returns statements that will be added to an object's update statement
also updates the query context
:param val: the current column value
:param prev: the previous column value
:param ctx: the values that will be passed to the query
:rtype: list
"""
# remove from Quoter containers, if applicable
if isinstance(val, self.Quoter): val = val.value
if isinstance(prev, self.Quoter): prev = prev.value
if val is None or val == prev:
# don't return anything if the new value is the same as
# the old one, or if the new value is none
return []
elif prev is None or not any({v in prev for v in val}):
field = uuid1().hex
ctx[field] = self.Quoter(val)
return ['"{}" = :{}'.format(self.db_field_name, field)]
else:
# partial update time
to_create = val - prev
to_delete = prev - val
statements = []
if to_create:
field_id = uuid1().hex
ctx[field_id] = self.Quoter(to_create)
statements += ['"{0}" = "{0}" + :{1}'.format(self.db_field_name, field_id)]
if to_delete:
field_id = uuid1().hex
ctx[field_id] = self.Quoter(to_delete)
statements += ['"{0}" = "{0}" - :{1}'.format(self.db_field_name, field_id)]
return statements
class List(BaseContainerColumn):
"""
Stores a list of ordered values
@@ -383,8 +431,15 @@ class List(BaseContainerColumn):
return [self.value_col.validate(v) for v in val]
def to_database(self, value):
if value is None: return None
return self.Quoter([self.value_col.to_database(v) for v in value])
def get_update_statement(self, val, prev, values):
"""
http://en.wikipedia.org/wiki/Boyer%E2%80%93Moore_string_search_algorithm
"""
pass
class Map(BaseContainerColumn):
"""
Stores a key -> value map (dictionary)
@@ -438,6 +493,13 @@ class Map(BaseContainerColumn):
return {self.key_col.to_python(k):self.value_col.to_python(v) for k,v in value.items()}
def to_database(self, value):
if value is None: return None
return self.Quoter({self.key_col.to_database(k):self.value_col.to_database(v) for k,v in value.items()})
def get_update_statement(self, val, prev, ctx):
"""
http://www.datastax.com/docs/1.2/cql_cli/using/collections_map#deletion
"""
pass

View File

@@ -136,8 +136,7 @@ class BaseModel(object):
#reset the value managers
for v in self._values.values():
v.previous_value = v.value
v.reset_previous_value()
self._is_persisted = True
return self

View File

@@ -4,6 +4,7 @@ from datetime import datetime
from hashlib import md5
from time import time
from uuid import uuid1
from cqlengine import BaseContainerColumn
from cqlengine.connection import connection_manager
from cqlengine.exceptions import CQLEngineException
@@ -641,7 +642,15 @@ class DMLQuery(object):
if not col.is_primary_key:
val = values.get(name)
if val is None: continue
set_statements += ['"{}" = :{}'.format(col.db_field_name, field_ids[col.db_field_name])]
if isinstance(col, BaseContainerColumn):
#remove value from query values, the column will handle it
query_values.pop(field_ids.get(name), None)
val_mgr = self.instance._values[name]
set_statements += col.get_update_statement(val, val_mgr.previous_value, query_values)
pass
else:
set_statements += ['"{}" = :{}'.format(col.db_field_name, field_ids[col.db_field_name])]
qs += [', '.join(set_statements)]
qs += ['WHERE']

View File

@@ -45,6 +45,32 @@ class TestSetColumn(BaseCassEngTestCase):
with self.assertRaises(ValidationError):
TestSetModel.create(int_set={'string', True}, text_set={1, 3.0})
def test_partial_updates(self):
""" Tests that partial udpates work as expected """
m1 = TestSetModel.create(int_set={1,2,3,4})
m1.int_set.add(5)
m1.int_set.remove(1)
assert m1.int_set == {2,3,4,5}
m1.save()
m2 = TestSetModel.get(partition=m1.partition)
assert m2.int_set == {2,3,4,5}
def test_partial_update_creation(self):
"""
Tests that proper update statements are created for a partial set update
:return:
"""
ctx = {}
col = columns.Set(columns.Integer, db_field="TEST")
statements = col.get_update_statement({1,2,3,4}, {2,3,4,5}, ctx)
assert len([v for v in ctx.values() if {1} == v.value]) == 1
assert len([v for v in ctx.values() if {5} == v.value]) == 1
assert len([s for s in statements if '"TEST" = "TEST" -' in s]) == 1
assert len([s for s in statements if '"TEST" = "TEST" +' in s]) == 1
class TestListModel(Model):
partition = columns.UUID(primary_key=True, default=uuid4)