@@ -1,5 +1,11 @@
|
||||
CHANGELOG
|
||||
|
||||
0.9
|
||||
* adding update method
|
||||
* adding BigInt column (thanks @Lifto)
|
||||
* adding support for timezone aware time uuid functions (thanks @dokai)
|
||||
* only saving collection fields on insert if they've been modified
|
||||
|
||||
0.8.5
|
||||
* adding support for timeouts
|
||||
|
||||
|
||||
@@ -198,6 +198,10 @@ class Column(object):
|
||||
def get_cql(self):
|
||||
return '"{}"'.format(self.db_field_name)
|
||||
|
||||
def _val_is_null(self, val):
|
||||
""" determines if the given value equates to a null value for the given column type """
|
||||
return val is None
|
||||
|
||||
|
||||
class Bytes(Column):
|
||||
db_type = 'blob'
|
||||
@@ -523,6 +527,15 @@ class BaseContainerColumn(Column):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _val_is_null(self, val):
|
||||
return not val
|
||||
|
||||
|
||||
class BaseContainerQuoter(ValueQuoter):
|
||||
|
||||
def __nonzero__(self):
|
||||
return bool(self.value)
|
||||
|
||||
|
||||
class Set(BaseContainerColumn):
|
||||
"""
|
||||
@@ -532,7 +545,7 @@ class Set(BaseContainerColumn):
|
||||
"""
|
||||
db_type = 'set<{}>'
|
||||
|
||||
class Quoter(ValueQuoter):
|
||||
class Quoter(BaseContainerQuoter):
|
||||
|
||||
def __str__(self):
|
||||
cq = cql_quote
|
||||
@@ -622,12 +635,15 @@ class List(BaseContainerColumn):
|
||||
"""
|
||||
db_type = 'list<{}>'
|
||||
|
||||
class Quoter(ValueQuoter):
|
||||
class Quoter(BaseContainerQuoter):
|
||||
|
||||
def __str__(self):
|
||||
cq = cql_quote
|
||||
return '[' + ', '.join([cq(v) for v in self.value]) + ']'
|
||||
|
||||
def __nonzero__(self):
|
||||
return bool(self.value)
|
||||
|
||||
def __init__(self, value_type, default=set, **kwargs):
|
||||
return super(List, self).__init__(value_type=value_type, default=default, **kwargs)
|
||||
|
||||
@@ -733,7 +749,7 @@ class Map(BaseContainerColumn):
|
||||
|
||||
db_type = 'map<{}, {}>'
|
||||
|
||||
class Quoter(ValueQuoter):
|
||||
class Quoter(BaseContainerQuoter):
|
||||
|
||||
def __str__(self):
|
||||
cq = cql_quote
|
||||
|
||||
@@ -356,7 +356,6 @@ class BaseModel(object):
|
||||
return cls.objects.get(*args, **kwargs)
|
||||
|
||||
def save(self):
|
||||
|
||||
# handle polymorphic models
|
||||
if self._is_polymorphic:
|
||||
if self._is_polymorphic_base:
|
||||
@@ -375,6 +374,37 @@ class BaseModel(object):
|
||||
|
||||
return self
|
||||
|
||||
def update(self, **values):
|
||||
for k, v in values.items():
|
||||
col = self._columns.get(k)
|
||||
|
||||
# check for nonexistant columns
|
||||
if col is None:
|
||||
raise ValidationError("{}.{} has no column named: {}".format(self.__module__, self.__class__.__name__, k))
|
||||
|
||||
# check for primary key update attempts
|
||||
if col.is_primary_key:
|
||||
raise ValidationError("Cannot apply update to primary key '{}' for {}.{}".format(k, self.__module__, self.__class__.__name__))
|
||||
|
||||
setattr(self, k, v)
|
||||
|
||||
# handle polymorphic models
|
||||
if self._is_polymorphic:
|
||||
if self._is_polymorphic_base:
|
||||
raise PolyMorphicModelException('cannot update polymorphic base model')
|
||||
else:
|
||||
setattr(self, self._polymorphic_column_name, self.__polymorphic_key__)
|
||||
|
||||
self.validate()
|
||||
self.__dmlquery__(self.__class__, self, batch=self._batch).update()
|
||||
|
||||
#reset the value managers
|
||||
for v in self._values.values():
|
||||
v.reset_previous_value()
|
||||
self._is_persisted = True
|
||||
|
||||
return self
|
||||
|
||||
def delete(self):
|
||||
""" Deletes this instance """
|
||||
self.__dmlquery__(self.__class__, self, batch=self._batch).delete()
|
||||
|
||||
@@ -9,7 +9,7 @@ from cqlengine.columns import Counter
|
||||
|
||||
from cqlengine.connection import connection_manager, execute, RowResult
|
||||
|
||||
from cqlengine.exceptions import CQLEngineException
|
||||
from cqlengine.exceptions import CQLEngineException, ValidationError
|
||||
from cqlengine.functions import QueryValue, Token
|
||||
|
||||
#CQL 3 reference:
|
||||
@@ -22,6 +22,7 @@ class MultipleObjectsReturned(QueryException): pass
|
||||
|
||||
class QueryOperatorException(QueryException): pass
|
||||
|
||||
|
||||
class QueryOperator(object):
|
||||
# The symbol that identifies this operator in filter kwargs
|
||||
# ie: colname__<symbol>
|
||||
@@ -116,10 +117,12 @@ class QueryOperator(object):
|
||||
def __hash__(self):
|
||||
return hash(self.column.db_field_name) ^ hash(self.value)
|
||||
|
||||
|
||||
class EqualsOperator(QueryOperator):
|
||||
symbol = 'EQ'
|
||||
cql_symbol = '='
|
||||
|
||||
|
||||
class IterableQueryValue(QueryValue):
|
||||
def __init__(self, value):
|
||||
try:
|
||||
@@ -133,28 +136,34 @@ class IterableQueryValue(QueryValue):
|
||||
def get_cql(self):
|
||||
return '({})'.format(', '.join(':{}'.format(i) for i in self.identifier))
|
||||
|
||||
|
||||
class InOperator(EqualsOperator):
|
||||
symbol = 'IN'
|
||||
cql_symbol = 'IN'
|
||||
|
||||
QUERY_VALUE_WRAPPER = IterableQueryValue
|
||||
|
||||
|
||||
class GreaterThanOperator(QueryOperator):
|
||||
symbol = "GT"
|
||||
cql_symbol = '>'
|
||||
|
||||
|
||||
class GreaterThanOrEqualOperator(QueryOperator):
|
||||
symbol = "GTE"
|
||||
cql_symbol = '>='
|
||||
|
||||
|
||||
class LessThanOperator(QueryOperator):
|
||||
symbol = "LT"
|
||||
cql_symbol = '<'
|
||||
|
||||
|
||||
class LessThanOrEqualOperator(QueryOperator):
|
||||
symbol = "LTE"
|
||||
cql_symbol = '<='
|
||||
|
||||
|
||||
class AbstractQueryableColumn(object):
|
||||
"""
|
||||
exposes cql query operators through pythons
|
||||
@@ -192,6 +201,7 @@ class BatchType(object):
|
||||
Unlogged = 'UNLOGGED'
|
||||
Counter = 'COUNTER'
|
||||
|
||||
|
||||
class BatchQuery(object):
|
||||
"""
|
||||
Handles the batching of queries
|
||||
@@ -630,6 +640,7 @@ class AbstractQuerySet(object):
|
||||
def __ne__(self, q):
|
||||
return not (self != q)
|
||||
|
||||
|
||||
class ResultObject(dict):
|
||||
"""
|
||||
adds attribute access to a dictionary
|
||||
@@ -641,6 +652,7 @@ class ResultObject(dict):
|
||||
except KeyError:
|
||||
raise AttributeError
|
||||
|
||||
|
||||
class SimpleQuerySet(AbstractQuerySet):
|
||||
"""
|
||||
|
||||
@@ -658,6 +670,7 @@ class SimpleQuerySet(AbstractQuerySet):
|
||||
return ResultObject(zip(names, values))
|
||||
return _construct_instance
|
||||
|
||||
|
||||
class ModelQuerySet(AbstractQuerySet):
|
||||
"""
|
||||
|
||||
@@ -740,6 +753,54 @@ class ModelQuerySet(AbstractQuerySet):
|
||||
clone._flat_values_list = flat
|
||||
return clone
|
||||
|
||||
def update(self, **values):
|
||||
""" Updates the rows in this queryset """
|
||||
if not values:
|
||||
return
|
||||
|
||||
set_statements = []
|
||||
ctx = {}
|
||||
nulled_columns = set()
|
||||
for name, val in values.items():
|
||||
col = self.model._columns.get(name)
|
||||
# check for nonexistant columns
|
||||
if col is None:
|
||||
raise ValidationError("{}.{} has no column named: {}".format(self.__module__, self.model.__name__, name))
|
||||
# check for primary key update attempts
|
||||
if col.is_primary_key:
|
||||
raise ValidationError("Cannot apply update to primary key '{}' for {}.{}".format(name, self.__module__, self.model.__name__))
|
||||
|
||||
val = col.validate(val)
|
||||
if val is None:
|
||||
nulled_columns.add(name)
|
||||
continue
|
||||
# add the update statements
|
||||
if isinstance(col, (BaseContainerColumn, Counter)):
|
||||
val_mgr = self.instance._values[name]
|
||||
set_statements += col.get_update_statement(val, val_mgr.previous_value, ctx)
|
||||
|
||||
else:
|
||||
field_id = uuid4().hex
|
||||
set_statements += ['"{}" = :{}'.format(col.db_field_name, field_id)]
|
||||
ctx[field_id] = val
|
||||
|
||||
if set_statements:
|
||||
qs = "UPDATE {} SET {} WHERE {}".format(
|
||||
self.column_family_name,
|
||||
', '.join(set_statements),
|
||||
self._where_clause()
|
||||
)
|
||||
ctx.update(self._where_values())
|
||||
execute(qs, ctx)
|
||||
|
||||
if nulled_columns:
|
||||
qs = "DELETE {} FROM {} WHERE {}".format(
|
||||
', '.join(nulled_columns),
|
||||
self.column_family_name,
|
||||
self._where_clause()
|
||||
)
|
||||
execute(qs, self._where_values())
|
||||
|
||||
|
||||
class DMLQuery(object):
|
||||
"""
|
||||
@@ -763,84 +824,11 @@ class DMLQuery(object):
|
||||
self._batch = batch_obj
|
||||
return self
|
||||
|
||||
def save(self):
|
||||
def _delete_null_columns(self):
|
||||
"""
|
||||
Creates / updates a row.
|
||||
This is a blind insert call.
|
||||
All validation and cleaning needs to happen
|
||||
prior to calling this.
|
||||
executes a delete query to remove columns that have changed to null
|
||||
"""
|
||||
if self.instance is None:
|
||||
raise CQLEngineException("DML Query intance attribute is None")
|
||||
assert type(self.instance) == self.model
|
||||
|
||||
#organize data
|
||||
value_pairs = []
|
||||
values = self.instance._as_dict()
|
||||
|
||||
#get defined fields and their column names
|
||||
for name, col in self.model._columns.items():
|
||||
val = values.get(name)
|
||||
if val is None: continue
|
||||
value_pairs += [(col.db_field_name, val)]
|
||||
|
||||
#construct query string
|
||||
field_names = zip(*value_pairs)[0]
|
||||
field_ids = {n:uuid4().hex for n in field_names}
|
||||
field_values = dict(value_pairs)
|
||||
query_values = {field_ids[n]:field_values[n] for n in field_names}
|
||||
|
||||
qs = []
|
||||
if self.instance._has_counter or self.instance._can_update():
|
||||
qs += ["UPDATE {}".format(self.column_family_name)]
|
||||
qs += ["SET"]
|
||||
|
||||
set_statements = []
|
||||
#get defined fields and their column names
|
||||
for name, col in self.model._columns.items():
|
||||
if not col.is_primary_key:
|
||||
val = values.get(name)
|
||||
if val is None:
|
||||
continue
|
||||
if isinstance(col, (BaseContainerColumn, Counter)):
|
||||
#remove value from query values, the column will handle it
|
||||
query_values.pop(field_ids.get(name), None)
|
||||
|
||||
val_mgr = self.instance._values[name]
|
||||
set_statements += col.get_update_statement(val, val_mgr.previous_value, query_values)
|
||||
|
||||
else:
|
||||
set_statements += ['"{}" = :{}'.format(col.db_field_name, field_ids[col.db_field_name])]
|
||||
qs += [', '.join(set_statements)]
|
||||
|
||||
qs += ['WHERE']
|
||||
|
||||
where_statements = []
|
||||
for name, col in self.model._primary_keys.items():
|
||||
where_statements += ['"{}" = :{}'.format(col.db_field_name, field_ids[col.db_field_name])]
|
||||
|
||||
qs += [' AND '.join(where_statements)]
|
||||
|
||||
# clear the qs if there are no set statements and this is not a counter model
|
||||
if not set_statements and not self.instance._has_counter:
|
||||
qs = []
|
||||
|
||||
else:
|
||||
qs += ["INSERT INTO {}".format(self.column_family_name)]
|
||||
qs += ["({})".format(', '.join(['"{}"'.format(f) for f in field_names]))]
|
||||
qs += ['VALUES']
|
||||
qs += ["({})".format(', '.join([':'+field_ids[f] for f in field_names]))]
|
||||
|
||||
qs = ' '.join(qs)
|
||||
|
||||
# skip query execution if it's empty
|
||||
# caused by pointless update queries
|
||||
if qs:
|
||||
if self._batch:
|
||||
self._batch.add_query(qs, query_values)
|
||||
else:
|
||||
execute(qs, query_values)
|
||||
|
||||
values, field_names, field_ids, field_values, query_values = self._get_query_values()
|
||||
|
||||
# delete nulled columns and removed map keys
|
||||
qs = ['DELETE']
|
||||
@@ -874,6 +862,128 @@ class DMLQuery(object):
|
||||
else:
|
||||
execute(qs, query_values)
|
||||
|
||||
def update(self):
|
||||
"""
|
||||
updates a row.
|
||||
This is a blind update call.
|
||||
All validation and cleaning needs to happen
|
||||
prior to calling this.
|
||||
"""
|
||||
if self.instance is None:
|
||||
raise CQLEngineException("DML Query intance attribute is None")
|
||||
assert type(self.instance) == self.model
|
||||
|
||||
values, field_names, field_ids, field_values, query_values = self._get_query_values()
|
||||
|
||||
qs = []
|
||||
qs += ["UPDATE {}".format(self.column_family_name)]
|
||||
qs += ["SET"]
|
||||
|
||||
set_statements = []
|
||||
#get defined fields and their column names
|
||||
for name, col in self.model._columns.items():
|
||||
if not col.is_primary_key:
|
||||
val = values.get(name)
|
||||
|
||||
# don't update something that is null
|
||||
if val is None:
|
||||
continue
|
||||
|
||||
# don't update something if it hasn't changed
|
||||
if not self.instance._values[name].changed and not isinstance(col, Counter):
|
||||
continue
|
||||
|
||||
# add the update statements
|
||||
if isinstance(col, (BaseContainerColumn, Counter)):
|
||||
#remove value from query values, the column will handle it
|
||||
query_values.pop(field_ids.get(name), None)
|
||||
|
||||
val_mgr = self.instance._values[name]
|
||||
set_statements += col.get_update_statement(val, val_mgr.previous_value, query_values)
|
||||
|
||||
else:
|
||||
set_statements += ['"{}" = :{}'.format(col.db_field_name, field_ids[col.db_field_name])]
|
||||
qs += [', '.join(set_statements)]
|
||||
|
||||
qs += ['WHERE']
|
||||
|
||||
where_statements = []
|
||||
for name, col in self.model._primary_keys.items():
|
||||
where_statements += ['"{}" = :{}'.format(col.db_field_name, field_ids[col.db_field_name])]
|
||||
|
||||
qs += [' AND '.join(where_statements)]
|
||||
|
||||
# clear the qs if there are no set statements and this is not a counter model
|
||||
if not set_statements and not self.instance._has_counter:
|
||||
qs = []
|
||||
|
||||
qs = ' '.join(qs)
|
||||
# skip query execution if it's empty
|
||||
# caused by pointless update queries
|
||||
if qs:
|
||||
if self._batch:
|
||||
self._batch.add_query(qs, query_values)
|
||||
else:
|
||||
execute(qs, query_values)
|
||||
|
||||
self._delete_null_columns()
|
||||
|
||||
def _get_query_values(self):
|
||||
"""
|
||||
returns all the data needed to do queries
|
||||
"""
|
||||
#organize data
|
||||
value_pairs = []
|
||||
values = self.instance._as_dict()
|
||||
|
||||
#get defined fields and their column names
|
||||
for name, col in self.model._columns.items():
|
||||
val = values.get(name)
|
||||
if col._val_is_null(val): continue
|
||||
value_pairs += [(col.db_field_name, val)]
|
||||
|
||||
#construct query string
|
||||
field_names = zip(*value_pairs)[0]
|
||||
field_ids = {n:uuid4().hex for n in field_names}
|
||||
field_values = dict(value_pairs)
|
||||
query_values = {field_ids[n]:field_values[n] for n in field_names}
|
||||
return values, field_names, field_ids, field_values, query_values
|
||||
|
||||
def save(self):
|
||||
"""
|
||||
Creates / updates a row.
|
||||
This is a blind insert call.
|
||||
All validation and cleaning needs to happen
|
||||
prior to calling this.
|
||||
"""
|
||||
if self.instance is None:
|
||||
raise CQLEngineException("DML Query intance attribute is None")
|
||||
assert type(self.instance) == self.model
|
||||
|
||||
values, field_names, field_ids, field_values, query_values = self._get_query_values()
|
||||
|
||||
qs = []
|
||||
if self.instance._has_counter or self.instance._can_update():
|
||||
return self.update()
|
||||
else:
|
||||
qs += ["INSERT INTO {}".format(self.column_family_name)]
|
||||
qs += ["({})".format(', '.join(['"{}"'.format(f) for f in field_names]))]
|
||||
qs += ['VALUES']
|
||||
qs += ["({})".format(', '.join([':'+field_ids[f] for f in field_names]))]
|
||||
|
||||
qs = ' '.join(qs)
|
||||
|
||||
# skip query execution if it's empty
|
||||
# caused by pointless update queries
|
||||
if qs:
|
||||
if self._batch:
|
||||
self._batch.add_query(qs, query_values)
|
||||
else:
|
||||
execute(qs, query_values)
|
||||
|
||||
# delete any nulled columns
|
||||
self._delete_null_columns()
|
||||
|
||||
def delete(self):
|
||||
""" Deletes one instance """
|
||||
if self.instance is None:
|
||||
|
||||
@@ -154,6 +154,17 @@ class TestSetColumn(BaseCassEngTestCase):
|
||||
py_val = column.to_python(db_val.value)
|
||||
assert py_val == val
|
||||
|
||||
def test_default_empty_container_saving(self):
|
||||
""" tests that the default empty container is not saved if it hasn't been updated """
|
||||
pkey = uuid4()
|
||||
# create a row with set data
|
||||
TestSetModel.create(partition=pkey, int_set={3, 4})
|
||||
# create another with no set data
|
||||
TestSetModel.create(partition=pkey)
|
||||
|
||||
m = TestSetModel.get(partition=pkey)
|
||||
self.assertEqual(m.int_set, {3, 4})
|
||||
|
||||
|
||||
class TestListModel(Model):
|
||||
partition = columns.UUID(primary_key=True, default=uuid4)
|
||||
@@ -282,6 +293,18 @@ class TestListColumn(BaseCassEngTestCase):
|
||||
py_val = column.to_python(db_val.value)
|
||||
assert py_val == val
|
||||
|
||||
def test_default_empty_container_saving(self):
|
||||
""" tests that the default empty container is not saved if it hasn't been updated """
|
||||
pkey = uuid4()
|
||||
# create a row with list data
|
||||
TestListModel.create(partition=pkey, int_list=[1,2,3,4])
|
||||
# create another with no list data
|
||||
TestListModel.create(partition=pkey)
|
||||
|
||||
m = TestListModel.get(partition=pkey)
|
||||
self.assertEqual(m.int_list, [1,2,3,4])
|
||||
|
||||
|
||||
class TestMapModel(Model):
|
||||
partition = columns.UUID(primary_key=True, default=uuid4)
|
||||
int_map = columns.Map(columns.Integer, columns.UUID, required=False)
|
||||
@@ -309,8 +332,6 @@ class TestMapColumn(BaseCassEngTestCase):
|
||||
tmp2 = TestMapModel.get(partition=tmp.partition)
|
||||
tmp2.int_map['blah'] = 1
|
||||
|
||||
|
||||
|
||||
def test_io_success(self):
|
||||
""" Tests that a basic usage works as expected """
|
||||
k1 = uuid4()
|
||||
@@ -370,7 +391,6 @@ class TestMapColumn(BaseCassEngTestCase):
|
||||
m2 = TestMapModel.get(partition=m.partition)
|
||||
assert m2.int_map == expected
|
||||
|
||||
|
||||
def test_updates_to_none(self):
|
||||
""" Tests that setting the field to None works as expected """
|
||||
m = TestMapModel.create(int_map={1: uuid4()})
|
||||
@@ -406,6 +426,18 @@ class TestMapColumn(BaseCassEngTestCase):
|
||||
py_val = column.to_python(db_val.value)
|
||||
assert py_val == val
|
||||
|
||||
def test_default_empty_container_saving(self):
|
||||
""" tests that the default empty container is not saved if it hasn't been updated """
|
||||
pkey = uuid4()
|
||||
tmap = {1: uuid4(), 2: uuid4()}
|
||||
# create a row with set data
|
||||
TestMapModel.create(partition=pkey, int_map=tmap)
|
||||
# create another with no set data
|
||||
TestMapModel.create(partition=pkey)
|
||||
|
||||
m = TestMapModel.get(partition=pkey)
|
||||
self.assertEqual(m.int_map, tmap)
|
||||
|
||||
# def test_partial_update_creation(self):
|
||||
# """
|
||||
# Tests that proper update statements are created for a partial list update
|
||||
|
||||
91
cqlengine/tests/model/test_updates.py
Normal file
91
cqlengine/tests/model/test_updates.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from mock import patch
|
||||
from cqlengine.exceptions import ValidationError
|
||||
|
||||
from cqlengine.tests.base import BaseCassEngTestCase
|
||||
from cqlengine.models import Model
|
||||
from cqlengine import columns
|
||||
from cqlengine.management import sync_table, drop_table
|
||||
from cqlengine.connection import ConnectionPool
|
||||
|
||||
|
||||
class TestUpdateModel(Model):
|
||||
partition = columns.UUID(primary_key=True, default=uuid4)
|
||||
cluster = columns.UUID(primary_key=True, default=uuid4)
|
||||
count = columns.Integer(required=False)
|
||||
text = columns.Text(required=False, index=True)
|
||||
|
||||
|
||||
class ModelUpdateTests(BaseCassEngTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(ModelUpdateTests, cls).setUpClass()
|
||||
sync_table(TestUpdateModel)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super(ModelUpdateTests, cls).tearDownClass()
|
||||
drop_table(TestUpdateModel)
|
||||
|
||||
def test_update_model(self):
|
||||
""" tests calling udpate on models with no values passed in """
|
||||
m0 = TestUpdateModel.create(count=5, text='monkey')
|
||||
|
||||
# independently save over a new count value, unknown to original instance
|
||||
m1 = TestUpdateModel.get(partition=m0.partition, cluster=m0.cluster)
|
||||
m1.count = 6
|
||||
m1.save()
|
||||
|
||||
# update the text, and call update
|
||||
m0.text = 'monkey land'
|
||||
m0.update()
|
||||
|
||||
# database should reflect both updates
|
||||
m2 = TestUpdateModel.get(partition=m0.partition, cluster=m0.cluster)
|
||||
self.assertEqual(m2.count, m1.count)
|
||||
self.assertEqual(m2.text, m0.text)
|
||||
|
||||
def test_update_values(self):
|
||||
""" tests calling update on models with values passed in """
|
||||
m0 = TestUpdateModel.create(count=5, text='monkey')
|
||||
|
||||
# independently save over a new count value, unknown to original instance
|
||||
m1 = TestUpdateModel.get(partition=m0.partition, cluster=m0.cluster)
|
||||
m1.count = 6
|
||||
m1.save()
|
||||
|
||||
# update the text, and call update
|
||||
m0.update(text='monkey land')
|
||||
self.assertEqual(m0.text, 'monkey land')
|
||||
|
||||
# database should reflect both updates
|
||||
m2 = TestUpdateModel.get(partition=m0.partition, cluster=m0.cluster)
|
||||
self.assertEqual(m2.count, m1.count)
|
||||
self.assertEqual(m2.text, m0.text)
|
||||
|
||||
def test_noop_model_update(self):
|
||||
""" tests that calling update on a model with no changes will do nothing. """
|
||||
m0 = TestUpdateModel.create(count=5, text='monkey')
|
||||
|
||||
with patch.object(ConnectionPool, 'execute') as execute:
|
||||
m0.update()
|
||||
assert execute.call_count == 0
|
||||
|
||||
with patch.object(ConnectionPool, 'execute') as execute:
|
||||
m0.update(count=5)
|
||||
assert execute.call_count == 0
|
||||
|
||||
def test_invalid_update_kwarg(self):
|
||||
""" tests that passing in a kwarg to the update method that isn't a column will fail """
|
||||
m0 = TestUpdateModel.create(count=5, text='monkey')
|
||||
with self.assertRaises(ValidationError):
|
||||
m0.update(numbers=20)
|
||||
|
||||
def test_primary_key_update_failure(self):
|
||||
""" tests that attempting to update the value of a primary key will fail """
|
||||
m0 = TestUpdateModel.create(count=5, text='monkey')
|
||||
with self.assertRaises(ValidationError):
|
||||
m0.update(partition=uuid4())
|
||||
|
||||
114
cqlengine/tests/query/test_updates.py
Normal file
114
cqlengine/tests/query/test_updates.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from uuid import uuid4
|
||||
from cqlengine.exceptions import ValidationError
|
||||
from cqlengine.query import QueryException
|
||||
|
||||
from cqlengine.tests.base import BaseCassEngTestCase
|
||||
from cqlengine.models import Model
|
||||
from cqlengine.management import sync_table, drop_table
|
||||
from cqlengine import columns
|
||||
|
||||
|
||||
class TestQueryUpdateModel(Model):
|
||||
partition = columns.UUID(primary_key=True, default=uuid4)
|
||||
cluster = columns.Integer(primary_key=True)
|
||||
count = columns.Integer(required=False)
|
||||
text = columns.Text(required=False, index=True)
|
||||
|
||||
|
||||
class QueryUpdateTests(BaseCassEngTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(QueryUpdateTests, cls).setUpClass()
|
||||
sync_table(TestQueryUpdateModel)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super(QueryUpdateTests, cls).tearDownClass()
|
||||
drop_table(TestQueryUpdateModel)
|
||||
|
||||
def test_update_values(self):
|
||||
""" tests calling udpate on a queryset """
|
||||
partition = uuid4()
|
||||
for i in range(5):
|
||||
TestQueryUpdateModel.create(partition=partition, cluster=i, count=i, text=str(i))
|
||||
|
||||
# sanity check
|
||||
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
||||
assert row.cluster == i
|
||||
assert row.count == i
|
||||
assert row.text == str(i)
|
||||
|
||||
# perform update
|
||||
TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6)
|
||||
|
||||
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
||||
assert row.cluster == i
|
||||
assert row.count == (6 if i == 3 else i)
|
||||
assert row.text == str(i)
|
||||
|
||||
def test_update_values_validation(self):
|
||||
""" tests calling udpate on models with values passed in """
|
||||
partition = uuid4()
|
||||
for i in range(5):
|
||||
TestQueryUpdateModel.create(partition=partition, cluster=i, count=i, text=str(i))
|
||||
|
||||
# sanity check
|
||||
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
||||
assert row.cluster == i
|
||||
assert row.count == i
|
||||
assert row.text == str(i)
|
||||
|
||||
# perform update
|
||||
with self.assertRaises(ValidationError):
|
||||
TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count='asdf')
|
||||
|
||||
def test_invalid_update_kwarg(self):
|
||||
""" tests that passing in a kwarg to the update method that isn't a column will fail """
|
||||
with self.assertRaises(ValidationError):
|
||||
TestQueryUpdateModel.objects(partition=uuid4(), cluster=3).update(bacon=5000)
|
||||
|
||||
def test_primary_key_update_failure(self):
|
||||
""" tests that attempting to update the value of a primary key will fail """
|
||||
with self.assertRaises(ValidationError):
|
||||
TestQueryUpdateModel.objects(partition=uuid4(), cluster=3).update(cluster=5000)
|
||||
|
||||
def test_null_update_deletes_column(self):
|
||||
""" setting a field to null in the update should issue a delete statement """
|
||||
partition = uuid4()
|
||||
for i in range(5):
|
||||
TestQueryUpdateModel.create(partition=partition, cluster=i, count=i, text=str(i))
|
||||
|
||||
# sanity check
|
||||
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
||||
assert row.cluster == i
|
||||
assert row.count == i
|
||||
assert row.text == str(i)
|
||||
|
||||
# perform update
|
||||
TestQueryUpdateModel.objects(partition=partition, cluster=3).update(text=None)
|
||||
|
||||
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
||||
assert row.cluster == i
|
||||
assert row.count == i
|
||||
assert row.text == (None if i == 3 else str(i))
|
||||
|
||||
def test_mixed_value_and_null_update(self):
|
||||
""" tests that updating a columns value, and removing another works properly """
|
||||
partition = uuid4()
|
||||
for i in range(5):
|
||||
TestQueryUpdateModel.create(partition=partition, cluster=i, count=i, text=str(i))
|
||||
|
||||
# sanity check
|
||||
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
||||
assert row.cluster == i
|
||||
assert row.count == i
|
||||
assert row.text == str(i)
|
||||
|
||||
# perform update
|
||||
TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6, text=None)
|
||||
|
||||
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
|
||||
assert row.cluster == i
|
||||
assert row.count == (6 if i == 3 else i)
|
||||
assert row.text == (None if i == 3 else str(i))
|
||||
@@ -137,6 +137,13 @@ Model Methods
|
||||
.. method:: delete()
|
||||
|
||||
Deletes the object from the database.
|
||||
-- method:: update(**values)
|
||||
|
||||
|
||||
Performs an update on the model instance. You can pass in values to set on the model
|
||||
for updating, or you can call without values to execute an update against any modified
|
||||
fields. If no fields on the model have been modified since loading, no query will be
|
||||
performed. Model validation is performed normally.
|
||||
|
||||
Model Attributes
|
||||
================
|
||||
|
||||
@@ -353,3 +353,16 @@ QuerySet method reference
|
||||
.. method:: allow_filtering()
|
||||
|
||||
Enables the (usually) unwise practive of querying on a clustering key without also defining a partition key
|
||||
|
||||
-- method:: update(**values)
|
||||
|
||||
Performs an update on the row selected by the queryset. Include values to update in the
|
||||
update like so:
|
||||
|
||||
.. code-block:: python
|
||||
Model.objects(key=n).update(value='x')
|
||||
|
||||
Passing in updates for columns which are not part of the model will raise a ValidationError.
|
||||
Per column validation will be performed, but instance level validation will not
|
||||
(`Model.validate` is not called).
|
||||
|
||||
|
||||
Reference in New Issue
Block a user