Merge pull request #284 from timmartin19/master

UPDATE transaction statements
This commit is contained in:
Jon Haddad
2014-11-14 11:51:16 -08:00
5 changed files with 263 additions and 10 deletions

View File

@@ -74,6 +74,33 @@ class QuerySetDescriptor(object):
raise NotImplementedError
class TransactionDescriptor(object):
"""
returns a query set descriptor
"""
def __get__(self, instance, model):
if instance:
def transaction_setter(*prepared_transaction, **unprepared_transactions):
if len(prepared_transaction) > 0:
transactions = prepared_transaction[0]
else:
transactions = instance.objects.iff(**unprepared_transactions)._transaction
instance._transaction = transactions
return instance
return transaction_setter
qs = model.__queryset__(model)
def transaction_setter(**unprepared_transactions):
transactions = model.objects.iff(**unprepared_transactions)._transaction
qs._transaction = transactions
return qs
return transaction_setter
def __call__(self, *args, **kwargs):
raise NotImplementedError
class TTLDescriptor(object):
"""
returns a query set descriptor
@@ -239,6 +266,7 @@ class BaseModel(object):
objects = QuerySetDescriptor()
ttl = TTLDescriptor()
consistency = ConsistencyDescriptor()
iff = TransactionDescriptor()
# custom timestamps, see USING TIMESTAMP X
timestamp = TimestampDescriptor()
@@ -301,9 +329,10 @@ class BaseModel(object):
self._values = {}
self._ttl = self.__default_ttl__
self._timestamp = None
self._transaction = None
for name, column in self._columns.items():
value = values.get(name, None)
value = values.get(name, None)
if value is not None or isinstance(column, columns.BaseContainerColumn):
value = column.to_python(value)
value_mngr = column.value_manager(self, column, value)
@@ -550,7 +579,8 @@ class BaseModel(object):
ttl=self._ttl,
timestamp=self._timestamp,
consistency=self.__consistency__,
if_not_exists=self._if_not_exists).save()
if_not_exists=self._if_not_exists,
transaction=self._transaction).save()
#reset the value managers
for v in self._values.values():
@@ -588,7 +618,8 @@ class BaseModel(object):
batch=self._batch,
ttl=self._ttl,
timestamp=self._timestamp,
consistency=self.__consistency__).update()
consistency=self.__consistency__,
transaction=self._transaction).update()
#reset the value managers
for v in self._values.values():

View File

@@ -5,7 +5,6 @@ from cqlengine import BaseContainerColumn, Map, columns
from cqlengine.columns import Counter, List, Set
from cqlengine.connection import execute
from cqlengine.exceptions import CQLEngineException, ValidationError, LWTException
from cqlengine.functions import Token, BaseQueryFunction, QueryValue, UnicodeMixin
@@ -13,7 +12,7 @@ from cqlengine.functions import Token, BaseQueryFunction, QueryValue, UnicodeMix
#http://www.datastax.com/docs/1.1/references/cql/index
from cqlengine.operators import InOperator, EqualsOperator, GreaterThanOperator, GreaterThanOrEqualOperator
from cqlengine.operators import LessThanOperator, LessThanOrEqualOperator, BaseWhereOperator
from cqlengine.statements import WhereClause, SelectStatement, DeleteStatement, UpdateStatement, AssignmentClause, InsertStatement, BaseCQLStatement, MapUpdateClause, MapDeleteClause, ListUpdateClause, SetUpdateClause, CounterUpdateClause
from cqlengine.statements import WhereClause, SelectStatement, DeleteStatement, UpdateStatement, AssignmentClause, InsertStatement, BaseCQLStatement, MapUpdateClause, MapDeleteClause, ListUpdateClause, SetUpdateClause, CounterUpdateClause, TransactionClause
class QueryException(CQLEngineException): pass
@@ -206,6 +205,9 @@ class AbstractQuerySet(object):
#Where clause filters
self._where = []
# Transaction clause filters
self._transaction = []
#ordering arguments
self._order = []
@@ -243,6 +245,8 @@ class AbstractQuerySet(object):
return self._batch.add_query(q)
else:
result = execute(q, consistency_level=self._consistency)
if self._transaction:
check_applied(result)
return result
def __unicode__(self):
@@ -407,6 +411,49 @@ class AbstractQuerySet(object):
else:
raise QueryException("Can't parse '{}'".format(arg))
def iff(self, *args, **kwargs):
"""Adds IF statements to queryset"""
if len([x for x in kwargs.values() if x is None]):
raise CQLEngineException("None values on iff are not allowed")
clone = copy.deepcopy(self)
for operator in args:
if not isinstance(operator, TransactionClause):
raise QueryException('{} is not a valid query operator'.format(operator))
clone._transaction.append(operator)
for col_name, val in kwargs.items():
exists = False
try:
column = self.model._get_column(col_name)
except KeyError:
if col_name == 'pk__token':
if not isinstance(val, Token):
raise QueryException("Virtual column 'pk__token' may only be compared to Token() values")
column = columns._PartitionKeysToken(self.model)
quote_field = False
else:
raise QueryException("Can't resolve column name: '{}'".format(col_name))
if isinstance(val, Token):
if col_name != 'pk__token':
raise QueryException("Token() values may only be compared to the 'pk__token' virtual column")
partition_columns = column.partition_columns
if len(partition_columns) != len(val.value):
raise QueryException(
'Token() received {} arguments but model has {} partition keys'.format(
len(val.value), len(partition_columns)))
val.set_columns(partition_columns)
if isinstance(val, BaseQueryFunction) or exists is True:
query_val = val
else:
query_val = column.to_database(val)
clone._transaction.append(TransactionClause(col_name, query_val))
return clone
def filter(self, *args, **kwargs):
"""
Adds WHERE arguments to the queryset, returning a new queryset
@@ -732,7 +779,8 @@ class ModelQuerySet(AbstractQuerySet):
return
nulled_columns = set()
us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, timestamp=self._timestamp)
us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl,
timestamp=self._timestamp, transactions=self._transaction)
for name, val in values.items():
col_name, col_op = self._parse_filter_arg(name)
col = self.model._columns.get(col_name)
@@ -787,7 +835,7 @@ class DMLQuery(object):
_timestamp = None
_if_not_exists = False
def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False):
def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False, transaction=None):
self.model = model
self.column_family_name = self.model.column_family_name()
self.instance = instance
@@ -796,15 +844,15 @@ class DMLQuery(object):
self._consistency = consistency
self._timestamp = timestamp
self._if_not_exists = if_not_exists
self._transaction = transaction
def _execute(self, q):
if self._batch:
return self._batch.add_query(q)
else:
tmp = execute(q, consistency_level=self._consistency)
if self._if_not_exists:
if self._if_not_exists or self._transaction:
check_applied(tmp)
return tmp
def batch(self, batch_obj):
@@ -850,7 +898,8 @@ class DMLQuery(object):
raise CQLEngineException("DML Query intance attribute is None")
assert type(self.instance) == self.model
static_update_only = True
statement = UpdateStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp)
statement = UpdateStatement(self.column_family_name, ttl=self._ttl,
timestamp=self._timestamp, transactions=self._transaction)
#get defined fields and their column names
for name, col in self.model._columns.items():
if not col.is_primary_key:

View File

@@ -139,6 +139,16 @@ class AssignmentClause(BaseClause):
return self.field, self.context_id
class TransactionClause(BaseClause):
""" A single variable iff statement """
def __unicode__(self):
return u'"{}" = %({})s'.format(self.field, self.context_id)
def insert_tuple(self):
return self.field, self.context_id
class ContainerUpdateClause(AssignmentClause):
def __init__(self, field, value, operation=None, previous=None, column=None):
@@ -675,6 +685,26 @@ class InsertStatement(AssignmentStatement):
class UpdateStatement(AssignmentStatement):
""" an cql update select statement """
def __init__(self,
table,
assignments=None,
consistency=None,
where=None,
ttl=None,
timestamp=None,
transactions=None):
super(UpdateStatement, self). __init__(table,
assignments=assignments,
consistency=consistency,
where=where,
ttl=ttl,
timestamp=timestamp)
# Add iff statements
self.transactions = []
for transaction in transactions or []:
self.add_transaction_clause(transaction)
def __unicode__(self):
qs = ['UPDATE', self.table]
@@ -695,8 +725,39 @@ class UpdateStatement(AssignmentStatement):
if self.where_clauses:
qs += [self._where]
if len(self.transactions) > 0:
qs += [self._get_transactions()]
return ' '.join(qs)
def add_transaction_clause(self, clause):
"""
Adds a iff clause to this statement
:param clause: The clause that will be added to the iff statement
:type clause: TransactionClause
"""
if not isinstance(clause, TransactionClause):
raise StatementException('only instances of AssignmentClause can be added to statements')
clause.set_context_id(self.context_counter)
self.context_counter += clause.get_context_size()
self.transactions.append(clause)
def get_context(self):
ctx = super(UpdateStatement, self).get_context()
for clause in self.transactions or []:
clause.update_context(ctx)
return ctx
def _get_transactions(self):
return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.transactions]))
def update_context_id(self, i):
super(UpdateStatement, self).update_context_id(i)
for transaction in self.transactions:
transaction.set_context_id(self.context_counter)
self.context_counter += transaction.get_context_size()
class DeleteStatement(BaseCQLStatement):
""" a cql delete statement """

View File

@@ -0,0 +1,97 @@
__author__ = 'Tim Martin'
from cqlengine.management import sync_table, drop_table
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.models import Model
from cqlengine.exceptions import LWTException
from uuid import uuid4
from cqlengine import columns, BatchQuery
import mock
from cqlengine import ALL, BatchQuery
from cqlengine.statements import TransactionClause
import six
class TestTransactionModel(Model):
__keyspace__ = 'test'
id = columns.UUID(primary_key=True, default=lambda:uuid4())
count = columns.Integer()
text = columns.Text(required=False)
class TestTransaction(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestTransaction, cls).setUpClass()
sync_table(TestTransactionModel)
@classmethod
def tearDownClass(cls):
super(TestTransaction, cls).tearDownClass()
drop_table(TestTransactionModel)
def test_update_using_transaction(self):
t = TestTransactionModel.create(text='blah blah')
t.text = 'new blah'
with mock.patch.object(self.session, 'execute') as m:
t.iff(text='blah blah').save()
args = m.call_args
self.assertIn('IF "text" = %(0)s', args[0][0].query_string)
def test_update_transaction_success(self):
t = TestTransactionModel.create(text='blah blah', count=5)
id = t.id
t.text = 'new blah'
t.iff(text='blah blah').save()
updated = TestTransactionModel.objects(id=id).first()
self.assertEqual(updated.count, 5)
self.assertEqual(updated.text, 'new blah')
def test_update_failure(self):
t = TestTransactionModel.create(text='blah blah')
t.text = 'new blah'
t = t.iff(text='something wrong')
self.assertRaises(LWTException, t.save)
def test_blind_update(self):
t = TestTransactionModel.create(text='blah blah')
t.text = 'something else'
uid = t.id
with mock.patch.object(self.session, 'execute') as m:
TestTransactionModel.objects(id=uid).iff(text='blah blah').update(text='oh hey der')
args = m.call_args
self.assertIn('IF "text" = %(1)s', args[0][0].query_string)
def test_blind_update_fail(self):
t = TestTransactionModel.create(text='blah blah')
t.text = 'something else'
uid = t.id
qs = TestTransactionModel.objects(id=uid).iff(text='Not dis!')
self.assertRaises(LWTException, qs.update, text='this will never work')
def test_transaction_clause(self):
tc = TransactionClause('some_value', 23)
tc.set_context_id(3)
self.assertEqual('"some_value" = %(3)s', six.text_type(tc))
self.assertEqual('"some_value" = %(3)s', str(tc))
def test_batch_update_transaction(self):
t = TestTransactionModel.create(text='something', count=5)
id = t.id
with BatchQuery() as b:
t.batch(b).iff(count=5).update(text='something else')
updated = TestTransactionModel.objects(id=id).first()
self.assertEqual(updated.text, 'something else')
b = BatchQuery()
updated.batch(b).iff(count=6).update(text='and another thing')
self.assertRaises(LWTException, b.execute)
updated = TestTransactionModel.objects(id=id).first()
self.assertEqual(updated.text, 'something else')

View File

@@ -207,6 +207,21 @@ Model Methods
This method is supported on Cassandra 2.0 or later.
.. method:: iff(**values)
Checks to ensure that the values specified are correct on the Cassandra cluster.
Simply specify the column(s) and the expected value(s). As with if_not_exists,
this incurs a performance cost.
If the insertion isn't applied, a LWTException is raised
.. code-block::
t = TestTransactionModel(text='some text', count=5)
try:
t.iff(count=5).update('other text')
except LWTException as e:
# handle failure
.. method:: update(**values)
Performs an update on the model instance. You can pass in values to set on the model