adding support for partial updates for set columns

This commit is contained in:
Blake Eggleston
2013-03-10 14:11:16 -07:00
parent 66ca79f777
commit af50ced2df
4 changed files with 102 additions and 6 deletions

View File

@@ -29,6 +29,9 @@ class BaseValueManager(object):
""" """
return self.value != self.previous_value return self.value != self.previous_value
def reset_previous_value(self):
self.previous_value = copy(self.value)
def getval(self): def getval(self):
return self.value return self.value
@@ -283,9 +286,6 @@ class Counter(Column):
super(Counter, self).__init__(**kwargs) super(Counter, self).__init__(**kwargs)
raise NotImplementedError raise NotImplementedError
class ContainerValueManager(BaseValueManager):
pass
class ContainerQuoter(object): class ContainerQuoter(object):
""" """
contains a single value, which will quote itself for CQL insertion statements contains a single value, which will quote itself for CQL insertion statements
@@ -323,6 +323,12 @@ class BaseContainerColumn(Column):
db_type = self.db_type.format(self.value_type.db_type) db_type = self.db_type.format(self.value_type.db_type)
return '{} {}'.format(self.db_field_name, db_type) return '{} {}'.format(self.db_field_name, db_type)
def get_update_statement(self, val, prev, ctx):
"""
Used to add partial update statements
"""
raise NotImplementedError
class Set(BaseContainerColumn): class Set(BaseContainerColumn):
""" """
Stores a set of unordered, unique values Stores a set of unordered, unique values
@@ -359,8 +365,50 @@ class Set(BaseContainerColumn):
return {self.value_col.validate(v) for v in val} return {self.value_col.validate(v) for v in val}
def to_database(self, value): def to_database(self, value):
if value is None: return None
return self.Quoter({self.value_col.to_database(v) for v in value}) return self.Quoter({self.value_col.to_database(v) for v in value})
def get_update_statement(self, val, prev, ctx):
"""
Returns statements that will be added to an object's update statement
also updates the query context
:param val: the current column value
:param prev: the previous column value
:param ctx: the values that will be passed to the query
:rtype: list
"""
# remove from Quoter containers, if applicable
if isinstance(val, self.Quoter): val = val.value
if isinstance(prev, self.Quoter): prev = prev.value
if val is None or val == prev:
# don't return anything if the new value is the same as
# the old one, or if the new value is none
return []
elif prev is None or not any({v in prev for v in val}):
field = uuid1().hex
ctx[field] = self.Quoter(val)
return ['"{}" = :{}'.format(self.db_field_name, field)]
else:
# partial update time
to_create = val - prev
to_delete = prev - val
statements = []
if to_create:
field_id = uuid1().hex
ctx[field_id] = self.Quoter(to_create)
statements += ['"{0}" = "{0}" + :{1}'.format(self.db_field_name, field_id)]
if to_delete:
field_id = uuid1().hex
ctx[field_id] = self.Quoter(to_delete)
statements += ['"{0}" = "{0}" - :{1}'.format(self.db_field_name, field_id)]
return statements
class List(BaseContainerColumn): class List(BaseContainerColumn):
""" """
Stores a list of ordered values Stores a list of ordered values
@@ -383,8 +431,15 @@ class List(BaseContainerColumn):
return [self.value_col.validate(v) for v in val] return [self.value_col.validate(v) for v in val]
def to_database(self, value): def to_database(self, value):
if value is None: return None
return self.Quoter([self.value_col.to_database(v) for v in value]) return self.Quoter([self.value_col.to_database(v) for v in value])
def get_update_statement(self, val, prev, values):
"""
http://en.wikipedia.org/wiki/Boyer%E2%80%93Moore_string_search_algorithm
"""
pass
class Map(BaseContainerColumn): class Map(BaseContainerColumn):
""" """
Stores a key -> value map (dictionary) Stores a key -> value map (dictionary)
@@ -438,6 +493,13 @@ class Map(BaseContainerColumn):
return {self.key_col.to_python(k):self.value_col.to_python(v) for k,v in value.items()} return {self.key_col.to_python(k):self.value_col.to_python(v) for k,v in value.items()}
def to_database(self, value): def to_database(self, value):
if value is None: return None
return self.Quoter({self.key_col.to_database(k):self.value_col.to_database(v) for k,v in value.items()}) return self.Quoter({self.key_col.to_database(k):self.value_col.to_database(v) for k,v in value.items()})
def get_update_statement(self, val, prev, ctx):
"""
http://www.datastax.com/docs/1.2/cql_cli/using/collections_map#deletion
"""
pass

View File

@@ -136,8 +136,7 @@ class BaseModel(object):
#reset the value managers #reset the value managers
for v in self._values.values(): for v in self._values.values():
v.previous_value = v.value v.reset_previous_value()
self._is_persisted = True self._is_persisted = True
return self return self

View File

@@ -4,6 +4,7 @@ from datetime import datetime
from hashlib import md5 from hashlib import md5
from time import time from time import time
from uuid import uuid1 from uuid import uuid1
from cqlengine import BaseContainerColumn
from cqlengine.connection import connection_manager from cqlengine.connection import connection_manager
from cqlengine.exceptions import CQLEngineException from cqlengine.exceptions import CQLEngineException
@@ -641,6 +642,14 @@ class DMLQuery(object):
if not col.is_primary_key: if not col.is_primary_key:
val = values.get(name) val = values.get(name)
if val is None: continue if val is None: continue
if isinstance(col, BaseContainerColumn):
#remove value from query values, the column will handle it
query_values.pop(field_ids.get(name), None)
val_mgr = self.instance._values[name]
set_statements += col.get_update_statement(val, val_mgr.previous_value, query_values)
pass
else:
set_statements += ['"{}" = :{}'.format(col.db_field_name, field_ids[col.db_field_name])] set_statements += ['"{}" = :{}'.format(col.db_field_name, field_ids[col.db_field_name])]
qs += [', '.join(set_statements)] qs += [', '.join(set_statements)]

View File

@@ -45,6 +45,32 @@ class TestSetColumn(BaseCassEngTestCase):
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
TestSetModel.create(int_set={'string', True}, text_set={1, 3.0}) TestSetModel.create(int_set={'string', True}, text_set={1, 3.0})
def test_partial_updates(self):
""" Tests that partial udpates work as expected """
m1 = TestSetModel.create(int_set={1,2,3,4})
m1.int_set.add(5)
m1.int_set.remove(1)
assert m1.int_set == {2,3,4,5}
m1.save()
m2 = TestSetModel.get(partition=m1.partition)
assert m2.int_set == {2,3,4,5}
def test_partial_update_creation(self):
"""
Tests that proper update statements are created for a partial set update
:return:
"""
ctx = {}
col = columns.Set(columns.Integer, db_field="TEST")
statements = col.get_update_statement({1,2,3,4}, {2,3,4,5}, ctx)
assert len([v for v in ctx.values() if {1} == v.value]) == 1
assert len([v for v in ctx.values() if {5} == v.value]) == 1
assert len([s for s in statements if '"TEST" = "TEST" -' in s]) == 1
assert len([s for s in statements if '"TEST" = "TEST" +' in s]) == 1
class TestListModel(Model): class TestListModel(Model):
partition = columns.UUID(primary_key=True, default=uuid4) partition = columns.UUID(primary_key=True, default=uuid4)