Merge pull request #520 from datastax/249
PYTHON-249 - Fixing conditional deletes in cqlengine
This commit is contained in:
		| @@ -101,28 +101,28 @@ class QuerySetDescriptor(object): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|  | ||||
| class TransactionDescriptor(object): | ||||
| class ConditionalDescriptor(object): | ||||
|     """ | ||||
|     returns a query set descriptor | ||||
|     """ | ||||
|     def __get__(self, instance, model): | ||||
|         if instance: | ||||
|             def transaction_setter(*prepared_transaction, **unprepared_transactions): | ||||
|                 if len(prepared_transaction) > 0: | ||||
|                     transactions = prepared_transaction[0] | ||||
|             def conditional_setter(*prepared_conditional, **unprepared_conditionals): | ||||
|                 if len(prepared_conditional) > 0: | ||||
|                     conditionals = prepared_conditional[0] | ||||
|                 else: | ||||
|                     transactions = instance.objects.iff(**unprepared_transactions)._transaction | ||||
|                 instance._transaction = transactions | ||||
|                     conditionals = instance.objects.iff(**unprepared_conditionals)._conditional | ||||
|                 instance._conditional = conditionals | ||||
|                 return instance | ||||
|  | ||||
|             return transaction_setter | ||||
|             return conditional_setter | ||||
|         qs = model.__queryset__(model) | ||||
|  | ||||
|         def transaction_setter(**unprepared_transactions): | ||||
|             transactions = model.objects.iff(**unprepared_transactions)._transaction | ||||
|             qs._transaction = transactions | ||||
|         def conditional_setter(**unprepared_conditionals): | ||||
|             conditionals = model.objects.iff(**unprepared_conditionals)._conditional | ||||
|             qs._conditional = conditionals | ||||
|             return qs | ||||
|         return transaction_setter | ||||
|         return conditional_setter | ||||
|  | ||||
|     def __call__(self, *args, **kwargs): | ||||
|         raise NotImplementedError | ||||
| @@ -314,7 +314,7 @@ class BaseModel(object): | ||||
|     objects = QuerySetDescriptor() | ||||
|     ttl = TTLDescriptor() | ||||
|     consistency = ConsistencyDescriptor() | ||||
|     iff = TransactionDescriptor() | ||||
|     iff = ConditionalDescriptor() | ||||
|  | ||||
|     # custom timestamps, see USING TIMESTAMP X | ||||
|     timestamp = TimestampDescriptor() | ||||
| @@ -352,7 +352,7 @@ class BaseModel(object): | ||||
|     def __init__(self, **values): | ||||
|         self._ttl = self.__default_ttl__ | ||||
|         self._timestamp = None | ||||
|         self._transaction = None | ||||
|         self._conditional = None | ||||
|         self._batch = None | ||||
|         self._timeout = connection.NOT_SET | ||||
|         self._is_persisted = False | ||||
| @@ -684,7 +684,7 @@ class BaseModel(object): | ||||
|                           timestamp=self._timestamp, | ||||
|                           consistency=self.__consistency__, | ||||
|                           if_not_exists=self._if_not_exists, | ||||
|                           transaction=self._transaction, | ||||
|                           conditional=self._conditional, | ||||
|                           timeout=self._timeout, | ||||
|                           if_exists=self._if_exists).save() | ||||
|  | ||||
| @@ -731,7 +731,7 @@ class BaseModel(object): | ||||
|                           ttl=self._ttl, | ||||
|                           timestamp=self._timestamp, | ||||
|                           consistency=self.__consistency__, | ||||
|                           transaction=self._transaction, | ||||
|                           conditional=self._conditional, | ||||
|                           timeout=self._timeout, | ||||
|                           if_exists=self._if_exists).update() | ||||
|  | ||||
| @@ -751,6 +751,7 @@ class BaseModel(object): | ||||
|                           timestamp=self._timestamp, | ||||
|                           consistency=self.__consistency__, | ||||
|                           timeout=self._timeout, | ||||
|                           conditional=self._conditional, | ||||
|                           if_exists=self._if_exists).delete() | ||||
|  | ||||
|     def get_changed_columns(self): | ||||
|   | ||||
| @@ -28,7 +28,7 @@ from cassandra.cqlengine.statements import (WhereClause, SelectStatement, Delete | ||||
|                                             UpdateStatement, AssignmentClause, InsertStatement, | ||||
|                                             BaseCQLStatement, MapUpdateClause, MapDeleteClause, | ||||
|                                             ListUpdateClause, SetUpdateClause, CounterUpdateClause, | ||||
|                                             TransactionClause) | ||||
|                                             ConditionalClause) | ||||
|  | ||||
|  | ||||
| class QueryException(CQLEngineException): | ||||
| @@ -43,7 +43,7 @@ class IfExistsWithCounterColumn(CQLEngineException): | ||||
|  | ||||
|  | ||||
| class LWTException(CQLEngineException): | ||||
|     """Lightweight transaction exception. | ||||
|     """Lightweight conditional exception. | ||||
|  | ||||
|     This exception will be raised when a write using an `IF` clause could not be | ||||
|     applied due to existing data violating the condition. The existing data is | ||||
| @@ -146,7 +146,7 @@ class BatchQuery(object): | ||||
|         :param batch_type: (optional) One of batch type values available through BatchType enum | ||||
|         :type batch_type: str or None | ||||
|         :param timestamp: (optional) A datetime or timedelta object with desired timestamp to be applied | ||||
|             to the batch transaction. | ||||
|             to the batch conditional. | ||||
|         :type timestamp: datetime or timedelta or None | ||||
|         :param consistency: (optional) One of consistency values ("ANY", "ONE", "QUORUM" etc) | ||||
|         :type consistency: The :class:`.ConsistencyLevel` to be used for the batch query, or None. | ||||
| @@ -267,8 +267,8 @@ class AbstractQuerySet(object): | ||||
|         # Where clause filters | ||||
|         self._where = [] | ||||
|  | ||||
|         # Transaction clause filters | ||||
|         self._transaction = [] | ||||
|         # Conditional clause filters | ||||
|         self._conditional = [] | ||||
|  | ||||
|         # ordering arguments | ||||
|         self._order = [] | ||||
| @@ -314,7 +314,7 @@ class AbstractQuerySet(object): | ||||
|             return self._batch.add_query(q) | ||||
|         else: | ||||
|             result = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout) | ||||
|             if self._if_not_exists or self._if_exists or self._transaction: | ||||
|             if self._if_not_exists or self._if_exists or self._conditional: | ||||
|                 check_applied(result) | ||||
|             return result | ||||
|  | ||||
| @@ -545,9 +545,9 @@ class AbstractQuerySet(object): | ||||
|  | ||||
|         clone = copy.deepcopy(self) | ||||
|         for operator in args: | ||||
|             if not isinstance(operator, TransactionClause): | ||||
|             if not isinstance(operator, ConditionalClause): | ||||
|                 raise QueryException('{0} is not a valid query operator'.format(operator)) | ||||
|             clone._transaction.append(operator) | ||||
|             clone._conditional.append(operator) | ||||
|  | ||||
|         for col_name, val in kwargs.items(): | ||||
|             exists = False | ||||
| @@ -576,7 +576,7 @@ class AbstractQuerySet(object): | ||||
|             else: | ||||
|                 query_val = column.to_database(val) | ||||
|  | ||||
|             clone._transaction.append(TransactionClause(col_name, query_val)) | ||||
|             clone._conditional.append(ConditionalClause(col_name, query_val)) | ||||
|  | ||||
|         return clone | ||||
|  | ||||
| @@ -898,6 +898,7 @@ class AbstractQuerySet(object): | ||||
|             self.column_family_name, | ||||
|             where=self._where, | ||||
|             timestamp=self._timestamp, | ||||
|             conditionals=self._conditional, | ||||
|             if_exists=self._if_exists | ||||
|         ) | ||||
|         self._execute(dq) | ||||
| @@ -1155,7 +1156,7 @@ class ModelQuerySet(AbstractQuerySet): | ||||
|  | ||||
|         nulled_columns = set() | ||||
|         us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, | ||||
|                              timestamp=self._timestamp, transactions=self._transaction, if_exists=self._if_exists) | ||||
|                              timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists) | ||||
|         for name, val in values.items(): | ||||
|             col_name, col_op = self._parse_filter_arg(name) | ||||
|             col = self.model._columns.get(col_name) | ||||
| @@ -1196,7 +1197,7 @@ class ModelQuerySet(AbstractQuerySet): | ||||
|  | ||||
|         if nulled_columns: | ||||
|             ds = DeleteStatement(self.column_family_name, fields=nulled_columns, | ||||
|                                  where=self._where, if_exists=self._if_exists) | ||||
|                                  where=self._where, conditionals=self._conditional, if_exists=self._if_exists) | ||||
|             self._execute(ds) | ||||
|  | ||||
|  | ||||
| @@ -1215,7 +1216,7 @@ class DMLQuery(object): | ||||
|     _if_exists = False | ||||
|  | ||||
|     def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, | ||||
|                  if_not_exists=False, transaction=None, timeout=connection.NOT_SET, if_exists=False): | ||||
|                  if_not_exists=False, conditional=None, timeout=connection.NOT_SET, if_exists=False): | ||||
|         self.model = model | ||||
|         self.column_family_name = self.model.column_family_name() | ||||
|         self.instance = instance | ||||
| @@ -1225,7 +1226,7 @@ class DMLQuery(object): | ||||
|         self._timestamp = timestamp | ||||
|         self._if_not_exists = if_not_exists | ||||
|         self._if_exists = if_exists | ||||
|         self._transaction = transaction | ||||
|         self._conditional = conditional | ||||
|         self._timeout = timeout | ||||
|  | ||||
|     def _execute(self, q): | ||||
| @@ -1233,7 +1234,7 @@ class DMLQuery(object): | ||||
|             return self._batch.add_query(q) | ||||
|         else: | ||||
|             tmp = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout) | ||||
|             if self._if_not_exists or self._if_exists or self._transaction: | ||||
|             if self._if_not_exists or self._if_exists or self._conditional: | ||||
|                 check_applied(tmp) | ||||
|             return tmp | ||||
|  | ||||
| @@ -1247,7 +1248,7 @@ class DMLQuery(object): | ||||
|         """ | ||||
|         executes a delete query to remove columns that have changed to null | ||||
|         """ | ||||
|         ds = DeleteStatement(self.column_family_name, if_exists=self._if_exists) | ||||
|         ds = DeleteStatement(self.column_family_name, conditionals=self._conditional, if_exists=self._if_exists) | ||||
|         deleted_fields = False | ||||
|         for _, v in self.instance._values.items(): | ||||
|             col = v.column | ||||
| @@ -1282,7 +1283,7 @@ class DMLQuery(object): | ||||
|         null_clustering_key = False if len(self.instance._clustering_keys) == 0 else True | ||||
|         static_changed_only = True | ||||
|         statement = UpdateStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, | ||||
|                                     transactions=self._transaction, if_exists=self._if_exists) | ||||
|                                     conditionals=self._conditional, if_exists=self._if_exists) | ||||
|         for name, col in self.instance._clustering_keys.items(): | ||||
|             null_clustering_key = null_clustering_key and col._val_is_null(getattr(self.instance, name, None)) | ||||
|         # get defined fields and their column names | ||||
| @@ -1324,7 +1325,7 @@ class DMLQuery(object): | ||||
|                         col.to_database(val) | ||||
|                     )) | ||||
|  | ||||
|         if statement.get_context_size() > 0 or self.instance._has_counter: | ||||
|         if statement.assignments: | ||||
|             for name, col in self.model._primary_keys.items(): | ||||
|                 # only include clustering key if clustering key is not null, and non static columns are changed to avoid cql error | ||||
|                 if (null_clustering_key or static_changed_only) and (not col.partition_key): | ||||
| @@ -1386,7 +1387,7 @@ class DMLQuery(object): | ||||
|         if self.instance is None: | ||||
|             raise CQLEngineException("DML Query instance attribute is None") | ||||
|  | ||||
|         ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp, if_exists=self._if_exists) | ||||
|         ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists) | ||||
|         for name, col in self.model._primary_keys.items(): | ||||
|             if (not col.partition_key) and (getattr(self.instance, name) is None): | ||||
|                 continue | ||||
|   | ||||
| @@ -148,7 +148,7 @@ class AssignmentClause(BaseClause): | ||||
|         return self.field, self.context_id | ||||
|  | ||||
|  | ||||
| class TransactionClause(BaseClause): | ||||
| class ConditionalClause(BaseClause): | ||||
|     """ A single variable iff statement """ | ||||
|  | ||||
|     def __unicode__(self): | ||||
| @@ -471,7 +471,7 @@ class MapDeleteClause(BaseDeleteClause): | ||||
| class BaseCQLStatement(UnicodeMixin): | ||||
|     """ The base cql statement class """ | ||||
|  | ||||
|     def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None): | ||||
|     def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, conditionals=None): | ||||
|         super(BaseCQLStatement, self).__init__() | ||||
|         self.table = table | ||||
|         self.consistency = consistency | ||||
| @@ -484,6 +484,10 @@ class BaseCQLStatement(UnicodeMixin): | ||||
|         for clause in where or []: | ||||
|             self.add_where_clause(clause) | ||||
|  | ||||
|         self.conditionals = [] | ||||
|         for conditional in conditionals or []: | ||||
|             self.add_conditional_clause(conditional) | ||||
|  | ||||
|     def add_where_clause(self, clause): | ||||
|         """ | ||||
|         adds a where clause to this statement | ||||
| @@ -506,6 +510,22 @@ class BaseCQLStatement(UnicodeMixin): | ||||
|             clause.update_context(ctx) | ||||
|         return ctx | ||||
|  | ||||
|     def add_conditional_clause(self, clause): | ||||
|         """ | ||||
|         Adds a iff clause to this statement | ||||
|  | ||||
|         :param clause: The clause that will be added to the iff statement | ||||
|         :type clause: ConditionalClause | ||||
|         """ | ||||
|         if not isinstance(clause, ConditionalClause): | ||||
|             raise StatementException('only instances of AssignmentClause can be added to statements') | ||||
|         clause.set_context_id(self.context_counter) | ||||
|         self.context_counter += clause.get_context_size() | ||||
|         self.conditionals.append(clause) | ||||
|  | ||||
|     def _get_conditionals(self): | ||||
|         return 'IF {0}'.format(' AND '.join([six.text_type(c) for c in self.conditionals])) | ||||
|  | ||||
|     def get_context_size(self): | ||||
|         return len(self.get_context()) | ||||
|  | ||||
| @@ -616,11 +636,13 @@ class AssignmentStatement(BaseCQLStatement): | ||||
|                  consistency=None, | ||||
|                  where=None, | ||||
|                  ttl=None, | ||||
|                  timestamp=None): | ||||
|                  timestamp=None, | ||||
|                  conditionals=None): | ||||
|         super(AssignmentStatement, self).__init__( | ||||
|             table, | ||||
|             consistency=consistency, | ||||
|             where=where, | ||||
|             conditionals=conditionals | ||||
|         ) | ||||
|         self.ttl = ttl | ||||
|         self.timestamp = timestamp | ||||
| @@ -715,19 +737,15 @@ class UpdateStatement(AssignmentStatement): | ||||
|                  where=None, | ||||
|                  ttl=None, | ||||
|                  timestamp=None, | ||||
|                  transactions=None, | ||||
|                  conditionals=None, | ||||
|                  if_exists=False): | ||||
|         super(UpdateStatement, self). __init__(table, | ||||
|                                                assignments=assignments, | ||||
|                                                consistency=consistency, | ||||
|                                                where=where, | ||||
|                                                ttl=ttl, | ||||
|                                                timestamp=timestamp) | ||||
|  | ||||
|         # Add iff statements | ||||
|         self.transactions = [] | ||||
|         for transaction in transactions or []: | ||||
|             self.add_transaction_clause(transaction) | ||||
|                                                timestamp=timestamp, | ||||
|                                                conditionals=conditionals) | ||||
|  | ||||
|         self.if_exists = if_exists | ||||
|  | ||||
| @@ -751,58 +769,44 @@ class UpdateStatement(AssignmentStatement): | ||||
|         if self.where_clauses: | ||||
|             qs += [self._where] | ||||
|  | ||||
|         if len(self.transactions) > 0: | ||||
|             qs += [self._get_transactions()] | ||||
|         if len(self.conditionals) > 0: | ||||
|             qs += [self._get_conditionals()] | ||||
|  | ||||
|         if self.if_exists: | ||||
|             qs += ["IF EXISTS"] | ||||
|  | ||||
|         return ' '.join(qs) | ||||
|  | ||||
|     def add_transaction_clause(self, clause): | ||||
|         """ | ||||
|         Adds a iff clause to this statement | ||||
|  | ||||
|         :param clause: The clause that will be added to the iff statement | ||||
|         :type clause: TransactionClause | ||||
|         """ | ||||
|         if not isinstance(clause, TransactionClause): | ||||
|             raise StatementException('only instances of AssignmentClause can be added to statements') | ||||
|         clause.set_context_id(self.context_counter) | ||||
|         self.context_counter += clause.get_context_size() | ||||
|         self.transactions.append(clause) | ||||
|  | ||||
|     def get_context(self): | ||||
|         ctx = super(UpdateStatement, self).get_context() | ||||
|         for clause in self.transactions or []: | ||||
|         for clause in self.conditionals: | ||||
|             clause.update_context(ctx) | ||||
|         return ctx | ||||
|  | ||||
|     def _get_transactions(self): | ||||
|         return 'IF {0}'.format(' AND '.join([six.text_type(c) for c in self.transactions])) | ||||
|  | ||||
|     def update_context_id(self, i): | ||||
|         super(UpdateStatement, self).update_context_id(i) | ||||
|         for transaction in self.transactions: | ||||
|             transaction.set_context_id(self.context_counter) | ||||
|             self.context_counter += transaction.get_context_size() | ||||
|         for conditional in self.conditionals: | ||||
|             conditional.set_context_id(self.context_counter) | ||||
|             self.context_counter += conditional.get_context_size() | ||||
|  | ||||
|  | ||||
| class DeleteStatement(BaseCQLStatement): | ||||
|     """ a cql delete statement """ | ||||
|  | ||||
|     def __init__(self, table, fields=None, consistency=None, where=None, timestamp=None, if_exists=False): | ||||
|     def __init__(self, table, fields=None, consistency=None, where=None, timestamp=None, conditionals=None, if_exists=False): | ||||
|         super(DeleteStatement, self).__init__( | ||||
|             table, | ||||
|             consistency=consistency, | ||||
|             where=where, | ||||
|             timestamp=timestamp | ||||
|             timestamp=timestamp, | ||||
|             conditionals=conditionals | ||||
|         ) | ||||
|         self.fields = [] | ||||
|         if isinstance(fields, six.string_types): | ||||
|             fields = [fields] | ||||
|         for field in fields or []: | ||||
|             self.add_field(field) | ||||
|  | ||||
|         self.if_exists = if_exists | ||||
|  | ||||
|     def update_context_id(self, i): | ||||
| @@ -810,11 +814,16 @@ class DeleteStatement(BaseCQLStatement): | ||||
|         for field in self.fields: | ||||
|             field.set_context_id(self.context_counter) | ||||
|             self.context_counter += field.get_context_size() | ||||
|         for t in self.conditionals: | ||||
|             t.set_context_id(self.context_counter) | ||||
|             self.context_counter += t.get_context_size() | ||||
|  | ||||
|     def get_context(self): | ||||
|         ctx = super(DeleteStatement, self).get_context() | ||||
|         for field in self.fields: | ||||
|             field.update_context(ctx) | ||||
|         for clause in self.conditionals: | ||||
|             clause.update_context(ctx) | ||||
|         return ctx | ||||
|  | ||||
|     def add_field(self, field): | ||||
| @@ -843,6 +852,9 @@ class DeleteStatement(BaseCQLStatement): | ||||
|         if self.where_clauses: | ||||
|             qs += [self._where] | ||||
|  | ||||
|         if self.conditionals: | ||||
|             qs += [self._get_conditionals()] | ||||
|  | ||||
|         if self.if_exists: | ||||
|             qs += ["IF EXISTS"] | ||||
|  | ||||
|   | ||||
| @@ -13,10 +13,11 @@ | ||||
| # limitations under the License. | ||||
|  | ||||
| from unittest import TestCase | ||||
| from cassandra.cqlengine.statements import DeleteStatement, WhereClause, MapDeleteClause | ||||
| from cassandra.cqlengine.statements import DeleteStatement, WhereClause, MapDeleteClause, ConditionalClause | ||||
| from cassandra.cqlengine.operators import * | ||||
| import six | ||||
|  | ||||
|  | ||||
| class DeleteStatementTests(TestCase): | ||||
|  | ||||
|     def test_single_field_is_listified(self): | ||||
| @@ -49,7 +50,7 @@ class DeleteStatementTests(TestCase): | ||||
|  | ||||
|     def test_context_update(self): | ||||
|         ds = DeleteStatement('table', None) | ||||
|         ds.add_field(MapDeleteClause('d', {1: 2}, {1:2, 3: 4})) | ||||
|         ds.add_field(MapDeleteClause('d', {1: 2}, {1: 2, 3: 4})) | ||||
|         ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) | ||||
|  | ||||
|         ds.update_context_id(7) | ||||
| @@ -72,3 +73,13 @@ class DeleteStatementTests(TestCase): | ||||
|         ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) | ||||
|         ds.add_where_clause(WhereClause('created_at', InOperator(), ['0', '10', '20'])) | ||||
|         self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" IN %(1)s', six.text_type(ds)) | ||||
|  | ||||
|     def test_delete_conditional(self): | ||||
|         where = [WhereClause('id', EqualsOperator(), 1)] | ||||
|         conditionals = [ConditionalClause('f0', 'value0'), ConditionalClause('f1', 'value1')] | ||||
|         ds = DeleteStatement('table', where=where, conditionals=conditionals) | ||||
|         self.assertEqual(len(ds.conditionals), len(conditionals)) | ||||
|         self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "id" = %(0)s IF "f0" = %(1)s AND "f1" = %(2)s', six.text_type(ds)) | ||||
|         fields = ['one', 'two'] | ||||
|         ds = DeleteStatement('table', fields=fields, where=where, conditionals=conditionals) | ||||
|         self.assertEqual(six.text_type(ds), 'DELETE "one", "two" FROM table WHERE "id" = %(0)s IF "f0" = %(1)s AND "f1" = %(2)s', six.text_type(ds)) | ||||
|   | ||||
							
								
								
									
										178
									
								
								tests/integration/cqlengine/test_lwt_conditional.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										178
									
								
								tests/integration/cqlengine/test_lwt_conditional.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,178 @@ | ||||
| # Copyright 2013-2016 DataStax, Inc. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| try: | ||||
|     import unittest2 as unittest | ||||
| except ImportError: | ||||
|     import unittest  # noqa | ||||
|  | ||||
| import mock | ||||
| import six | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from cassandra.cqlengine import columns | ||||
| from cassandra.cqlengine.management import sync_table, drop_table | ||||
| from cassandra.cqlengine.models import Model | ||||
| from cassandra.cqlengine.query import BatchQuery, LWTException | ||||
| from cassandra.cqlengine.statements import ConditionalClause | ||||
|  | ||||
| from tests.integration.cqlengine.base import BaseCassEngTestCase | ||||
| from tests.integration import CASSANDRA_VERSION | ||||
|  | ||||
|  | ||||
| class TestConditionalModel(Model): | ||||
|     id = columns.UUID(primary_key=True, default=uuid4) | ||||
|     count = columns.Integer() | ||||
|     text = columns.Text(required=False) | ||||
|  | ||||
|  | ||||
| @unittest.skipUnless(CASSANDRA_VERSION >= '2.0.0', "conditionals only supported on cassandra 2.0 or higher") | ||||
| class TestConditional(BaseCassEngTestCase): | ||||
|  | ||||
|     @classmethod | ||||
|     def setUpClass(cls): | ||||
|         super(TestConditional, cls).setUpClass() | ||||
|         sync_table(TestConditionalModel) | ||||
|  | ||||
|     @classmethod | ||||
|     def tearDownClass(cls): | ||||
|         super(TestConditional, cls).tearDownClass() | ||||
|         drop_table(TestConditionalModel) | ||||
|  | ||||
|     def test_update_using_conditional(self): | ||||
|         t = TestConditionalModel.create(text='blah blah') | ||||
|         t.text = 'new blah' | ||||
|         with mock.patch.object(self.session, 'execute') as m: | ||||
|             t.iff(text='blah blah').save() | ||||
|  | ||||
|         args = m.call_args | ||||
|         self.assertIn('IF "text" = %(0)s', args[0][0].query_string) | ||||
|  | ||||
|     def test_update_conditional_success(self): | ||||
|         t = TestConditionalModel.create(text='blah blah', count=5) | ||||
|         id = t.id | ||||
|         t.text = 'new blah' | ||||
|         t.iff(text='blah blah').save() | ||||
|  | ||||
|         updated = TestConditionalModel.objects(id=id).first() | ||||
|         self.assertEqual(updated.count, 5) | ||||
|         self.assertEqual(updated.text, 'new blah') | ||||
|  | ||||
|     def test_update_failure(self): | ||||
|         t = TestConditionalModel.create(text='blah blah') | ||||
|         t.text = 'new blah' | ||||
|         t = t.iff(text='something wrong') | ||||
|  | ||||
|         with self.assertRaises(LWTException) as assertion: | ||||
|             t.save() | ||||
|  | ||||
|         self.assertEqual(assertion.exception.existing, { | ||||
|             'text': 'blah blah', | ||||
|             '[applied]': False, | ||||
|         }) | ||||
|  | ||||
|     def test_blind_update(self): | ||||
|         t = TestConditionalModel.create(text='blah blah') | ||||
|         t.text = 'something else' | ||||
|         uid = t.id | ||||
|  | ||||
|         with mock.patch.object(self.session, 'execute') as m: | ||||
|             TestConditionalModel.objects(id=uid).iff(text='blah blah').update(text='oh hey der') | ||||
|  | ||||
|         args = m.call_args | ||||
|         self.assertIn('IF "text" = %(1)s', args[0][0].query_string) | ||||
|  | ||||
|     def test_blind_update_fail(self): | ||||
|         t = TestConditionalModel.create(text='blah blah') | ||||
|         t.text = 'something else' | ||||
|         uid = t.id | ||||
|         qs = TestConditionalModel.objects(id=uid).iff(text='Not dis!') | ||||
|         with self.assertRaises(LWTException) as assertion: | ||||
|             qs.update(text='this will never work') | ||||
|  | ||||
|         self.assertEqual(assertion.exception.existing, { | ||||
|             'text': 'blah blah', | ||||
|             '[applied]': False, | ||||
|         }) | ||||
|  | ||||
|     def test_conditional_clause(self): | ||||
|         tc = ConditionalClause('some_value', 23) | ||||
|         tc.set_context_id(3) | ||||
|  | ||||
|         self.assertEqual('"some_value" = %(3)s', six.text_type(tc)) | ||||
|         self.assertEqual('"some_value" = %(3)s', str(tc)) | ||||
|  | ||||
|     def test_batch_update_conditional(self): | ||||
|         t = TestConditionalModel.create(text='something', count=5) | ||||
|         id = t.id | ||||
|         with BatchQuery() as b: | ||||
|             t.batch(b).iff(count=5).update(text='something else') | ||||
|  | ||||
|         updated = TestConditionalModel.objects(id=id).first() | ||||
|         self.assertEqual(updated.text, 'something else') | ||||
|  | ||||
|         b = BatchQuery() | ||||
|         updated.batch(b).iff(count=6).update(text='and another thing') | ||||
|         with self.assertRaises(LWTException) as assertion: | ||||
|             b.execute() | ||||
|  | ||||
|         self.assertEqual(assertion.exception.existing, { | ||||
|             'id': id, | ||||
|             'count': 5, | ||||
|             '[applied]': False, | ||||
|         }) | ||||
|  | ||||
|         updated = TestConditionalModel.objects(id=id).first() | ||||
|         self.assertEqual(updated.text, 'something else') | ||||
|  | ||||
|     def test_delete_conditional(self): | ||||
|         # DML path | ||||
|         t = TestConditionalModel.create(text='something', count=5) | ||||
|         self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) | ||||
|         with self.assertRaises(LWTException): | ||||
|             t.iff(count=9999).delete() | ||||
|         self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) | ||||
|         t.iff(count=5).delete() | ||||
|         self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 0) | ||||
|  | ||||
|         # QuerySet path | ||||
|         t = TestConditionalModel.create(text='something', count=5) | ||||
|         self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) | ||||
|         with self.assertRaises(LWTException): | ||||
|             TestConditionalModel.objects(id=t.id).iff(count=9999).delete() | ||||
|         self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) | ||||
|         TestConditionalModel.objects(id=t.id).iff(count=5).delete() | ||||
|         self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 0) | ||||
|  | ||||
|     def test_update_to_none(self): | ||||
|         # This test is done because updates to none are split into deletes | ||||
|         # for old versions of cassandra. Can be removed when we drop that code | ||||
|         # https://github.com/datastax/python-driver/blob/3.1.1/cassandra/cqlengine/query.py#L1197-L1200 | ||||
|  | ||||
|         # DML path | ||||
|         t = TestConditionalModel.create(text='something', count=5) | ||||
|         self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) | ||||
|         with self.assertRaises(LWTException): | ||||
|             t.iff(count=9999).update(text=None) | ||||
|         self.assertIsNotNone(TestConditionalModel.objects(id=t.id).first().text) | ||||
|         t.iff(count=5).update(text=None) | ||||
|         self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text) | ||||
|  | ||||
|         # QuerySet path | ||||
|         t = TestConditionalModel.create(text='something', count=5) | ||||
|         self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) | ||||
|         with self.assertRaises(LWTException): | ||||
|             TestConditionalModel.objects(id=t.id).iff(count=9999).update(text=None) | ||||
|         self.assertIsNotNone(TestConditionalModel.objects(id=t.id).first().text) | ||||
|         TestConditionalModel.objects(id=t.id).iff(count=5).update(text=None) | ||||
|         self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text) | ||||
| @@ -1,136 +0,0 @@ | ||||
| # Copyright 2013-2016 DataStax, Inc. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| try: | ||||
|     import unittest2 as unittest | ||||
| except ImportError: | ||||
|     import unittest  # noqa | ||||
|  | ||||
| import mock | ||||
| import six | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from cassandra.cqlengine import columns | ||||
| from cassandra.cqlengine.management import sync_table, drop_table | ||||
| from cassandra.cqlengine.models import Model | ||||
| from cassandra.cqlengine.query import BatchQuery, LWTException | ||||
| from cassandra.cqlengine.statements import TransactionClause | ||||
|  | ||||
| from tests.integration.cqlengine.base import BaseCassEngTestCase | ||||
| from tests.integration import CASSANDRA_VERSION | ||||
|  | ||||
|  | ||||
| class TestTransactionModel(Model): | ||||
|     id = columns.UUID(primary_key=True, default=uuid4) | ||||
|     count = columns.Integer() | ||||
|     text = columns.Text(required=False) | ||||
|  | ||||
|  | ||||
| @unittest.skipUnless(CASSANDRA_VERSION >= '2.0.0', "transactions only supported on cassandra 2.0 or higher") | ||||
| class TestTransaction(BaseCassEngTestCase): | ||||
|  | ||||
|     @classmethod | ||||
|     def setUpClass(cls): | ||||
|         super(TestTransaction, cls).setUpClass() | ||||
|         sync_table(TestTransactionModel) | ||||
|  | ||||
|     @classmethod | ||||
|     def tearDownClass(cls): | ||||
|         super(TestTransaction, cls).tearDownClass() | ||||
|         drop_table(TestTransactionModel) | ||||
|  | ||||
|     def test_update_using_transaction(self): | ||||
|         t = TestTransactionModel.create(text='blah blah') | ||||
|         t.text = 'new blah' | ||||
|         with mock.patch.object(self.session, 'execute') as m: | ||||
|             t.iff(text='blah blah').save() | ||||
|  | ||||
|         args = m.call_args | ||||
|         self.assertIn('IF "text" = %(0)s', args[0][0].query_string) | ||||
|  | ||||
|     def test_update_transaction_success(self): | ||||
|         t = TestTransactionModel.create(text='blah blah', count=5) | ||||
|         id = t.id | ||||
|         t.text = 'new blah' | ||||
|         t.iff(text='blah blah').save() | ||||
|  | ||||
|         updated = TestTransactionModel.objects(id=id).first() | ||||
|         self.assertEqual(updated.count, 5) | ||||
|         self.assertEqual(updated.text, 'new blah') | ||||
|  | ||||
|     def test_update_failure(self): | ||||
|         t = TestTransactionModel.create(text='blah blah') | ||||
|         t.text = 'new blah' | ||||
|         t = t.iff(text='something wrong') | ||||
|  | ||||
|         with self.assertRaises(LWTException) as assertion: | ||||
|             t.save() | ||||
|  | ||||
|         self.assertEqual(assertion.exception.existing, { | ||||
|             'text': 'blah blah', | ||||
|             '[applied]': False, | ||||
|         }) | ||||
|  | ||||
|     def test_blind_update(self): | ||||
|         t = TestTransactionModel.create(text='blah blah') | ||||
|         t.text = 'something else' | ||||
|         uid = t.id | ||||
|  | ||||
|         with mock.patch.object(self.session, 'execute') as m: | ||||
|             TestTransactionModel.objects(id=uid).iff(text='blah blah').update(text='oh hey der') | ||||
|  | ||||
|         args = m.call_args | ||||
|         self.assertIn('IF "text" = %(1)s', args[0][0].query_string) | ||||
|  | ||||
|     def test_blind_update_fail(self): | ||||
|         t = TestTransactionModel.create(text='blah blah') | ||||
|         t.text = 'something else' | ||||
|         uid = t.id | ||||
|         qs = TestTransactionModel.objects(id=uid).iff(text='Not dis!') | ||||
|         with self.assertRaises(LWTException) as assertion: | ||||
|             qs.update(text='this will never work') | ||||
|  | ||||
|         self.assertEqual(assertion.exception.existing, { | ||||
|             'text': 'blah blah', | ||||
|             '[applied]': False, | ||||
|         }) | ||||
|  | ||||
|     def test_transaction_clause(self): | ||||
|         tc = TransactionClause('some_value', 23) | ||||
|         tc.set_context_id(3) | ||||
|  | ||||
|         self.assertEqual('"some_value" = %(3)s', six.text_type(tc)) | ||||
|         self.assertEqual('"some_value" = %(3)s', str(tc)) | ||||
|  | ||||
|     def test_batch_update_transaction(self): | ||||
|         t = TestTransactionModel.create(text='something', count=5) | ||||
|         id = t.id | ||||
|         with BatchQuery() as b: | ||||
|             t.batch(b).iff(count=5).update(text='something else') | ||||
|  | ||||
|         updated = TestTransactionModel.objects(id=id).first() | ||||
|         self.assertEqual(updated.text, 'something else') | ||||
|  | ||||
|         b = BatchQuery() | ||||
|         updated.batch(b).iff(count=6).update(text='and another thing') | ||||
|         with self.assertRaises(LWTException) as assertion: | ||||
|             b.execute() | ||||
|  | ||||
|         self.assertEqual(assertion.exception.existing, { | ||||
|             'id': id, | ||||
|             'count': 5, | ||||
|             '[applied]': False, | ||||
|         }) | ||||
|  | ||||
|         updated = TestTransactionModel.objects(id=id).first() | ||||
|         self.assertEqual(updated.text, 'something else') | ||||
		Reference in New Issue
	
	Block a user
	 Adam Holmberg
					Adam Holmberg