adding support for partial updates for set columns
This commit is contained in:
@@ -29,6 +29,9 @@ class BaseValueManager(object):
|
||||
"""
|
||||
return self.value != self.previous_value
|
||||
|
||||
def reset_previous_value(self):
|
||||
self.previous_value = copy(self.value)
|
||||
|
||||
def getval(self):
|
||||
return self.value
|
||||
|
||||
@@ -283,9 +286,6 @@ class Counter(Column):
|
||||
super(Counter, self).__init__(**kwargs)
|
||||
raise NotImplementedError
|
||||
|
||||
class ContainerValueManager(BaseValueManager):
|
||||
pass
|
||||
|
||||
class ContainerQuoter(object):
|
||||
"""
|
||||
contains a single value, which will quote itself for CQL insertion statements
|
||||
@@ -323,6 +323,12 @@ class BaseContainerColumn(Column):
|
||||
db_type = self.db_type.format(self.value_type.db_type)
|
||||
return '{} {}'.format(self.db_field_name, db_type)
|
||||
|
||||
def get_update_statement(self, val, prev, ctx):
|
||||
"""
|
||||
Used to add partial update statements
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
class Set(BaseContainerColumn):
|
||||
"""
|
||||
Stores a set of unordered, unique values
|
||||
@@ -359,8 +365,50 @@ class Set(BaseContainerColumn):
|
||||
return {self.value_col.validate(v) for v in val}
|
||||
|
||||
def to_database(self, value):
|
||||
if value is None: return None
|
||||
return self.Quoter({self.value_col.to_database(v) for v in value})
|
||||
|
||||
def get_update_statement(self, val, prev, ctx):
|
||||
"""
|
||||
Returns statements that will be added to an object's update statement
|
||||
also updates the query context
|
||||
|
||||
:param val: the current column value
|
||||
:param prev: the previous column value
|
||||
:param ctx: the values that will be passed to the query
|
||||
:rtype: list
|
||||
"""
|
||||
|
||||
# remove from Quoter containers, if applicable
|
||||
if isinstance(val, self.Quoter): val = val.value
|
||||
if isinstance(prev, self.Quoter): prev = prev.value
|
||||
|
||||
if val is None or val == prev:
|
||||
# don't return anything if the new value is the same as
|
||||
# the old one, or if the new value is none
|
||||
return []
|
||||
elif prev is None or not any({v in prev for v in val}):
|
||||
field = uuid1().hex
|
||||
ctx[field] = self.Quoter(val)
|
||||
return ['"{}" = :{}'.format(self.db_field_name, field)]
|
||||
else:
|
||||
# partial update time
|
||||
to_create = val - prev
|
||||
to_delete = prev - val
|
||||
statements = []
|
||||
|
||||
if to_create:
|
||||
field_id = uuid1().hex
|
||||
ctx[field_id] = self.Quoter(to_create)
|
||||
statements += ['"{0}" = "{0}" + :{1}'.format(self.db_field_name, field_id)]
|
||||
|
||||
if to_delete:
|
||||
field_id = uuid1().hex
|
||||
ctx[field_id] = self.Quoter(to_delete)
|
||||
statements += ['"{0}" = "{0}" - :{1}'.format(self.db_field_name, field_id)]
|
||||
|
||||
return statements
|
||||
|
||||
class List(BaseContainerColumn):
|
||||
"""
|
||||
Stores a list of ordered values
|
||||
@@ -383,8 +431,15 @@ class List(BaseContainerColumn):
|
||||
return [self.value_col.validate(v) for v in val]
|
||||
|
||||
def to_database(self, value):
|
||||
if value is None: return None
|
||||
return self.Quoter([self.value_col.to_database(v) for v in value])
|
||||
|
||||
def get_update_statement(self, val, prev, values):
|
||||
"""
|
||||
http://en.wikipedia.org/wiki/Boyer%E2%80%93Moore_string_search_algorithm
|
||||
"""
|
||||
pass
|
||||
|
||||
class Map(BaseContainerColumn):
|
||||
"""
|
||||
Stores a key -> value map (dictionary)
|
||||
@@ -438,6 +493,13 @@ class Map(BaseContainerColumn):
|
||||
return {self.key_col.to_python(k):self.value_col.to_python(v) for k,v in value.items()}
|
||||
|
||||
def to_database(self, value):
|
||||
if value is None: return None
|
||||
return self.Quoter({self.key_col.to_database(k):self.value_col.to_database(v) for k,v in value.items()})
|
||||
|
||||
def get_update_statement(self, val, prev, ctx):
|
||||
"""
|
||||
http://www.datastax.com/docs/1.2/cql_cli/using/collections_map#deletion
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -136,8 +136,7 @@ class BaseModel(object):
|
||||
|
||||
#reset the value managers
|
||||
for v in self._values.values():
|
||||
v.previous_value = v.value
|
||||
|
||||
v.reset_previous_value()
|
||||
self._is_persisted = True
|
||||
|
||||
return self
|
||||
|
||||
@@ -4,6 +4,7 @@ from datetime import datetime
|
||||
from hashlib import md5
|
||||
from time import time
|
||||
from uuid import uuid1
|
||||
from cqlengine import BaseContainerColumn
|
||||
|
||||
from cqlengine.connection import connection_manager
|
||||
from cqlengine.exceptions import CQLEngineException
|
||||
@@ -641,7 +642,15 @@ class DMLQuery(object):
|
||||
if not col.is_primary_key:
|
||||
val = values.get(name)
|
||||
if val is None: continue
|
||||
set_statements += ['"{}" = :{}'.format(col.db_field_name, field_ids[col.db_field_name])]
|
||||
if isinstance(col, BaseContainerColumn):
|
||||
#remove value from query values, the column will handle it
|
||||
query_values.pop(field_ids.get(name), None)
|
||||
|
||||
val_mgr = self.instance._values[name]
|
||||
set_statements += col.get_update_statement(val, val_mgr.previous_value, query_values)
|
||||
pass
|
||||
else:
|
||||
set_statements += ['"{}" = :{}'.format(col.db_field_name, field_ids[col.db_field_name])]
|
||||
qs += [', '.join(set_statements)]
|
||||
|
||||
qs += ['WHERE']
|
||||
|
||||
@@ -45,6 +45,32 @@ class TestSetColumn(BaseCassEngTestCase):
|
||||
with self.assertRaises(ValidationError):
|
||||
TestSetModel.create(int_set={'string', True}, text_set={1, 3.0})
|
||||
|
||||
def test_partial_updates(self):
|
||||
""" Tests that partial udpates work as expected """
|
||||
m1 = TestSetModel.create(int_set={1,2,3,4})
|
||||
|
||||
m1.int_set.add(5)
|
||||
m1.int_set.remove(1)
|
||||
assert m1.int_set == {2,3,4,5}
|
||||
|
||||
m1.save()
|
||||
|
||||
m2 = TestSetModel.get(partition=m1.partition)
|
||||
assert m2.int_set == {2,3,4,5}
|
||||
|
||||
def test_partial_update_creation(self):
|
||||
"""
|
||||
Tests that proper update statements are created for a partial set update
|
||||
:return:
|
||||
"""
|
||||
ctx = {}
|
||||
col = columns.Set(columns.Integer, db_field="TEST")
|
||||
statements = col.get_update_statement({1,2,3,4}, {2,3,4,5}, ctx)
|
||||
|
||||
assert len([v for v in ctx.values() if {1} == v.value]) == 1
|
||||
assert len([v for v in ctx.values() if {5} == v.value]) == 1
|
||||
assert len([s for s in statements if '"TEST" = "TEST" -' in s]) == 1
|
||||
assert len([s for s in statements if '"TEST" = "TEST" +' in s]) == 1
|
||||
|
||||
class TestListModel(Model):
|
||||
partition = columns.UUID(primary_key=True, default=uuid4)
|
||||
|
||||
Reference in New Issue
Block a user