diff --git a/cqlengine/models.py b/cqlengine/models.py index 77812426..880f6207 100644 --- a/cqlengine/models.py +++ b/cqlengine/models.py @@ -74,6 +74,27 @@ class QuerySetDescriptor(object): raise NotImplementedError +class TransactionDescriptor(object): + """ + returns a query set descriptor + """ + def __get__(self, instance, model): + if instance: + def transaction_setter(transaction): + instance._transaction = transaction + return instance + return transaction_setter + qs = model.__queryset__(model) + + def transaction_setter(transaction): + qs._transaction = transaction + return instance + return transaction_setter + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + class TTLDescriptor(object): """ returns a query set descriptor @@ -239,6 +260,7 @@ class BaseModel(object): objects = QuerySetDescriptor() ttl = TTLDescriptor() consistency = ConsistencyDescriptor() + transaction = TransactionDescriptor() # custom timestamps, see USING TIMESTAMP X timestamp = TimestampDescriptor() @@ -303,7 +325,7 @@ class BaseModel(object): self._timestamp = None for name, column in self._columns.items(): - value = values.get(name, None) + value = values.get(name, None) if value is not None or isinstance(column, columns.BaseContainerColumn): value = column.to_python(value) value_mngr = column.value_manager(self, column, value) @@ -531,6 +553,10 @@ class BaseModel(object): return cls.objects.filter(*args, **kwargs) + @classmethod + def transaction(cls, *args, **kwargs): + return cls.objects.transaction(*args, **kwargs) + @classmethod def get(cls, *args, **kwargs): return cls.objects.get(*args, **kwargs) diff --git a/cqlengine/query.py b/cqlengine/query.py index fd2f0072..8ee7e621 100644 --- a/cqlengine/query.py +++ b/cqlengine/query.py @@ -13,7 +13,7 @@ from cqlengine.functions import Token, BaseQueryFunction, QueryValue, UnicodeMix #http://www.datastax.com/docs/1.1/references/cql/index from cqlengine.operators import InOperator, EqualsOperator, GreaterThanOperator, GreaterThanOrEqualOperator from cqlengine.operators import LessThanOperator, LessThanOrEqualOperator, BaseWhereOperator -from cqlengine.statements import WhereClause, SelectStatement, DeleteStatement, UpdateStatement, AssignmentClause, InsertStatement, BaseCQLStatement, MapUpdateClause, MapDeleteClause, ListUpdateClause, SetUpdateClause, CounterUpdateClause +from cqlengine.statements import WhereClause, SelectStatement, DeleteStatement, UpdateStatement, AssignmentClause, InsertStatement, BaseCQLStatement, MapUpdateClause, MapDeleteClause, ListUpdateClause, SetUpdateClause, CounterUpdateClause, TransactionClause class QueryException(CQLEngineException): pass @@ -206,6 +206,9 @@ class AbstractQuerySet(object): #Where clause filters self._where = [] + # Transaction clause filters + self._transaction = [] + #ordering arguments self._order = [] @@ -407,6 +410,50 @@ class AbstractQuerySet(object): else: raise QueryException("Can't parse '{}'".format(arg)) + def transaction(self, *args, **kwargs): + """Adds IF statements to queryset""" + if len([x for x in kwargs.values() if x is None]): + raise CQLEngineException("None values on transaction are not allowed") + + clone = copy.deepcopy(self) + for operator in args: + if not isinstance(operator, TransactionClause): + raise QueryException('{} is not a valid query operator'.format(operator)) + clone._transaction.append(operator) + + for col_name, val in kwargs.items(): + try: + column = self.model._get_column(col_name) + except KeyError: + if col_name in ['exists', 'not_exists']: + pass + elif col_name == 'pk__token': + if not isinstance(val, Token): + raise QueryException("Virtual column 'pk__token' may only be compared to Token() values") + column = columns._PartitionKeysToken(self.model) + quote_field = False + else: + raise QueryException("Can't resolve column name: '{}'".format(col_name)) + + if isinstance(val, Token): + if col_name != 'pk__token': + raise QueryException("Token() values may only be compared to the 'pk__token' virtual column") + partition_columns = column.partition_columns + if len(partition_columns) != len(val.value): + raise QueryException( + 'Token() received {} arguments but model has {} partition keys'.format( + len(val.value), len(partition_columns))) + val.set_columns(partition_columns) + + if isinstance(val, BaseQueryFunction): + query_val = val + else: + query_val = column.to_database(val) + + clone._transaction.append(TransactionClause(col_name, query_val)) + + return clone + def filter(self, *args, **kwargs): """ Adds WHERE arguments to the queryset, returning a new queryset diff --git a/cqlengine/statements.py b/cqlengine/statements.py index a6cdf4b3..5cdef38a 100644 --- a/cqlengine/statements.py +++ b/cqlengine/statements.py @@ -139,6 +139,20 @@ class AssignmentClause(BaseClause): return self.field, self.context_id +class TransactionClause(BaseClause): + """ A single variable transaction statement """ + + def __unicode__(self): + if self.field == 'exists': + return u'EXISTS' + if self.field == 'not_exists': + return u'NOT EXISTS' + return u'"{}" = %({})s'.format(self.field, self.context_id) + + def insert_tuple(self): + return self.field, self.context_id + + class ContainerUpdateClause(AssignmentClause): def __init__(self, field, value, operation=None, previous=None, column=None): @@ -573,7 +587,8 @@ class AssignmentStatement(BaseCQLStatement): consistency=None, where=None, ttl=None, - timestamp=None): + timestamp=None, + transactions=None): super(AssignmentStatement, self).__init__( table, consistency=consistency, @@ -587,6 +602,11 @@ class AssignmentStatement(BaseCQLStatement): for assignment in assignments or []: self.add_assignment_clause(assignment) + # Add transaction statements + self.transactions = [] + for transaction in transactions or []: + self.add_transaction_clause(transaction) + def update_context_id(self, i): super(AssignmentStatement, self).update_context_id(i) for assignment in self.assignments: @@ -605,6 +625,19 @@ class AssignmentStatement(BaseCQLStatement): self.context_counter += clause.get_context_size() self.assignments.append(clause) + def add_transaction_clause(self, clause): + """ + Adds a transaction clause to this statement + + :param clause: The clause that will be added to the transaction statement + :type clause: TransactionClause + """ + if not isinstance(clause, TransactionClause): + raise StatementException('only instances of AssignmentClause can be added to statements') + clause.set_context_id(self.context_counter) + self.context_counter += clause.get_context_size() + self.transactions.append(clause) + @property def is_empty(self): return len(self.assignments) == 0 @@ -615,6 +648,9 @@ class AssignmentStatement(BaseCQLStatement): clause.update_context(ctx) return ctx + def _transactions(self): + return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.where_clauses])) + class InsertStatement(AssignmentStatement): """ an cql insert select statement """ @@ -660,6 +696,9 @@ class InsertStatement(AssignmentStatement): if self.timestamp: qs += ["USING TIMESTAMP {}".format(self.timestamp_normalized)] + if len(self.transactions) > 0: + qs += [self._transactions] + return ' '.join(qs) @@ -686,6 +725,9 @@ class UpdateStatement(AssignmentStatement): if self.where_clauses: qs += [self._where] + if len(self.transactions) > 0: + qs += [self._transactions] + return ' '.join(qs) diff --git a/cqlengine/tests/model/test_transaction_statements.py b/cqlengine/tests/model/test_transaction_statements.py new file mode 100644 index 00000000..8afabb1e --- /dev/null +++ b/cqlengine/tests/model/test_transaction_statements.py @@ -0,0 +1,36 @@ +__author__ = 'Tim Martin' +from uuid import uuid4 + +from mock import patch +from cqlengine.exceptions import ValidationError + +from cqlengine.tests.base import BaseCassEngTestCase +from cqlengine.models import Model +from cqlengine import columns +from cqlengine.management import sync_table, drop_table + + +class TestUpdateModel(Model): + __keyspace__ = 'test' + partition = columns.UUID(primary_key=True, default=uuid4) + cluster = columns.UUID(primary_key=True, default=uuid4) + count = columns.Integer(required=False) + text = columns.Text(required=False, index=True) + + +class ModelUpdateTests(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(ModelUpdateTests, cls).setUpClass() + sync_table(TestUpdateModel) + + @classmethod + def tearDownClass(cls): + super(ModelUpdateTests, cls).tearDownClass() + drop_table(TestUpdateModel) + + def test_transaction_insertion(self): + m = TestUpdateModel(count=5, text='something').transaction(exists=True) + m.save() + x = 10 \ No newline at end of file