1) fix save/update/delete for static column changes only when clustering key is null, and 2) allow null clustering key for create() when setting static column values only
This commit is contained in:
@@ -26,6 +26,7 @@ class BaseValueManager(object):
|
||||
self.column = column
|
||||
self.previous_value = deepcopy(value)
|
||||
self.value = value
|
||||
self.explicit = False
|
||||
|
||||
@property
|
||||
def deleted(self):
|
||||
@@ -140,9 +141,7 @@ class Column(object):
|
||||
if there's a problem
|
||||
"""
|
||||
if value is None:
|
||||
if self.has_default:
|
||||
return self.get_default()
|
||||
elif self.required:
|
||||
if self.required:
|
||||
raise ValidationError('{} - None values are not allowed'.format(self.column_name or self.db_field))
|
||||
return value
|
||||
|
||||
|
@@ -336,6 +336,8 @@ class BaseModel(object):
|
||||
if value is not None or isinstance(column, columns.BaseContainerColumn):
|
||||
value = column.to_python(value)
|
||||
value_mngr = column.value_manager(self, column, value)
|
||||
if name in values:
|
||||
value_mngr.explicit = True
|
||||
self._values[name] = value_mngr
|
||||
|
||||
# a flag set by the deserializer to indicate
|
||||
@@ -490,7 +492,10 @@ class BaseModel(object):
|
||||
def validate(self):
|
||||
""" Cleans and validates the field values """
|
||||
for name, col in self._columns.items():
|
||||
val = col.validate(getattr(self, name))
|
||||
v = getattr(self, name)
|
||||
if v is None and not self._values[name].explicit and col.has_default:
|
||||
v = col.get_default()
|
||||
val = col.validate(v)
|
||||
setattr(self, name, val)
|
||||
|
||||
### Let an instance be used like a dict of its columns keys/values
|
||||
|
@@ -806,7 +806,9 @@ class ModelQuerySet(AbstractQuerySet):
|
||||
if col.is_primary_key:
|
||||
raise ValidationError("Cannot apply update to primary key '{}' for {}.{}".format(col_name, self.__module__, self.model.__name__))
|
||||
|
||||
# we should not provide default values in this use case.
|
||||
val = col.validate(val)
|
||||
|
||||
if val is None:
|
||||
nulled_columns.add(col_name)
|
||||
continue
|
||||
@@ -914,11 +916,17 @@ class DMLQuery(object):
|
||||
if self.instance is None:
|
||||
raise CQLEngineException("DML Query intance attribute is None")
|
||||
assert type(self.instance) == self.model
|
||||
static_update_only = True
|
||||
null_clustering_key = False if len(self.instance._clustering_keys) == 0 else True
|
||||
static_changed_only = True
|
||||
statement = UpdateStatement(self.column_family_name, ttl=self._ttl,
|
||||
timestamp=self._timestamp, transactions=self._transaction)
|
||||
for name, col in self.instance._clustering_keys.items():
|
||||
null_clustering_key = null_clustering_key and col._val_is_null(getattr(self.instance, name, None))
|
||||
#get defined fields and their column names
|
||||
for name, col in self.model._columns.items():
|
||||
# if clustering key is null, don't include non static columns
|
||||
if null_clustering_key and not col.static and not col.partition_key:
|
||||
continue
|
||||
if not col.is_primary_key:
|
||||
val = getattr(self.instance, name, None)
|
||||
val_mgr = self.instance._values[name]
|
||||
@@ -931,7 +939,7 @@ class DMLQuery(object):
|
||||
if not val_mgr.changed and not isinstance(col, Counter):
|
||||
continue
|
||||
|
||||
static_update_only = (static_update_only and col.static)
|
||||
static_changed_only = static_changed_only and col.static
|
||||
if isinstance(col, (BaseContainerColumn, Counter)):
|
||||
# get appropriate clause
|
||||
if isinstance(col, List): klass = ListUpdateClause
|
||||
@@ -953,7 +961,8 @@ class DMLQuery(object):
|
||||
|
||||
if statement.get_context_size() > 0 or self.instance._has_counter:
|
||||
for name, col in self.model._primary_keys.items():
|
||||
if static_update_only and (not col.partition_key):
|
||||
# only include clustering key if clustering key is not null, and non static columns are changed to avoid cql error
|
||||
if (null_clustering_key or static_changed_only) and (not col.partition_key):
|
||||
continue
|
||||
statement.add_where_clause(WhereClause(
|
||||
col.db_field_name,
|
||||
@@ -962,7 +971,8 @@ class DMLQuery(object):
|
||||
))
|
||||
self._execute(statement)
|
||||
|
||||
self._delete_null_columns()
|
||||
if not null_clustering_key:
|
||||
self._delete_null_columns()
|
||||
|
||||
def save(self):
|
||||
"""
|
||||
@@ -980,7 +990,12 @@ class DMLQuery(object):
|
||||
return self.update()
|
||||
else:
|
||||
insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, if_not_exists=self._if_not_exists)
|
||||
static_save_only = False if len(self.instance._clustering_keys) == 0 else True
|
||||
for name, col in self.instance._clustering_keys.items():
|
||||
static_save_only = static_save_only and col._val_is_null(getattr(self.instance, name, None))
|
||||
for name, col in self.instance._columns.items():
|
||||
if static_save_only and not col.static and not col.partition_key:
|
||||
continue
|
||||
val = getattr(self.instance, name, None)
|
||||
if col._val_is_null(val):
|
||||
if self.instance._values[name].changed:
|
||||
@@ -996,7 +1011,8 @@ class DMLQuery(object):
|
||||
if not insert.is_empty:
|
||||
self._execute(insert)
|
||||
# delete any nulled columns
|
||||
self._delete_null_columns()
|
||||
if not static_save_only:
|
||||
self._delete_null_columns()
|
||||
|
||||
def delete(self):
|
||||
""" Deletes one instance """
|
||||
@@ -1005,6 +1021,7 @@ class DMLQuery(object):
|
||||
|
||||
ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp)
|
||||
for name, col in self.model._primary_keys.items():
|
||||
if (not col.partition_key) and (getattr(self.instance, name) is None): continue
|
||||
ds.add_where_clause(WhereClause(
|
||||
col.db_field_name,
|
||||
EqualsOperator(),
|
||||
|
@@ -529,4 +529,4 @@ class TestCamelMapColumn(BaseCassEngTestCase):
|
||||
drop_table(TestCamelMapModel)
|
||||
|
||||
def test_camelcase_column(self):
|
||||
TestCamelMapModel.create(partition=None, camelMap={'blah': 1})
|
||||
TestCamelMapModel.create(camelMap={'blah': 1})
|
||||
|
@@ -1,3 +1,6 @@
|
||||
#import sys, nose
|
||||
#sys.path.insert(0, '/Users/andy/projects/cqlengine')
|
||||
|
||||
from uuid import uuid4
|
||||
from unittest import skipUnless
|
||||
from cqlengine import Model
|
||||
@@ -9,9 +12,8 @@ from cqlengine.tests.base import CASSANDRA_VERSION, PROTOCOL_VERSION
|
||||
|
||||
|
||||
class TestStaticModel(Model):
|
||||
|
||||
partition = columns.UUID(primary_key=True, default=uuid4)
|
||||
cluster = columns.UUID(primary_key=True, default=uuid4)
|
||||
cluster = columns.UUID(primary_key=True)
|
||||
static = columns.Text(static=True)
|
||||
text = columns.Text()
|
||||
|
||||
@@ -33,7 +35,7 @@ class TestStaticColumn(BaseCassEngTestCase):
|
||||
@skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0")
|
||||
def test_mixed_updates(self):
|
||||
""" Tests that updates on both static and non-static columns work as intended """
|
||||
instance = TestStaticModel.create()
|
||||
instance = TestStaticModel.create(cluster=uuid4())
|
||||
instance.static = "it's shared"
|
||||
instance.text = "some text"
|
||||
instance.save()
|
||||
@@ -49,7 +51,7 @@ class TestStaticColumn(BaseCassEngTestCase):
|
||||
@skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0")
|
||||
def test_static_only_updates(self):
|
||||
""" Tests that updates on static only column work as intended """
|
||||
instance = TestStaticModel.create()
|
||||
instance = TestStaticModel.create(cluster=uuid4())
|
||||
instance.static = "it's shared"
|
||||
instance.text = "some text"
|
||||
instance.save()
|
||||
@@ -60,3 +62,17 @@ class TestStaticColumn(BaseCassEngTestCase):
|
||||
actual = TestStaticModel.get(partition=u.partition)
|
||||
assert actual.static == "it's still shared"
|
||||
|
||||
@skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0")
|
||||
def test_static_with_null_cluster_key(self):
|
||||
""" Tests that save/update/delete works for static column works when clustering key is null"""
|
||||
instance = TestStaticModel.create(cluster=None, static = "it's shared")
|
||||
instance.save()
|
||||
|
||||
u = TestStaticModel.get(partition=instance.partition)
|
||||
u.static = "it's still shared"
|
||||
u.update()
|
||||
actual = TestStaticModel.get(partition=u.partition)
|
||||
assert actual.static == "it's still shared"
|
||||
|
||||
#if __name__ == '__main__':
|
||||
# nose.main()
|
Reference in New Issue
Block a user