Merge pull request #284 from timmartin19/master
UPDATE transaction statements
This commit is contained in:
@@ -74,6 +74,33 @@ class QuerySetDescriptor(object):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransactionDescriptor(object):
|
||||
"""
|
||||
returns a query set descriptor
|
||||
"""
|
||||
def __get__(self, instance, model):
|
||||
if instance:
|
||||
def transaction_setter(*prepared_transaction, **unprepared_transactions):
|
||||
if len(prepared_transaction) > 0:
|
||||
transactions = prepared_transaction[0]
|
||||
else:
|
||||
transactions = instance.objects.iff(**unprepared_transactions)._transaction
|
||||
instance._transaction = transactions
|
||||
return instance
|
||||
|
||||
return transaction_setter
|
||||
qs = model.__queryset__(model)
|
||||
|
||||
def transaction_setter(**unprepared_transactions):
|
||||
transactions = model.objects.iff(**unprepared_transactions)._transaction
|
||||
qs._transaction = transactions
|
||||
return qs
|
||||
return transaction_setter
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TTLDescriptor(object):
|
||||
"""
|
||||
returns a query set descriptor
|
||||
@@ -239,6 +266,7 @@ class BaseModel(object):
|
||||
objects = QuerySetDescriptor()
|
||||
ttl = TTLDescriptor()
|
||||
consistency = ConsistencyDescriptor()
|
||||
iff = TransactionDescriptor()
|
||||
|
||||
# custom timestamps, see USING TIMESTAMP X
|
||||
timestamp = TimestampDescriptor()
|
||||
@@ -301,9 +329,10 @@ class BaseModel(object):
|
||||
self._values = {}
|
||||
self._ttl = self.__default_ttl__
|
||||
self._timestamp = None
|
||||
self._transaction = None
|
||||
|
||||
for name, column in self._columns.items():
|
||||
value = values.get(name, None)
|
||||
value = values.get(name, None)
|
||||
if value is not None or isinstance(column, columns.BaseContainerColumn):
|
||||
value = column.to_python(value)
|
||||
value_mngr = column.value_manager(self, column, value)
|
||||
@@ -550,7 +579,8 @@ class BaseModel(object):
|
||||
ttl=self._ttl,
|
||||
timestamp=self._timestamp,
|
||||
consistency=self.__consistency__,
|
||||
if_not_exists=self._if_not_exists).save()
|
||||
if_not_exists=self._if_not_exists,
|
||||
transaction=self._transaction).save()
|
||||
|
||||
#reset the value managers
|
||||
for v in self._values.values():
|
||||
@@ -588,7 +618,8 @@ class BaseModel(object):
|
||||
batch=self._batch,
|
||||
ttl=self._ttl,
|
||||
timestamp=self._timestamp,
|
||||
consistency=self.__consistency__).update()
|
||||
consistency=self.__consistency__,
|
||||
transaction=self._transaction).update()
|
||||
|
||||
#reset the value managers
|
||||
for v in self._values.values():
|
||||
|
@@ -5,7 +5,6 @@ from cqlengine import BaseContainerColumn, Map, columns
|
||||
from cqlengine.columns import Counter, List, Set
|
||||
|
||||
from cqlengine.connection import execute
|
||||
|
||||
from cqlengine.exceptions import CQLEngineException, ValidationError, LWTException
|
||||
from cqlengine.functions import Token, BaseQueryFunction, QueryValue, UnicodeMixin
|
||||
|
||||
@@ -13,7 +12,7 @@ from cqlengine.functions import Token, BaseQueryFunction, QueryValue, UnicodeMix
|
||||
#http://www.datastax.com/docs/1.1/references/cql/index
|
||||
from cqlengine.operators import InOperator, EqualsOperator, GreaterThanOperator, GreaterThanOrEqualOperator
|
||||
from cqlengine.operators import LessThanOperator, LessThanOrEqualOperator, BaseWhereOperator
|
||||
from cqlengine.statements import WhereClause, SelectStatement, DeleteStatement, UpdateStatement, AssignmentClause, InsertStatement, BaseCQLStatement, MapUpdateClause, MapDeleteClause, ListUpdateClause, SetUpdateClause, CounterUpdateClause
|
||||
from cqlengine.statements import WhereClause, SelectStatement, DeleteStatement, UpdateStatement, AssignmentClause, InsertStatement, BaseCQLStatement, MapUpdateClause, MapDeleteClause, ListUpdateClause, SetUpdateClause, CounterUpdateClause, TransactionClause
|
||||
|
||||
|
||||
class QueryException(CQLEngineException): pass
|
||||
@@ -206,6 +205,9 @@ class AbstractQuerySet(object):
|
||||
#Where clause filters
|
||||
self._where = []
|
||||
|
||||
# Transaction clause filters
|
||||
self._transaction = []
|
||||
|
||||
#ordering arguments
|
||||
self._order = []
|
||||
|
||||
@@ -243,6 +245,8 @@ class AbstractQuerySet(object):
|
||||
return self._batch.add_query(q)
|
||||
else:
|
||||
result = execute(q, consistency_level=self._consistency)
|
||||
if self._transaction:
|
||||
check_applied(result)
|
||||
return result
|
||||
|
||||
def __unicode__(self):
|
||||
@@ -407,6 +411,49 @@ class AbstractQuerySet(object):
|
||||
else:
|
||||
raise QueryException("Can't parse '{}'".format(arg))
|
||||
|
||||
def iff(self, *args, **kwargs):
|
||||
"""Adds IF statements to queryset"""
|
||||
if len([x for x in kwargs.values() if x is None]):
|
||||
raise CQLEngineException("None values on iff are not allowed")
|
||||
|
||||
clone = copy.deepcopy(self)
|
||||
for operator in args:
|
||||
if not isinstance(operator, TransactionClause):
|
||||
raise QueryException('{} is not a valid query operator'.format(operator))
|
||||
clone._transaction.append(operator)
|
||||
|
||||
for col_name, val in kwargs.items():
|
||||
exists = False
|
||||
try:
|
||||
column = self.model._get_column(col_name)
|
||||
except KeyError:
|
||||
if col_name == 'pk__token':
|
||||
if not isinstance(val, Token):
|
||||
raise QueryException("Virtual column 'pk__token' may only be compared to Token() values")
|
||||
column = columns._PartitionKeysToken(self.model)
|
||||
quote_field = False
|
||||
else:
|
||||
raise QueryException("Can't resolve column name: '{}'".format(col_name))
|
||||
|
||||
if isinstance(val, Token):
|
||||
if col_name != 'pk__token':
|
||||
raise QueryException("Token() values may only be compared to the 'pk__token' virtual column")
|
||||
partition_columns = column.partition_columns
|
||||
if len(partition_columns) != len(val.value):
|
||||
raise QueryException(
|
||||
'Token() received {} arguments but model has {} partition keys'.format(
|
||||
len(val.value), len(partition_columns)))
|
||||
val.set_columns(partition_columns)
|
||||
|
||||
if isinstance(val, BaseQueryFunction) or exists is True:
|
||||
query_val = val
|
||||
else:
|
||||
query_val = column.to_database(val)
|
||||
|
||||
clone._transaction.append(TransactionClause(col_name, query_val))
|
||||
|
||||
return clone
|
||||
|
||||
def filter(self, *args, **kwargs):
|
||||
"""
|
||||
Adds WHERE arguments to the queryset, returning a new queryset
|
||||
@@ -732,7 +779,8 @@ class ModelQuerySet(AbstractQuerySet):
|
||||
return
|
||||
|
||||
nulled_columns = set()
|
||||
us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, timestamp=self._timestamp)
|
||||
us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl,
|
||||
timestamp=self._timestamp, transactions=self._transaction)
|
||||
for name, val in values.items():
|
||||
col_name, col_op = self._parse_filter_arg(name)
|
||||
col = self.model._columns.get(col_name)
|
||||
@@ -787,7 +835,7 @@ class DMLQuery(object):
|
||||
_timestamp = None
|
||||
_if_not_exists = False
|
||||
|
||||
def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False):
|
||||
def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False, transaction=None):
|
||||
self.model = model
|
||||
self.column_family_name = self.model.column_family_name()
|
||||
self.instance = instance
|
||||
@@ -796,15 +844,15 @@ class DMLQuery(object):
|
||||
self._consistency = consistency
|
||||
self._timestamp = timestamp
|
||||
self._if_not_exists = if_not_exists
|
||||
self._transaction = transaction
|
||||
|
||||
def _execute(self, q):
|
||||
if self._batch:
|
||||
return self._batch.add_query(q)
|
||||
else:
|
||||
tmp = execute(q, consistency_level=self._consistency)
|
||||
if self._if_not_exists:
|
||||
if self._if_not_exists or self._transaction:
|
||||
check_applied(tmp)
|
||||
|
||||
return tmp
|
||||
|
||||
def batch(self, batch_obj):
|
||||
@@ -850,7 +898,8 @@ class DMLQuery(object):
|
||||
raise CQLEngineException("DML Query intance attribute is None")
|
||||
assert type(self.instance) == self.model
|
||||
static_update_only = True
|
||||
statement = UpdateStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp)
|
||||
statement = UpdateStatement(self.column_family_name, ttl=self._ttl,
|
||||
timestamp=self._timestamp, transactions=self._transaction)
|
||||
#get defined fields and their column names
|
||||
for name, col in self.model._columns.items():
|
||||
if not col.is_primary_key:
|
||||
|
@@ -139,6 +139,16 @@ class AssignmentClause(BaseClause):
|
||||
return self.field, self.context_id
|
||||
|
||||
|
||||
class TransactionClause(BaseClause):
|
||||
""" A single variable iff statement """
|
||||
|
||||
def __unicode__(self):
|
||||
return u'"{}" = %({})s'.format(self.field, self.context_id)
|
||||
|
||||
def insert_tuple(self):
|
||||
return self.field, self.context_id
|
||||
|
||||
|
||||
class ContainerUpdateClause(AssignmentClause):
|
||||
|
||||
def __init__(self, field, value, operation=None, previous=None, column=None):
|
||||
@@ -675,6 +685,26 @@ class InsertStatement(AssignmentStatement):
|
||||
class UpdateStatement(AssignmentStatement):
|
||||
""" an cql update select statement """
|
||||
|
||||
def __init__(self,
|
||||
table,
|
||||
assignments=None,
|
||||
consistency=None,
|
||||
where=None,
|
||||
ttl=None,
|
||||
timestamp=None,
|
||||
transactions=None):
|
||||
super(UpdateStatement, self). __init__(table,
|
||||
assignments=assignments,
|
||||
consistency=consistency,
|
||||
where=where,
|
||||
ttl=ttl,
|
||||
timestamp=timestamp)
|
||||
|
||||
# Add iff statements
|
||||
self.transactions = []
|
||||
for transaction in transactions or []:
|
||||
self.add_transaction_clause(transaction)
|
||||
|
||||
def __unicode__(self):
|
||||
qs = ['UPDATE', self.table]
|
||||
|
||||
@@ -695,8 +725,39 @@ class UpdateStatement(AssignmentStatement):
|
||||
if self.where_clauses:
|
||||
qs += [self._where]
|
||||
|
||||
if len(self.transactions) > 0:
|
||||
qs += [self._get_transactions()]
|
||||
|
||||
return ' '.join(qs)
|
||||
|
||||
def add_transaction_clause(self, clause):
|
||||
"""
|
||||
Adds a iff clause to this statement
|
||||
|
||||
:param clause: The clause that will be added to the iff statement
|
||||
:type clause: TransactionClause
|
||||
"""
|
||||
if not isinstance(clause, TransactionClause):
|
||||
raise StatementException('only instances of AssignmentClause can be added to statements')
|
||||
clause.set_context_id(self.context_counter)
|
||||
self.context_counter += clause.get_context_size()
|
||||
self.transactions.append(clause)
|
||||
|
||||
def get_context(self):
|
||||
ctx = super(UpdateStatement, self).get_context()
|
||||
for clause in self.transactions or []:
|
||||
clause.update_context(ctx)
|
||||
return ctx
|
||||
|
||||
def _get_transactions(self):
|
||||
return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.transactions]))
|
||||
|
||||
def update_context_id(self, i):
|
||||
super(UpdateStatement, self).update_context_id(i)
|
||||
for transaction in self.transactions:
|
||||
transaction.set_context_id(self.context_counter)
|
||||
self.context_counter += transaction.get_context_size()
|
||||
|
||||
|
||||
class DeleteStatement(BaseCQLStatement):
|
||||
""" a cql delete statement """
|
||||
|
97
cqlengine/tests/test_transaction.py
Normal file
97
cqlengine/tests/test_transaction.py
Normal file
@@ -0,0 +1,97 @@
|
||||
__author__ = 'Tim Martin'
|
||||
from cqlengine.management import sync_table, drop_table
|
||||
from cqlengine.tests.base import BaseCassEngTestCase
|
||||
from cqlengine.models import Model
|
||||
from cqlengine.exceptions import LWTException
|
||||
from uuid import uuid4
|
||||
from cqlengine import columns, BatchQuery
|
||||
import mock
|
||||
from cqlengine import ALL, BatchQuery
|
||||
from cqlengine.statements import TransactionClause
|
||||
import six
|
||||
|
||||
|
||||
class TestTransactionModel(Model):
|
||||
__keyspace__ = 'test'
|
||||
id = columns.UUID(primary_key=True, default=lambda:uuid4())
|
||||
count = columns.Integer()
|
||||
text = columns.Text(required=False)
|
||||
|
||||
|
||||
class TestTransaction(BaseCassEngTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(TestTransaction, cls).setUpClass()
|
||||
sync_table(TestTransactionModel)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super(TestTransaction, cls).tearDownClass()
|
||||
drop_table(TestTransactionModel)
|
||||
|
||||
def test_update_using_transaction(self):
|
||||
t = TestTransactionModel.create(text='blah blah')
|
||||
t.text = 'new blah'
|
||||
with mock.patch.object(self.session, 'execute') as m:
|
||||
t.iff(text='blah blah').save()
|
||||
|
||||
args = m.call_args
|
||||
self.assertIn('IF "text" = %(0)s', args[0][0].query_string)
|
||||
|
||||
def test_update_transaction_success(self):
|
||||
t = TestTransactionModel.create(text='blah blah', count=5)
|
||||
id = t.id
|
||||
t.text = 'new blah'
|
||||
t.iff(text='blah blah').save()
|
||||
|
||||
updated = TestTransactionModel.objects(id=id).first()
|
||||
self.assertEqual(updated.count, 5)
|
||||
self.assertEqual(updated.text, 'new blah')
|
||||
|
||||
def test_update_failure(self):
|
||||
t = TestTransactionModel.create(text='blah blah')
|
||||
t.text = 'new blah'
|
||||
t = t.iff(text='something wrong')
|
||||
self.assertRaises(LWTException, t.save)
|
||||
|
||||
def test_blind_update(self):
|
||||
t = TestTransactionModel.create(text='blah blah')
|
||||
t.text = 'something else'
|
||||
uid = t.id
|
||||
|
||||
with mock.patch.object(self.session, 'execute') as m:
|
||||
TestTransactionModel.objects(id=uid).iff(text='blah blah').update(text='oh hey der')
|
||||
|
||||
args = m.call_args
|
||||
self.assertIn('IF "text" = %(1)s', args[0][0].query_string)
|
||||
|
||||
def test_blind_update_fail(self):
|
||||
t = TestTransactionModel.create(text='blah blah')
|
||||
t.text = 'something else'
|
||||
uid = t.id
|
||||
qs = TestTransactionModel.objects(id=uid).iff(text='Not dis!')
|
||||
self.assertRaises(LWTException, qs.update, text='this will never work')
|
||||
|
||||
def test_transaction_clause(self):
|
||||
tc = TransactionClause('some_value', 23)
|
||||
tc.set_context_id(3)
|
||||
|
||||
self.assertEqual('"some_value" = %(3)s', six.text_type(tc))
|
||||
self.assertEqual('"some_value" = %(3)s', str(tc))
|
||||
|
||||
def test_batch_update_transaction(self):
|
||||
t = TestTransactionModel.create(text='something', count=5)
|
||||
id = t.id
|
||||
with BatchQuery() as b:
|
||||
t.batch(b).iff(count=5).update(text='something else')
|
||||
|
||||
updated = TestTransactionModel.objects(id=id).first()
|
||||
self.assertEqual(updated.text, 'something else')
|
||||
|
||||
b = BatchQuery()
|
||||
updated.batch(b).iff(count=6).update(text='and another thing')
|
||||
self.assertRaises(LWTException, b.execute)
|
||||
|
||||
updated = TestTransactionModel.objects(id=id).first()
|
||||
self.assertEqual(updated.text, 'something else')
|
@@ -207,6 +207,21 @@ Model Methods
|
||||
|
||||
This method is supported on Cassandra 2.0 or later.
|
||||
|
||||
.. method:: iff(**values)
|
||||
|
||||
Checks to ensure that the values specified are correct on the Cassandra cluster.
|
||||
Simply specify the column(s) and the expected value(s). As with if_not_exists,
|
||||
this incurs a performance cost.
|
||||
|
||||
If the insertion isn't applied, a LWTException is raised
|
||||
|
||||
.. code-block::
|
||||
t = TestTransactionModel(text='some text', count=5)
|
||||
try:
|
||||
t.iff(count=5).update('other text')
|
||||
except LWTException as e:
|
||||
# handle failure
|
||||
|
||||
.. method:: update(**values)
|
||||
|
||||
Performs an update on the model instance. You can pass in values to set on the model
|
||||
|
Reference in New Issue
Block a user