merged in updates branch again, got consistency blind updates working
This commit is contained in:
		@@ -2,6 +2,8 @@ CHANGELOG
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
0.9
 | 
					0.9
 | 
				
			||||||
* adding update method
 | 
					* adding update method
 | 
				
			||||||
 | 
					* adding BigInt column (thanks @Lifto)
 | 
				
			||||||
 | 
					* adding support for timezone aware time uuid functions (thanks @dokai)
 | 
				
			||||||
* only saving collection fields on insert if they've been modified
 | 
					* only saving collection fields on insert if they've been modified
 | 
				
			||||||
 | 
					
 | 
				
			||||||
0.8.5
 | 
					0.8.5
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,7 +9,7 @@ from cqlengine.columns import Counter
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from cqlengine.connection import connection_manager, execute, RowResult
 | 
					from cqlengine.connection import connection_manager, execute, RowResult
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from cqlengine.exceptions import CQLEngineException
 | 
					from cqlengine.exceptions import CQLEngineException, ValidationError
 | 
				
			||||||
from cqlengine.functions import QueryValue, Token
 | 
					from cqlengine.functions import QueryValue, Token
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#CQL 3 reference:
 | 
					#CQL 3 reference:
 | 
				
			||||||
@@ -641,13 +641,6 @@ class AbstractQuerySet(object):
 | 
				
			|||||||
        else:
 | 
					        else:
 | 
				
			||||||
            execute(qs, self._where_values())
 | 
					            execute(qs, self._where_values())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update(self, **values):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        updates the contents of the query
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        qs = ['UPDATE {}'.format(self.column_family_name)]
 | 
					 | 
				
			||||||
        qs += ['SET']
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __eq__(self, q):
 | 
					    def __eq__(self, q):
 | 
				
			||||||
        return set(self._where) == set(q._where)
 | 
					        return set(self._where) == set(q._where)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -772,6 +765,62 @@ class ModelQuerySet(AbstractQuerySet):
 | 
				
			|||||||
        clone._flat_values_list = flat
 | 
					        clone._flat_values_list = flat
 | 
				
			||||||
        return clone
 | 
					        return clone
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def consistency(self, consistency):
 | 
				
			||||||
 | 
					        self._consistency = consistency
 | 
				
			||||||
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def ttl(self, ttl):
 | 
				
			||||||
 | 
					        self._ttl = ttl
 | 
				
			||||||
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def update(self, **values):
 | 
				
			||||||
 | 
					        """ Updates the rows in this queryset """
 | 
				
			||||||
 | 
					        if not values:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        set_statements = []
 | 
				
			||||||
 | 
					        ctx = {}
 | 
				
			||||||
 | 
					        nulled_columns = set()
 | 
				
			||||||
 | 
					        for name, val in values.items():
 | 
				
			||||||
 | 
					            col = self.model._columns.get(name)
 | 
				
			||||||
 | 
					            # check for nonexistant columns
 | 
				
			||||||
 | 
					            if col is None:
 | 
				
			||||||
 | 
					                raise ValidationError("{}.{} has no column named: {}".format(self.__module__, self.model.__name__, name))
 | 
				
			||||||
 | 
					            # check for primary key update attempts
 | 
				
			||||||
 | 
					            if col.is_primary_key:
 | 
				
			||||||
 | 
					                raise ValidationError("Cannot apply update to primary key '{}' for {}.{}".format(name, self.__module__, self.model.__name__))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            val = col.validate(val)
 | 
				
			||||||
 | 
					            if val is None:
 | 
				
			||||||
 | 
					                nulled_columns.add(name)
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					            # add the update statements
 | 
				
			||||||
 | 
					            if isinstance(col, (BaseContainerColumn, Counter)):
 | 
				
			||||||
 | 
					                val_mgr = self.instance._values[name]
 | 
				
			||||||
 | 
					                set_statements += col.get_update_statement(val, val_mgr.previous_value, ctx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                field_id = uuid4().hex
 | 
				
			||||||
 | 
					                set_statements += ['"{}" = :{}'.format(col.db_field_name, field_id)]
 | 
				
			||||||
 | 
					                ctx[field_id] = val
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if set_statements:
 | 
				
			||||||
 | 
					            qs = "UPDATE {} SET {} WHERE {}".format(
 | 
				
			||||||
 | 
					                self.column_family_name,
 | 
				
			||||||
 | 
					                ', '.join(set_statements),
 | 
				
			||||||
 | 
					                self._where_clause()
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            ctx.update(self._where_values())
 | 
				
			||||||
 | 
					            execute(qs, ctx, self._consistency)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if nulled_columns:
 | 
				
			||||||
 | 
					            qs = "DELETE {} FROM {} WHERE {}".format(
 | 
				
			||||||
 | 
					                ', '.join(nulled_columns),
 | 
				
			||||||
 | 
					                self.column_family_name,
 | 
				
			||||||
 | 
					                self._where_clause()
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            execute(qs, self._where_values(), self._consistency)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DMLQuery(object):
 | 
					class DMLQuery(object):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										114
									
								
								cqlengine/tests/query/test_updates.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								cqlengine/tests/query/test_updates.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,114 @@
 | 
				
			|||||||
 | 
					from uuid import uuid4
 | 
				
			||||||
 | 
					from cqlengine.exceptions import ValidationError
 | 
				
			||||||
 | 
					from cqlengine.query import QueryException
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from cqlengine.tests.base import BaseCassEngTestCase
 | 
				
			||||||
 | 
					from cqlengine.models import Model
 | 
				
			||||||
 | 
					from cqlengine.management import sync_table, drop_table
 | 
				
			||||||
 | 
					from cqlengine import columns
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestQueryUpdateModel(Model):
 | 
				
			||||||
 | 
					    partition   = columns.UUID(primary_key=True, default=uuid4)
 | 
				
			||||||
 | 
					    cluster     = columns.Integer(primary_key=True)
 | 
				
			||||||
 | 
					    count       = columns.Integer(required=False)
 | 
				
			||||||
 | 
					    text        = columns.Text(required=False, index=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class QueryUpdateTests(BaseCassEngTestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def setUpClass(cls):
 | 
				
			||||||
 | 
					        super(QueryUpdateTests, cls).setUpClass()
 | 
				
			||||||
 | 
					        sync_table(TestQueryUpdateModel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def tearDownClass(cls):
 | 
				
			||||||
 | 
					        super(QueryUpdateTests, cls).tearDownClass()
 | 
				
			||||||
 | 
					        drop_table(TestQueryUpdateModel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_update_values(self):
 | 
				
			||||||
 | 
					        """ tests calling udpate on a queryset """
 | 
				
			||||||
 | 
					        partition = uuid4()
 | 
				
			||||||
 | 
					        for i in range(5):
 | 
				
			||||||
 | 
					            TestQueryUpdateModel.create(partition=partition, cluster=i, count=i, text=str(i))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # sanity check
 | 
				
			||||||
 | 
					        for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
 | 
				
			||||||
 | 
					            assert row.cluster == i
 | 
				
			||||||
 | 
					            assert row.count == i
 | 
				
			||||||
 | 
					            assert row.text == str(i)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # perform update
 | 
				
			||||||
 | 
					        TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
 | 
				
			||||||
 | 
					            assert row.cluster == i
 | 
				
			||||||
 | 
					            assert row.count == (6 if i == 3 else i)
 | 
				
			||||||
 | 
					            assert row.text == str(i)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_update_values_validation(self):
 | 
				
			||||||
 | 
					        """ tests calling udpate on models with values passed in """
 | 
				
			||||||
 | 
					        partition = uuid4()
 | 
				
			||||||
 | 
					        for i in range(5):
 | 
				
			||||||
 | 
					            TestQueryUpdateModel.create(partition=partition, cluster=i, count=i, text=str(i))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # sanity check
 | 
				
			||||||
 | 
					        for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
 | 
				
			||||||
 | 
					            assert row.cluster == i
 | 
				
			||||||
 | 
					            assert row.count == i
 | 
				
			||||||
 | 
					            assert row.text == str(i)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # perform update
 | 
				
			||||||
 | 
					        with self.assertRaises(ValidationError):
 | 
				
			||||||
 | 
					            TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count='asdf')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_invalid_update_kwarg(self):
 | 
				
			||||||
 | 
					        """ tests that passing in a kwarg to the update method that isn't a column will fail """
 | 
				
			||||||
 | 
					        with self.assertRaises(ValidationError):
 | 
				
			||||||
 | 
					            TestQueryUpdateModel.objects(partition=uuid4(), cluster=3).update(bacon=5000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_primary_key_update_failure(self):
 | 
				
			||||||
 | 
					        """ tests that attempting to update the value of a primary key will fail """
 | 
				
			||||||
 | 
					        with self.assertRaises(ValidationError):
 | 
				
			||||||
 | 
					            TestQueryUpdateModel.objects(partition=uuid4(), cluster=3).update(cluster=5000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_null_update_deletes_column(self):
 | 
				
			||||||
 | 
					        """ setting a field to null in the update should issue a delete statement """
 | 
				
			||||||
 | 
					        partition = uuid4()
 | 
				
			||||||
 | 
					        for i in range(5):
 | 
				
			||||||
 | 
					            TestQueryUpdateModel.create(partition=partition, cluster=i, count=i, text=str(i))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # sanity check
 | 
				
			||||||
 | 
					        for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
 | 
				
			||||||
 | 
					            assert row.cluster == i
 | 
				
			||||||
 | 
					            assert row.count == i
 | 
				
			||||||
 | 
					            assert row.text == str(i)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # perform update
 | 
				
			||||||
 | 
					        TestQueryUpdateModel.objects(partition=partition, cluster=3).update(text=None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
 | 
				
			||||||
 | 
					            assert row.cluster == i
 | 
				
			||||||
 | 
					            assert row.count == i
 | 
				
			||||||
 | 
					            assert row.text == (None if i == 3 else str(i))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_mixed_value_and_null_update(self):
 | 
				
			||||||
 | 
					        """ tests that updating a columns value, and removing another works properly """
 | 
				
			||||||
 | 
					        partition = uuid4()
 | 
				
			||||||
 | 
					        for i in range(5):
 | 
				
			||||||
 | 
					            TestQueryUpdateModel.create(partition=partition, cluster=i, count=i, text=str(i))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # sanity check
 | 
				
			||||||
 | 
					        for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
 | 
				
			||||||
 | 
					            assert row.cluster == i
 | 
				
			||||||
 | 
					            assert row.count == i
 | 
				
			||||||
 | 
					            assert row.text == str(i)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # perform update
 | 
				
			||||||
 | 
					        TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6, text=None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
 | 
				
			||||||
 | 
					            assert row.cluster == i
 | 
				
			||||||
 | 
					            assert row.count == (6 if i == 3 else i)
 | 
				
			||||||
 | 
					            assert row.text == (None if i == 3 else str(i))
 | 
				
			||||||
@@ -72,7 +72,7 @@ class TestConsistency(BaseConsistencyTest):
 | 
				
			|||||||
        uid = t.id
 | 
					        uid = t.id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        with mock.patch.object(ConnectionPool, 'execute') as m:
 | 
					        with mock.patch.object(ConnectionPool, 'execute') as m:
 | 
				
			||||||
            TestConsistencyModel.objects(id=uid).update(text="grilled cheese")
 | 
					            TestConsistencyModel.objects(id=uid).consistency(ALL).update(text="grilled cheese")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        args = m.call_args
 | 
					        args = m.call_args
 | 
				
			||||||
        self.assertEqual(ALL, args[0][2])
 | 
					        self.assertEqual(ALL, args[0][2])
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -53,6 +53,7 @@ class TTLModelTests(BaseTTLTest):
 | 
				
			|||||||
        self.assertTrue(isinstance(qs, TestTTLModel.__queryset__), type(qs))
 | 
					        self.assertTrue(isinstance(qs, TestTTLModel.__queryset__), type(qs))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TTLInstanceUpdateTest(BaseTTLTest):
 | 
					class TTLInstanceUpdateTest(BaseTTLTest):
 | 
				
			||||||
    def test_update_includes_ttl(self):
 | 
					    def test_update_includes_ttl(self):
 | 
				
			||||||
        model = TestTTLModel.create(text="goodbye blake")
 | 
					        model = TestTTLModel.create(text="goodbye blake")
 | 
				
			||||||
@@ -87,6 +88,17 @@ class TTLInstanceTest(BaseTTLTest):
 | 
				
			|||||||
        self.assertIn("USING TTL", query)
 | 
					        self.assertIn("USING TTL", query)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TTLBlindUpdateTest(BaseTTLTest):
 | 
				
			||||||
 | 
					    def test_ttl_included_with_blind_update(self):
 | 
				
			||||||
 | 
					        o = TestTTLModel.create(text="whatever")
 | 
				
			||||||
 | 
					        tid = o.id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with mock.patch.object(ConnectionPool, 'execute') as m:
 | 
				
			||||||
 | 
					            TestTTLModel.objects(id=tid).ttl(60).update(text="bacon")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        query = m.call_args[0][0]
 | 
				
			||||||
 | 
					        self.assertIn("USING TTL", query)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class QuerySetTTLFragmentTest(BaseTTLTest):
 | 
					 | 
				
			||||||
    pass
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -149,6 +149,7 @@ Model Methods
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    -- method:: update(**values)
 | 
					    -- method:: update(**values)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Performs an update on the model instance. You can pass in values to set on the model
 | 
					        Performs an update on the model instance. You can pass in values to set on the model
 | 
				
			||||||
        for updating, or you can call without values to execute an update against any modified
 | 
					        for updating, or you can call without values to execute an update against any modified
 | 
				
			||||||
        fields. If no fields on the model have been modified since loading, no query will be
 | 
					        fields. If no fields on the model have been modified since loading, no query will be
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -358,7 +358,19 @@ QuerySet method reference
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        Sets the batch object to run the query on. Note that running a select query with a batch object will raise an exception
 | 
					        Sets the batch object to run the query on. Note that running a select query with a batch object will raise an exception
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    .. method:: ttl(batch_object)
 | 
					    .. method:: ttl(ttl_in_seconds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Sets the ttl to run the query query with. Note that running a select query with a ttl value will raise an exception
 | 
					        Sets the ttl to run the query query with. Note that running a select query with a ttl value will raise an exception
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    -- method:: update(**values)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Performs an update on the row selected by the queryset. Include values to update in the
 | 
				
			||||||
 | 
					        update like so:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        .. code-block:: python
 | 
				
			||||||
 | 
					            Model.objects(key=n).update(value='x')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Passing in updates for columns which are not part of the model will raise a ValidationError.
 | 
				
			||||||
 | 
					        Per column validation will be performed, but instance level validation will not
 | 
				
			||||||
 | 
					        (`Model.validate` is not called).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user