Merge pull request #284 from timmartin19/master
UPDATE transaction statements
This commit is contained in:
@@ -74,6 +74,33 @@ class QuerySetDescriptor(object):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class TransactionDescriptor(object):
|
||||||
|
"""
|
||||||
|
returns a query set descriptor
|
||||||
|
"""
|
||||||
|
def __get__(self, instance, model):
|
||||||
|
if instance:
|
||||||
|
def transaction_setter(*prepared_transaction, **unprepared_transactions):
|
||||||
|
if len(prepared_transaction) > 0:
|
||||||
|
transactions = prepared_transaction[0]
|
||||||
|
else:
|
||||||
|
transactions = instance.objects.iff(**unprepared_transactions)._transaction
|
||||||
|
instance._transaction = transactions
|
||||||
|
return instance
|
||||||
|
|
||||||
|
return transaction_setter
|
||||||
|
qs = model.__queryset__(model)
|
||||||
|
|
||||||
|
def transaction_setter(**unprepared_transactions):
|
||||||
|
transactions = model.objects.iff(**unprepared_transactions)._transaction
|
||||||
|
qs._transaction = transactions
|
||||||
|
return qs
|
||||||
|
return transaction_setter
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class TTLDescriptor(object):
|
class TTLDescriptor(object):
|
||||||
"""
|
"""
|
||||||
returns a query set descriptor
|
returns a query set descriptor
|
||||||
@@ -239,6 +266,7 @@ class BaseModel(object):
|
|||||||
objects = QuerySetDescriptor()
|
objects = QuerySetDescriptor()
|
||||||
ttl = TTLDescriptor()
|
ttl = TTLDescriptor()
|
||||||
consistency = ConsistencyDescriptor()
|
consistency = ConsistencyDescriptor()
|
||||||
|
iff = TransactionDescriptor()
|
||||||
|
|
||||||
# custom timestamps, see USING TIMESTAMP X
|
# custom timestamps, see USING TIMESTAMP X
|
||||||
timestamp = TimestampDescriptor()
|
timestamp = TimestampDescriptor()
|
||||||
@@ -301,6 +329,7 @@ class BaseModel(object):
|
|||||||
self._values = {}
|
self._values = {}
|
||||||
self._ttl = self.__default_ttl__
|
self._ttl = self.__default_ttl__
|
||||||
self._timestamp = None
|
self._timestamp = None
|
||||||
|
self._transaction = None
|
||||||
|
|
||||||
for name, column in self._columns.items():
|
for name, column in self._columns.items():
|
||||||
value = values.get(name, None)
|
value = values.get(name, None)
|
||||||
@@ -550,7 +579,8 @@ class BaseModel(object):
|
|||||||
ttl=self._ttl,
|
ttl=self._ttl,
|
||||||
timestamp=self._timestamp,
|
timestamp=self._timestamp,
|
||||||
consistency=self.__consistency__,
|
consistency=self.__consistency__,
|
||||||
if_not_exists=self._if_not_exists).save()
|
if_not_exists=self._if_not_exists,
|
||||||
|
transaction=self._transaction).save()
|
||||||
|
|
||||||
#reset the value managers
|
#reset the value managers
|
||||||
for v in self._values.values():
|
for v in self._values.values():
|
||||||
@@ -588,7 +618,8 @@ class BaseModel(object):
|
|||||||
batch=self._batch,
|
batch=self._batch,
|
||||||
ttl=self._ttl,
|
ttl=self._ttl,
|
||||||
timestamp=self._timestamp,
|
timestamp=self._timestamp,
|
||||||
consistency=self.__consistency__).update()
|
consistency=self.__consistency__,
|
||||||
|
transaction=self._transaction).update()
|
||||||
|
|
||||||
#reset the value managers
|
#reset the value managers
|
||||||
for v in self._values.values():
|
for v in self._values.values():
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from cqlengine import BaseContainerColumn, Map, columns
|
|||||||
from cqlengine.columns import Counter, List, Set
|
from cqlengine.columns import Counter, List, Set
|
||||||
|
|
||||||
from cqlengine.connection import execute
|
from cqlengine.connection import execute
|
||||||
|
|
||||||
from cqlengine.exceptions import CQLEngineException, ValidationError, LWTException
|
from cqlengine.exceptions import CQLEngineException, ValidationError, LWTException
|
||||||
from cqlengine.functions import Token, BaseQueryFunction, QueryValue, UnicodeMixin
|
from cqlengine.functions import Token, BaseQueryFunction, QueryValue, UnicodeMixin
|
||||||
|
|
||||||
@@ -13,7 +12,7 @@ from cqlengine.functions import Token, BaseQueryFunction, QueryValue, UnicodeMix
|
|||||||
#http://www.datastax.com/docs/1.1/references/cql/index
|
#http://www.datastax.com/docs/1.1/references/cql/index
|
||||||
from cqlengine.operators import InOperator, EqualsOperator, GreaterThanOperator, GreaterThanOrEqualOperator
|
from cqlengine.operators import InOperator, EqualsOperator, GreaterThanOperator, GreaterThanOrEqualOperator
|
||||||
from cqlengine.operators import LessThanOperator, LessThanOrEqualOperator, BaseWhereOperator
|
from cqlengine.operators import LessThanOperator, LessThanOrEqualOperator, BaseWhereOperator
|
||||||
from cqlengine.statements import WhereClause, SelectStatement, DeleteStatement, UpdateStatement, AssignmentClause, InsertStatement, BaseCQLStatement, MapUpdateClause, MapDeleteClause, ListUpdateClause, SetUpdateClause, CounterUpdateClause
|
from cqlengine.statements import WhereClause, SelectStatement, DeleteStatement, UpdateStatement, AssignmentClause, InsertStatement, BaseCQLStatement, MapUpdateClause, MapDeleteClause, ListUpdateClause, SetUpdateClause, CounterUpdateClause, TransactionClause
|
||||||
|
|
||||||
|
|
||||||
class QueryException(CQLEngineException): pass
|
class QueryException(CQLEngineException): pass
|
||||||
@@ -206,6 +205,9 @@ class AbstractQuerySet(object):
|
|||||||
#Where clause filters
|
#Where clause filters
|
||||||
self._where = []
|
self._where = []
|
||||||
|
|
||||||
|
# Transaction clause filters
|
||||||
|
self._transaction = []
|
||||||
|
|
||||||
#ordering arguments
|
#ordering arguments
|
||||||
self._order = []
|
self._order = []
|
||||||
|
|
||||||
@@ -243,6 +245,8 @@ class AbstractQuerySet(object):
|
|||||||
return self._batch.add_query(q)
|
return self._batch.add_query(q)
|
||||||
else:
|
else:
|
||||||
result = execute(q, consistency_level=self._consistency)
|
result = execute(q, consistency_level=self._consistency)
|
||||||
|
if self._transaction:
|
||||||
|
check_applied(result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def __unicode__(self):
|
def __unicode__(self):
|
||||||
@@ -407,6 +411,49 @@ class AbstractQuerySet(object):
|
|||||||
else:
|
else:
|
||||||
raise QueryException("Can't parse '{}'".format(arg))
|
raise QueryException("Can't parse '{}'".format(arg))
|
||||||
|
|
||||||
|
def iff(self, *args, **kwargs):
|
||||||
|
"""Adds IF statements to queryset"""
|
||||||
|
if len([x for x in kwargs.values() if x is None]):
|
||||||
|
raise CQLEngineException("None values on iff are not allowed")
|
||||||
|
|
||||||
|
clone = copy.deepcopy(self)
|
||||||
|
for operator in args:
|
||||||
|
if not isinstance(operator, TransactionClause):
|
||||||
|
raise QueryException('{} is not a valid query operator'.format(operator))
|
||||||
|
clone._transaction.append(operator)
|
||||||
|
|
||||||
|
for col_name, val in kwargs.items():
|
||||||
|
exists = False
|
||||||
|
try:
|
||||||
|
column = self.model._get_column(col_name)
|
||||||
|
except KeyError:
|
||||||
|
if col_name == 'pk__token':
|
||||||
|
if not isinstance(val, Token):
|
||||||
|
raise QueryException("Virtual column 'pk__token' may only be compared to Token() values")
|
||||||
|
column = columns._PartitionKeysToken(self.model)
|
||||||
|
quote_field = False
|
||||||
|
else:
|
||||||
|
raise QueryException("Can't resolve column name: '{}'".format(col_name))
|
||||||
|
|
||||||
|
if isinstance(val, Token):
|
||||||
|
if col_name != 'pk__token':
|
||||||
|
raise QueryException("Token() values may only be compared to the 'pk__token' virtual column")
|
||||||
|
partition_columns = column.partition_columns
|
||||||
|
if len(partition_columns) != len(val.value):
|
||||||
|
raise QueryException(
|
||||||
|
'Token() received {} arguments but model has {} partition keys'.format(
|
||||||
|
len(val.value), len(partition_columns)))
|
||||||
|
val.set_columns(partition_columns)
|
||||||
|
|
||||||
|
if isinstance(val, BaseQueryFunction) or exists is True:
|
||||||
|
query_val = val
|
||||||
|
else:
|
||||||
|
query_val = column.to_database(val)
|
||||||
|
|
||||||
|
clone._transaction.append(TransactionClause(col_name, query_val))
|
||||||
|
|
||||||
|
return clone
|
||||||
|
|
||||||
def filter(self, *args, **kwargs):
|
def filter(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Adds WHERE arguments to the queryset, returning a new queryset
|
Adds WHERE arguments to the queryset, returning a new queryset
|
||||||
@@ -732,7 +779,8 @@ class ModelQuerySet(AbstractQuerySet):
|
|||||||
return
|
return
|
||||||
|
|
||||||
nulled_columns = set()
|
nulled_columns = set()
|
||||||
us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, timestamp=self._timestamp)
|
us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl,
|
||||||
|
timestamp=self._timestamp, transactions=self._transaction)
|
||||||
for name, val in values.items():
|
for name, val in values.items():
|
||||||
col_name, col_op = self._parse_filter_arg(name)
|
col_name, col_op = self._parse_filter_arg(name)
|
||||||
col = self.model._columns.get(col_name)
|
col = self.model._columns.get(col_name)
|
||||||
@@ -787,7 +835,7 @@ class DMLQuery(object):
|
|||||||
_timestamp = None
|
_timestamp = None
|
||||||
_if_not_exists = False
|
_if_not_exists = False
|
||||||
|
|
||||||
def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False):
|
def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False, transaction=None):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.column_family_name = self.model.column_family_name()
|
self.column_family_name = self.model.column_family_name()
|
||||||
self.instance = instance
|
self.instance = instance
|
||||||
@@ -796,15 +844,15 @@ class DMLQuery(object):
|
|||||||
self._consistency = consistency
|
self._consistency = consistency
|
||||||
self._timestamp = timestamp
|
self._timestamp = timestamp
|
||||||
self._if_not_exists = if_not_exists
|
self._if_not_exists = if_not_exists
|
||||||
|
self._transaction = transaction
|
||||||
|
|
||||||
def _execute(self, q):
|
def _execute(self, q):
|
||||||
if self._batch:
|
if self._batch:
|
||||||
return self._batch.add_query(q)
|
return self._batch.add_query(q)
|
||||||
else:
|
else:
|
||||||
tmp = execute(q, consistency_level=self._consistency)
|
tmp = execute(q, consistency_level=self._consistency)
|
||||||
if self._if_not_exists:
|
if self._if_not_exists or self._transaction:
|
||||||
check_applied(tmp)
|
check_applied(tmp)
|
||||||
|
|
||||||
return tmp
|
return tmp
|
||||||
|
|
||||||
def batch(self, batch_obj):
|
def batch(self, batch_obj):
|
||||||
@@ -850,7 +898,8 @@ class DMLQuery(object):
|
|||||||
raise CQLEngineException("DML Query intance attribute is None")
|
raise CQLEngineException("DML Query intance attribute is None")
|
||||||
assert type(self.instance) == self.model
|
assert type(self.instance) == self.model
|
||||||
static_update_only = True
|
static_update_only = True
|
||||||
statement = UpdateStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp)
|
statement = UpdateStatement(self.column_family_name, ttl=self._ttl,
|
||||||
|
timestamp=self._timestamp, transactions=self._transaction)
|
||||||
#get defined fields and their column names
|
#get defined fields and their column names
|
||||||
for name, col in self.model._columns.items():
|
for name, col in self.model._columns.items():
|
||||||
if not col.is_primary_key:
|
if not col.is_primary_key:
|
||||||
|
|||||||
@@ -139,6 +139,16 @@ class AssignmentClause(BaseClause):
|
|||||||
return self.field, self.context_id
|
return self.field, self.context_id
|
||||||
|
|
||||||
|
|
||||||
|
class TransactionClause(BaseClause):
|
||||||
|
""" A single variable iff statement """
|
||||||
|
|
||||||
|
def __unicode__(self):
|
||||||
|
return u'"{}" = %({})s'.format(self.field, self.context_id)
|
||||||
|
|
||||||
|
def insert_tuple(self):
|
||||||
|
return self.field, self.context_id
|
||||||
|
|
||||||
|
|
||||||
class ContainerUpdateClause(AssignmentClause):
|
class ContainerUpdateClause(AssignmentClause):
|
||||||
|
|
||||||
def __init__(self, field, value, operation=None, previous=None, column=None):
|
def __init__(self, field, value, operation=None, previous=None, column=None):
|
||||||
@@ -675,6 +685,26 @@ class InsertStatement(AssignmentStatement):
|
|||||||
class UpdateStatement(AssignmentStatement):
|
class UpdateStatement(AssignmentStatement):
|
||||||
""" an cql update select statement """
|
""" an cql update select statement """
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
table,
|
||||||
|
assignments=None,
|
||||||
|
consistency=None,
|
||||||
|
where=None,
|
||||||
|
ttl=None,
|
||||||
|
timestamp=None,
|
||||||
|
transactions=None):
|
||||||
|
super(UpdateStatement, self). __init__(table,
|
||||||
|
assignments=assignments,
|
||||||
|
consistency=consistency,
|
||||||
|
where=where,
|
||||||
|
ttl=ttl,
|
||||||
|
timestamp=timestamp)
|
||||||
|
|
||||||
|
# Add iff statements
|
||||||
|
self.transactions = []
|
||||||
|
for transaction in transactions or []:
|
||||||
|
self.add_transaction_clause(transaction)
|
||||||
|
|
||||||
def __unicode__(self):
|
def __unicode__(self):
|
||||||
qs = ['UPDATE', self.table]
|
qs = ['UPDATE', self.table]
|
||||||
|
|
||||||
@@ -695,8 +725,39 @@ class UpdateStatement(AssignmentStatement):
|
|||||||
if self.where_clauses:
|
if self.where_clauses:
|
||||||
qs += [self._where]
|
qs += [self._where]
|
||||||
|
|
||||||
|
if len(self.transactions) > 0:
|
||||||
|
qs += [self._get_transactions()]
|
||||||
|
|
||||||
return ' '.join(qs)
|
return ' '.join(qs)
|
||||||
|
|
||||||
|
def add_transaction_clause(self, clause):
|
||||||
|
"""
|
||||||
|
Adds a iff clause to this statement
|
||||||
|
|
||||||
|
:param clause: The clause that will be added to the iff statement
|
||||||
|
:type clause: TransactionClause
|
||||||
|
"""
|
||||||
|
if not isinstance(clause, TransactionClause):
|
||||||
|
raise StatementException('only instances of AssignmentClause can be added to statements')
|
||||||
|
clause.set_context_id(self.context_counter)
|
||||||
|
self.context_counter += clause.get_context_size()
|
||||||
|
self.transactions.append(clause)
|
||||||
|
|
||||||
|
def get_context(self):
|
||||||
|
ctx = super(UpdateStatement, self).get_context()
|
||||||
|
for clause in self.transactions or []:
|
||||||
|
clause.update_context(ctx)
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
def _get_transactions(self):
|
||||||
|
return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.transactions]))
|
||||||
|
|
||||||
|
def update_context_id(self, i):
|
||||||
|
super(UpdateStatement, self).update_context_id(i)
|
||||||
|
for transaction in self.transactions:
|
||||||
|
transaction.set_context_id(self.context_counter)
|
||||||
|
self.context_counter += transaction.get_context_size()
|
||||||
|
|
||||||
|
|
||||||
class DeleteStatement(BaseCQLStatement):
|
class DeleteStatement(BaseCQLStatement):
|
||||||
""" a cql delete statement """
|
""" a cql delete statement """
|
||||||
|
|||||||
97
cqlengine/tests/test_transaction.py
Normal file
97
cqlengine/tests/test_transaction.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
__author__ = 'Tim Martin'
|
||||||
|
from cqlengine.management import sync_table, drop_table
|
||||||
|
from cqlengine.tests.base import BaseCassEngTestCase
|
||||||
|
from cqlengine.models import Model
|
||||||
|
from cqlengine.exceptions import LWTException
|
||||||
|
from uuid import uuid4
|
||||||
|
from cqlengine import columns, BatchQuery
|
||||||
|
import mock
|
||||||
|
from cqlengine import ALL, BatchQuery
|
||||||
|
from cqlengine.statements import TransactionClause
|
||||||
|
import six
|
||||||
|
|
||||||
|
|
||||||
|
class TestTransactionModel(Model):
|
||||||
|
__keyspace__ = 'test'
|
||||||
|
id = columns.UUID(primary_key=True, default=lambda:uuid4())
|
||||||
|
count = columns.Integer()
|
||||||
|
text = columns.Text(required=False)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTransaction(BaseCassEngTestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
super(TestTransaction, cls).setUpClass()
|
||||||
|
sync_table(TestTransactionModel)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
super(TestTransaction, cls).tearDownClass()
|
||||||
|
drop_table(TestTransactionModel)
|
||||||
|
|
||||||
|
def test_update_using_transaction(self):
|
||||||
|
t = TestTransactionModel.create(text='blah blah')
|
||||||
|
t.text = 'new blah'
|
||||||
|
with mock.patch.object(self.session, 'execute') as m:
|
||||||
|
t.iff(text='blah blah').save()
|
||||||
|
|
||||||
|
args = m.call_args
|
||||||
|
self.assertIn('IF "text" = %(0)s', args[0][0].query_string)
|
||||||
|
|
||||||
|
def test_update_transaction_success(self):
|
||||||
|
t = TestTransactionModel.create(text='blah blah', count=5)
|
||||||
|
id = t.id
|
||||||
|
t.text = 'new blah'
|
||||||
|
t.iff(text='blah blah').save()
|
||||||
|
|
||||||
|
updated = TestTransactionModel.objects(id=id).first()
|
||||||
|
self.assertEqual(updated.count, 5)
|
||||||
|
self.assertEqual(updated.text, 'new blah')
|
||||||
|
|
||||||
|
def test_update_failure(self):
|
||||||
|
t = TestTransactionModel.create(text='blah blah')
|
||||||
|
t.text = 'new blah'
|
||||||
|
t = t.iff(text='something wrong')
|
||||||
|
self.assertRaises(LWTException, t.save)
|
||||||
|
|
||||||
|
def test_blind_update(self):
|
||||||
|
t = TestTransactionModel.create(text='blah blah')
|
||||||
|
t.text = 'something else'
|
||||||
|
uid = t.id
|
||||||
|
|
||||||
|
with mock.patch.object(self.session, 'execute') as m:
|
||||||
|
TestTransactionModel.objects(id=uid).iff(text='blah blah').update(text='oh hey der')
|
||||||
|
|
||||||
|
args = m.call_args
|
||||||
|
self.assertIn('IF "text" = %(1)s', args[0][0].query_string)
|
||||||
|
|
||||||
|
def test_blind_update_fail(self):
|
||||||
|
t = TestTransactionModel.create(text='blah blah')
|
||||||
|
t.text = 'something else'
|
||||||
|
uid = t.id
|
||||||
|
qs = TestTransactionModel.objects(id=uid).iff(text='Not dis!')
|
||||||
|
self.assertRaises(LWTException, qs.update, text='this will never work')
|
||||||
|
|
||||||
|
def test_transaction_clause(self):
|
||||||
|
tc = TransactionClause('some_value', 23)
|
||||||
|
tc.set_context_id(3)
|
||||||
|
|
||||||
|
self.assertEqual('"some_value" = %(3)s', six.text_type(tc))
|
||||||
|
self.assertEqual('"some_value" = %(3)s', str(tc))
|
||||||
|
|
||||||
|
def test_batch_update_transaction(self):
|
||||||
|
t = TestTransactionModel.create(text='something', count=5)
|
||||||
|
id = t.id
|
||||||
|
with BatchQuery() as b:
|
||||||
|
t.batch(b).iff(count=5).update(text='something else')
|
||||||
|
|
||||||
|
updated = TestTransactionModel.objects(id=id).first()
|
||||||
|
self.assertEqual(updated.text, 'something else')
|
||||||
|
|
||||||
|
b = BatchQuery()
|
||||||
|
updated.batch(b).iff(count=6).update(text='and another thing')
|
||||||
|
self.assertRaises(LWTException, b.execute)
|
||||||
|
|
||||||
|
updated = TestTransactionModel.objects(id=id).first()
|
||||||
|
self.assertEqual(updated.text, 'something else')
|
||||||
@@ -207,6 +207,21 @@ Model Methods
|
|||||||
|
|
||||||
This method is supported on Cassandra 2.0 or later.
|
This method is supported on Cassandra 2.0 or later.
|
||||||
|
|
||||||
|
.. method:: iff(**values)
|
||||||
|
|
||||||
|
Checks to ensure that the values specified are correct on the Cassandra cluster.
|
||||||
|
Simply specify the column(s) and the expected value(s). As with if_not_exists,
|
||||||
|
this incurs a performance cost.
|
||||||
|
|
||||||
|
If the insertion isn't applied, a LWTException is raised
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
t = TestTransactionModel(text='some text', count=5)
|
||||||
|
try:
|
||||||
|
t.iff(count=5).update('other text')
|
||||||
|
except LWTException as e:
|
||||||
|
# handle failure
|
||||||
|
|
||||||
.. method:: update(**values)
|
.. method:: update(**values)
|
||||||
|
|
||||||
Performs an update on the model instance. You can pass in values to set on the model
|
Performs an update on the model instance. You can pass in values to set on the model
|
||||||
|
|||||||
Reference in New Issue
Block a user