diff --git a/cqlengine/query.py b/cqlengine/query.py index d9bef1aa..38762a8a 100644 --- a/cqlengine/query.py +++ b/cqlengine/query.py @@ -245,6 +245,8 @@ class AbstractQuerySet(object): return self._batch.add_query(q) else: result = execute(q, consistency_level=self._consistency) + if self._transaction: + check_applied(result) return result def __unicode__(self): diff --git a/cqlengine/statements.py b/cqlengine/statements.py index 445df8f2..ab8b4667 100644 --- a/cqlengine/statements.py +++ b/cqlengine/statements.py @@ -148,9 +148,6 @@ class TransactionClause(BaseClause): def insert_tuple(self): return self.field, self.context_id - def update_context(self, ctx): - return super(TransactionClause, self).update_context(ctx) - class ContainerUpdateClause(AssignmentClause): @@ -746,6 +743,12 @@ class UpdateStatement(AssignmentStatement): def _get_transactions(self): return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.transactions])) + def update_context_id(self, i): + super(UpdateStatement, self).update_context_id(i) + for transaction in self.transactions: + transaction.set_context_id(self.context_counter) + self.context_counter += transaction.get_context_size() + class DeleteStatement(BaseCQLStatement): """ a cql delete statement """ diff --git a/cqlengine/tests/statements/test_transaction_statement.py b/cqlengine/tests/statements/test_transaction_statement.py deleted file mode 100644 index 0477cda6..00000000 --- a/cqlengine/tests/statements/test_transaction_statement.py +++ /dev/null @@ -1,14 +0,0 @@ -__author__ = 'Tim Martin' -from unittest import TestCase -from cqlengine.statements import TransactionClause -import six - - -class TestTransactionClause(TestCase): - - def test_normal_transaction(self): - tc = TransactionClause('some_value', 23) - tc.set_context_id(3) - - self.assertEqual('"some_value" = %(3)s', six.text_type(tc)) - self.assertEqual('"some_value" = %(3)s', str(tc)) \ No newline at end of file diff --git a/cqlengine/tests/test_transaction.py b/cqlengine/tests/test_transaction.py index 69e6cb60..43eb488f 100644 --- a/cqlengine/tests/test_transaction.py +++ b/cqlengine/tests/test_transaction.py @@ -4,9 +4,11 @@ from cqlengine.tests.base import BaseCassEngTestCase from cqlengine.models import Model from cqlengine.exceptions import LWTException from uuid import uuid4 -from cqlengine import columns +from cqlengine import columns, BatchQuery import mock from cqlengine import ALL, BatchQuery +from cqlengine.statements import TransactionClause +import six class TestTransactionModel(Model): @@ -37,6 +39,16 @@ class TestTransaction(BaseCassEngTestCase): args = m.call_args self.assertIn('IF "text" = %(0)s', args[0][0].query_string) + def test_update_transaction_success(self): + t = TestTransactionModel.create(text='blah blah', count=5) + id = t.id + t.text = 'new blah' + t.iff(text='blah blah').save() + + updated = TestTransactionModel.objects(id=id).first() + self.assertEqual(updated.count, 5) + self.assertEqual(updated.text, 'new blah') + def test_update_failure(self): t = TestTransactionModel.create(text='blah blah') t.text = 'new blah' @@ -59,4 +71,27 @@ class TestTransaction(BaseCassEngTestCase): t.text = 'something else' uid = t.id qs = TestTransactionModel.objects(id=uid).iff(text='Not dis!') - self.assertRaises(LWTException, qs.update, text='this will never work') \ No newline at end of file + self.assertRaises(LWTException, qs.update, text='this will never work') + + def test_transaction_clause(self): + tc = TransactionClause('some_value', 23) + tc.set_context_id(3) + + self.assertEqual('"some_value" = %(3)s', six.text_type(tc)) + self.assertEqual('"some_value" = %(3)s', str(tc)) + + def test_batch_update_transaction(self): + t = TestTransactionModel.create(text='something', count=5) + id = t.id + with BatchQuery() as b: + t.batch(b).iff(count=5).update(text='something else') + + updated = TestTransactionModel.objects(id=id).first() + self.assertEqual(updated.text, 'something else') + + b = BatchQuery() + updated.batch(b).iff(count=6).update(text='and another thing') + self.assertRaises(LWTException, b.execute) + + updated = TestTransactionModel.objects(id=id).first() + self.assertEqual(updated.text, 'something else') \ No newline at end of file