Merge pull request #284 from timmartin19/master

UPDATE transaction statements
This commit is contained in:
Jon Haddad
2014-11-14 11:51:16 -08:00
5 changed files with 263 additions and 10 deletions

View File

@@ -74,6 +74,33 @@ class QuerySetDescriptor(object):
raise NotImplementedError raise NotImplementedError
class TransactionDescriptor(object):
"""
returns a query set descriptor
"""
def __get__(self, instance, model):
if instance:
def transaction_setter(*prepared_transaction, **unprepared_transactions):
if len(prepared_transaction) > 0:
transactions = prepared_transaction[0]
else:
transactions = instance.objects.iff(**unprepared_transactions)._transaction
instance._transaction = transactions
return instance
return transaction_setter
qs = model.__queryset__(model)
def transaction_setter(**unprepared_transactions):
transactions = model.objects.iff(**unprepared_transactions)._transaction
qs._transaction = transactions
return qs
return transaction_setter
def __call__(self, *args, **kwargs):
raise NotImplementedError
class TTLDescriptor(object): class TTLDescriptor(object):
""" """
returns a query set descriptor returns a query set descriptor
@@ -239,6 +266,7 @@ class BaseModel(object):
objects = QuerySetDescriptor() objects = QuerySetDescriptor()
ttl = TTLDescriptor() ttl = TTLDescriptor()
consistency = ConsistencyDescriptor() consistency = ConsistencyDescriptor()
iff = TransactionDescriptor()
# custom timestamps, see USING TIMESTAMP X # custom timestamps, see USING TIMESTAMP X
timestamp = TimestampDescriptor() timestamp = TimestampDescriptor()
@@ -301,6 +329,7 @@ class BaseModel(object):
self._values = {} self._values = {}
self._ttl = self.__default_ttl__ self._ttl = self.__default_ttl__
self._timestamp = None self._timestamp = None
self._transaction = None
for name, column in self._columns.items(): for name, column in self._columns.items():
value = values.get(name, None) value = values.get(name, None)
@@ -550,7 +579,8 @@ class BaseModel(object):
ttl=self._ttl, ttl=self._ttl,
timestamp=self._timestamp, timestamp=self._timestamp,
consistency=self.__consistency__, consistency=self.__consistency__,
if_not_exists=self._if_not_exists).save() if_not_exists=self._if_not_exists,
transaction=self._transaction).save()
#reset the value managers #reset the value managers
for v in self._values.values(): for v in self._values.values():
@@ -588,7 +618,8 @@ class BaseModel(object):
batch=self._batch, batch=self._batch,
ttl=self._ttl, ttl=self._ttl,
timestamp=self._timestamp, timestamp=self._timestamp,
consistency=self.__consistency__).update() consistency=self.__consistency__,
transaction=self._transaction).update()
#reset the value managers #reset the value managers
for v in self._values.values(): for v in self._values.values():

View File

@@ -5,7 +5,6 @@ from cqlengine import BaseContainerColumn, Map, columns
from cqlengine.columns import Counter, List, Set from cqlengine.columns import Counter, List, Set
from cqlengine.connection import execute from cqlengine.connection import execute
from cqlengine.exceptions import CQLEngineException, ValidationError, LWTException from cqlengine.exceptions import CQLEngineException, ValidationError, LWTException
from cqlengine.functions import Token, BaseQueryFunction, QueryValue, UnicodeMixin from cqlengine.functions import Token, BaseQueryFunction, QueryValue, UnicodeMixin
@@ -13,7 +12,7 @@ from cqlengine.functions import Token, BaseQueryFunction, QueryValue, UnicodeMix
#http://www.datastax.com/docs/1.1/references/cql/index #http://www.datastax.com/docs/1.1/references/cql/index
from cqlengine.operators import InOperator, EqualsOperator, GreaterThanOperator, GreaterThanOrEqualOperator from cqlengine.operators import InOperator, EqualsOperator, GreaterThanOperator, GreaterThanOrEqualOperator
from cqlengine.operators import LessThanOperator, LessThanOrEqualOperator, BaseWhereOperator from cqlengine.operators import LessThanOperator, LessThanOrEqualOperator, BaseWhereOperator
from cqlengine.statements import WhereClause, SelectStatement, DeleteStatement, UpdateStatement, AssignmentClause, InsertStatement, BaseCQLStatement, MapUpdateClause, MapDeleteClause, ListUpdateClause, SetUpdateClause, CounterUpdateClause from cqlengine.statements import WhereClause, SelectStatement, DeleteStatement, UpdateStatement, AssignmentClause, InsertStatement, BaseCQLStatement, MapUpdateClause, MapDeleteClause, ListUpdateClause, SetUpdateClause, CounterUpdateClause, TransactionClause
class QueryException(CQLEngineException): pass class QueryException(CQLEngineException): pass
@@ -206,6 +205,9 @@ class AbstractQuerySet(object):
#Where clause filters #Where clause filters
self._where = [] self._where = []
# Transaction clause filters
self._transaction = []
#ordering arguments #ordering arguments
self._order = [] self._order = []
@@ -243,6 +245,8 @@ class AbstractQuerySet(object):
return self._batch.add_query(q) return self._batch.add_query(q)
else: else:
result = execute(q, consistency_level=self._consistency) result = execute(q, consistency_level=self._consistency)
if self._transaction:
check_applied(result)
return result return result
def __unicode__(self): def __unicode__(self):
@@ -407,6 +411,49 @@ class AbstractQuerySet(object):
else: else:
raise QueryException("Can't parse '{}'".format(arg)) raise QueryException("Can't parse '{}'".format(arg))
def iff(self, *args, **kwargs):
"""Adds IF statements to queryset"""
if len([x for x in kwargs.values() if x is None]):
raise CQLEngineException("None values on iff are not allowed")
clone = copy.deepcopy(self)
for operator in args:
if not isinstance(operator, TransactionClause):
raise QueryException('{} is not a valid query operator'.format(operator))
clone._transaction.append(operator)
for col_name, val in kwargs.items():
exists = False
try:
column = self.model._get_column(col_name)
except KeyError:
if col_name == 'pk__token':
if not isinstance(val, Token):
raise QueryException("Virtual column 'pk__token' may only be compared to Token() values")
column = columns._PartitionKeysToken(self.model)
quote_field = False
else:
raise QueryException("Can't resolve column name: '{}'".format(col_name))
if isinstance(val, Token):
if col_name != 'pk__token':
raise QueryException("Token() values may only be compared to the 'pk__token' virtual column")
partition_columns = column.partition_columns
if len(partition_columns) != len(val.value):
raise QueryException(
'Token() received {} arguments but model has {} partition keys'.format(
len(val.value), len(partition_columns)))
val.set_columns(partition_columns)
if isinstance(val, BaseQueryFunction) or exists is True:
query_val = val
else:
query_val = column.to_database(val)
clone._transaction.append(TransactionClause(col_name, query_val))
return clone
def filter(self, *args, **kwargs): def filter(self, *args, **kwargs):
""" """
Adds WHERE arguments to the queryset, returning a new queryset Adds WHERE arguments to the queryset, returning a new queryset
@@ -732,7 +779,8 @@ class ModelQuerySet(AbstractQuerySet):
return return
nulled_columns = set() nulled_columns = set()
us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, timestamp=self._timestamp) us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl,
timestamp=self._timestamp, transactions=self._transaction)
for name, val in values.items(): for name, val in values.items():
col_name, col_op = self._parse_filter_arg(name) col_name, col_op = self._parse_filter_arg(name)
col = self.model._columns.get(col_name) col = self.model._columns.get(col_name)
@@ -787,7 +835,7 @@ class DMLQuery(object):
_timestamp = None _timestamp = None
_if_not_exists = False _if_not_exists = False
def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False): def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False, transaction=None):
self.model = model self.model = model
self.column_family_name = self.model.column_family_name() self.column_family_name = self.model.column_family_name()
self.instance = instance self.instance = instance
@@ -796,15 +844,15 @@ class DMLQuery(object):
self._consistency = consistency self._consistency = consistency
self._timestamp = timestamp self._timestamp = timestamp
self._if_not_exists = if_not_exists self._if_not_exists = if_not_exists
self._transaction = transaction
def _execute(self, q): def _execute(self, q):
if self._batch: if self._batch:
return self._batch.add_query(q) return self._batch.add_query(q)
else: else:
tmp = execute(q, consistency_level=self._consistency) tmp = execute(q, consistency_level=self._consistency)
if self._if_not_exists: if self._if_not_exists or self._transaction:
check_applied(tmp) check_applied(tmp)
return tmp return tmp
def batch(self, batch_obj): def batch(self, batch_obj):
@@ -850,7 +898,8 @@ class DMLQuery(object):
raise CQLEngineException("DML Query intance attribute is None") raise CQLEngineException("DML Query intance attribute is None")
assert type(self.instance) == self.model assert type(self.instance) == self.model
static_update_only = True static_update_only = True
statement = UpdateStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp) statement = UpdateStatement(self.column_family_name, ttl=self._ttl,
timestamp=self._timestamp, transactions=self._transaction)
#get defined fields and their column names #get defined fields and their column names
for name, col in self.model._columns.items(): for name, col in self.model._columns.items():
if not col.is_primary_key: if not col.is_primary_key:

View File

@@ -139,6 +139,16 @@ class AssignmentClause(BaseClause):
return self.field, self.context_id return self.field, self.context_id
class TransactionClause(BaseClause):
""" A single variable iff statement """
def __unicode__(self):
return u'"{}" = %({})s'.format(self.field, self.context_id)
def insert_tuple(self):
return self.field, self.context_id
class ContainerUpdateClause(AssignmentClause): class ContainerUpdateClause(AssignmentClause):
def __init__(self, field, value, operation=None, previous=None, column=None): def __init__(self, field, value, operation=None, previous=None, column=None):
@@ -675,6 +685,26 @@ class InsertStatement(AssignmentStatement):
class UpdateStatement(AssignmentStatement): class UpdateStatement(AssignmentStatement):
""" an cql update select statement """ """ an cql update select statement """
def __init__(self,
table,
assignments=None,
consistency=None,
where=None,
ttl=None,
timestamp=None,
transactions=None):
super(UpdateStatement, self). __init__(table,
assignments=assignments,
consistency=consistency,
where=where,
ttl=ttl,
timestamp=timestamp)
# Add iff statements
self.transactions = []
for transaction in transactions or []:
self.add_transaction_clause(transaction)
def __unicode__(self): def __unicode__(self):
qs = ['UPDATE', self.table] qs = ['UPDATE', self.table]
@@ -695,8 +725,39 @@ class UpdateStatement(AssignmentStatement):
if self.where_clauses: if self.where_clauses:
qs += [self._where] qs += [self._where]
if len(self.transactions) > 0:
qs += [self._get_transactions()]
return ' '.join(qs) return ' '.join(qs)
def add_transaction_clause(self, clause):
"""
Adds a iff clause to this statement
:param clause: The clause that will be added to the iff statement
:type clause: TransactionClause
"""
if not isinstance(clause, TransactionClause):
raise StatementException('only instances of AssignmentClause can be added to statements')
clause.set_context_id(self.context_counter)
self.context_counter += clause.get_context_size()
self.transactions.append(clause)
def get_context(self):
ctx = super(UpdateStatement, self).get_context()
for clause in self.transactions or []:
clause.update_context(ctx)
return ctx
def _get_transactions(self):
return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.transactions]))
def update_context_id(self, i):
super(UpdateStatement, self).update_context_id(i)
for transaction in self.transactions:
transaction.set_context_id(self.context_counter)
self.context_counter += transaction.get_context_size()
class DeleteStatement(BaseCQLStatement): class DeleteStatement(BaseCQLStatement):
""" a cql delete statement """ """ a cql delete statement """

View File

@@ -0,0 +1,97 @@
__author__ = 'Tim Martin'
from cqlengine.management import sync_table, drop_table
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.models import Model
from cqlengine.exceptions import LWTException
from uuid import uuid4
from cqlengine import columns, BatchQuery
import mock
from cqlengine import ALL, BatchQuery
from cqlengine.statements import TransactionClause
import six
class TestTransactionModel(Model):
__keyspace__ = 'test'
id = columns.UUID(primary_key=True, default=lambda:uuid4())
count = columns.Integer()
text = columns.Text(required=False)
class TestTransaction(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestTransaction, cls).setUpClass()
sync_table(TestTransactionModel)
@classmethod
def tearDownClass(cls):
super(TestTransaction, cls).tearDownClass()
drop_table(TestTransactionModel)
def test_update_using_transaction(self):
t = TestTransactionModel.create(text='blah blah')
t.text = 'new blah'
with mock.patch.object(self.session, 'execute') as m:
t.iff(text='blah blah').save()
args = m.call_args
self.assertIn('IF "text" = %(0)s', args[0][0].query_string)
def test_update_transaction_success(self):
t = TestTransactionModel.create(text='blah blah', count=5)
id = t.id
t.text = 'new blah'
t.iff(text='blah blah').save()
updated = TestTransactionModel.objects(id=id).first()
self.assertEqual(updated.count, 5)
self.assertEqual(updated.text, 'new blah')
def test_update_failure(self):
t = TestTransactionModel.create(text='blah blah')
t.text = 'new blah'
t = t.iff(text='something wrong')
self.assertRaises(LWTException, t.save)
def test_blind_update(self):
t = TestTransactionModel.create(text='blah blah')
t.text = 'something else'
uid = t.id
with mock.patch.object(self.session, 'execute') as m:
TestTransactionModel.objects(id=uid).iff(text='blah blah').update(text='oh hey der')
args = m.call_args
self.assertIn('IF "text" = %(1)s', args[0][0].query_string)
def test_blind_update_fail(self):
t = TestTransactionModel.create(text='blah blah')
t.text = 'something else'
uid = t.id
qs = TestTransactionModel.objects(id=uid).iff(text='Not dis!')
self.assertRaises(LWTException, qs.update, text='this will never work')
def test_transaction_clause(self):
tc = TransactionClause('some_value', 23)
tc.set_context_id(3)
self.assertEqual('"some_value" = %(3)s', six.text_type(tc))
self.assertEqual('"some_value" = %(3)s', str(tc))
def test_batch_update_transaction(self):
t = TestTransactionModel.create(text='something', count=5)
id = t.id
with BatchQuery() as b:
t.batch(b).iff(count=5).update(text='something else')
updated = TestTransactionModel.objects(id=id).first()
self.assertEqual(updated.text, 'something else')
b = BatchQuery()
updated.batch(b).iff(count=6).update(text='and another thing')
self.assertRaises(LWTException, b.execute)
updated = TestTransactionModel.objects(id=id).first()
self.assertEqual(updated.text, 'something else')

View File

@@ -207,6 +207,21 @@ Model Methods
This method is supported on Cassandra 2.0 or later. This method is supported on Cassandra 2.0 or later.
.. method:: iff(**values)
Checks to ensure that the values specified are correct on the Cassandra cluster.
Simply specify the column(s) and the expected value(s). As with if_not_exists,
this incurs a performance cost.
If the insertion isn't applied, a LWTException is raised
.. code-block::
t = TestTransactionModel(text='some text', count=5)
try:
t.iff(count=5).update('other text')
except LWTException as e:
# handle failure
.. method:: update(**values) .. method:: update(**values)
Performs an update on the model instance. You can pass in values to set on the model Performs an update on the model instance. You can pass in values to set on the model