Merge pull request #520 from datastax/249
PYTHON-249 - Fixing conditional deletes in cqlengine
This commit is contained in:
@@ -101,28 +101,28 @@ class QuerySetDescriptor(object):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransactionDescriptor(object):
|
||||
class ConditionalDescriptor(object):
|
||||
"""
|
||||
returns a query set descriptor
|
||||
"""
|
||||
def __get__(self, instance, model):
|
||||
if instance:
|
||||
def transaction_setter(*prepared_transaction, **unprepared_transactions):
|
||||
if len(prepared_transaction) > 0:
|
||||
transactions = prepared_transaction[0]
|
||||
def conditional_setter(*prepared_conditional, **unprepared_conditionals):
|
||||
if len(prepared_conditional) > 0:
|
||||
conditionals = prepared_conditional[0]
|
||||
else:
|
||||
transactions = instance.objects.iff(**unprepared_transactions)._transaction
|
||||
instance._transaction = transactions
|
||||
conditionals = instance.objects.iff(**unprepared_conditionals)._conditional
|
||||
instance._conditional = conditionals
|
||||
return instance
|
||||
|
||||
return transaction_setter
|
||||
return conditional_setter
|
||||
qs = model.__queryset__(model)
|
||||
|
||||
def transaction_setter(**unprepared_transactions):
|
||||
transactions = model.objects.iff(**unprepared_transactions)._transaction
|
||||
qs._transaction = transactions
|
||||
def conditional_setter(**unprepared_conditionals):
|
||||
conditionals = model.objects.iff(**unprepared_conditionals)._conditional
|
||||
qs._conditional = conditionals
|
||||
return qs
|
||||
return transaction_setter
|
||||
return conditional_setter
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
@@ -314,7 +314,7 @@ class BaseModel(object):
|
||||
objects = QuerySetDescriptor()
|
||||
ttl = TTLDescriptor()
|
||||
consistency = ConsistencyDescriptor()
|
||||
iff = TransactionDescriptor()
|
||||
iff = ConditionalDescriptor()
|
||||
|
||||
# custom timestamps, see USING TIMESTAMP X
|
||||
timestamp = TimestampDescriptor()
|
||||
@@ -352,7 +352,7 @@ class BaseModel(object):
|
||||
def __init__(self, **values):
|
||||
self._ttl = self.__default_ttl__
|
||||
self._timestamp = None
|
||||
self._transaction = None
|
||||
self._conditional = None
|
||||
self._batch = None
|
||||
self._timeout = connection.NOT_SET
|
||||
self._is_persisted = False
|
||||
@@ -684,7 +684,7 @@ class BaseModel(object):
|
||||
timestamp=self._timestamp,
|
||||
consistency=self.__consistency__,
|
||||
if_not_exists=self._if_not_exists,
|
||||
transaction=self._transaction,
|
||||
conditional=self._conditional,
|
||||
timeout=self._timeout,
|
||||
if_exists=self._if_exists).save()
|
||||
|
||||
@@ -731,7 +731,7 @@ class BaseModel(object):
|
||||
ttl=self._ttl,
|
||||
timestamp=self._timestamp,
|
||||
consistency=self.__consistency__,
|
||||
transaction=self._transaction,
|
||||
conditional=self._conditional,
|
||||
timeout=self._timeout,
|
||||
if_exists=self._if_exists).update()
|
||||
|
||||
@@ -751,6 +751,7 @@ class BaseModel(object):
|
||||
timestamp=self._timestamp,
|
||||
consistency=self.__consistency__,
|
||||
timeout=self._timeout,
|
||||
conditional=self._conditional,
|
||||
if_exists=self._if_exists).delete()
|
||||
|
||||
def get_changed_columns(self):
|
||||
|
||||
@@ -28,7 +28,7 @@ from cassandra.cqlengine.statements import (WhereClause, SelectStatement, Delete
|
||||
UpdateStatement, AssignmentClause, InsertStatement,
|
||||
BaseCQLStatement, MapUpdateClause, MapDeleteClause,
|
||||
ListUpdateClause, SetUpdateClause, CounterUpdateClause,
|
||||
TransactionClause)
|
||||
ConditionalClause)
|
||||
|
||||
|
||||
class QueryException(CQLEngineException):
|
||||
@@ -43,7 +43,7 @@ class IfExistsWithCounterColumn(CQLEngineException):
|
||||
|
||||
|
||||
class LWTException(CQLEngineException):
|
||||
"""Lightweight transaction exception.
|
||||
"""Lightweight conditional exception.
|
||||
|
||||
This exception will be raised when a write using an `IF` clause could not be
|
||||
applied due to existing data violating the condition. The existing data is
|
||||
@@ -146,7 +146,7 @@ class BatchQuery(object):
|
||||
:param batch_type: (optional) One of batch type values available through BatchType enum
|
||||
:type batch_type: str or None
|
||||
:param timestamp: (optional) A datetime or timedelta object with desired timestamp to be applied
|
||||
to the batch transaction.
|
||||
to the batch conditional.
|
||||
:type timestamp: datetime or timedelta or None
|
||||
:param consistency: (optional) One of consistency values ("ANY", "ONE", "QUORUM" etc)
|
||||
:type consistency: The :class:`.ConsistencyLevel` to be used for the batch query, or None.
|
||||
@@ -267,8 +267,8 @@ class AbstractQuerySet(object):
|
||||
# Where clause filters
|
||||
self._where = []
|
||||
|
||||
# Transaction clause filters
|
||||
self._transaction = []
|
||||
# Conditional clause filters
|
||||
self._conditional = []
|
||||
|
||||
# ordering arguments
|
||||
self._order = []
|
||||
@@ -314,7 +314,7 @@ class AbstractQuerySet(object):
|
||||
return self._batch.add_query(q)
|
||||
else:
|
||||
result = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout)
|
||||
if self._if_not_exists or self._if_exists or self._transaction:
|
||||
if self._if_not_exists or self._if_exists or self._conditional:
|
||||
check_applied(result)
|
||||
return result
|
||||
|
||||
@@ -545,9 +545,9 @@ class AbstractQuerySet(object):
|
||||
|
||||
clone = copy.deepcopy(self)
|
||||
for operator in args:
|
||||
if not isinstance(operator, TransactionClause):
|
||||
if not isinstance(operator, ConditionalClause):
|
||||
raise QueryException('{0} is not a valid query operator'.format(operator))
|
||||
clone._transaction.append(operator)
|
||||
clone._conditional.append(operator)
|
||||
|
||||
for col_name, val in kwargs.items():
|
||||
exists = False
|
||||
@@ -576,7 +576,7 @@ class AbstractQuerySet(object):
|
||||
else:
|
||||
query_val = column.to_database(val)
|
||||
|
||||
clone._transaction.append(TransactionClause(col_name, query_val))
|
||||
clone._conditional.append(ConditionalClause(col_name, query_val))
|
||||
|
||||
return clone
|
||||
|
||||
@@ -898,6 +898,7 @@ class AbstractQuerySet(object):
|
||||
self.column_family_name,
|
||||
where=self._where,
|
||||
timestamp=self._timestamp,
|
||||
conditionals=self._conditional,
|
||||
if_exists=self._if_exists
|
||||
)
|
||||
self._execute(dq)
|
||||
@@ -1155,7 +1156,7 @@ class ModelQuerySet(AbstractQuerySet):
|
||||
|
||||
nulled_columns = set()
|
||||
us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl,
|
||||
timestamp=self._timestamp, transactions=self._transaction, if_exists=self._if_exists)
|
||||
timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists)
|
||||
for name, val in values.items():
|
||||
col_name, col_op = self._parse_filter_arg(name)
|
||||
col = self.model._columns.get(col_name)
|
||||
@@ -1196,7 +1197,7 @@ class ModelQuerySet(AbstractQuerySet):
|
||||
|
||||
if nulled_columns:
|
||||
ds = DeleteStatement(self.column_family_name, fields=nulled_columns,
|
||||
where=self._where, if_exists=self._if_exists)
|
||||
where=self._where, conditionals=self._conditional, if_exists=self._if_exists)
|
||||
self._execute(ds)
|
||||
|
||||
|
||||
@@ -1215,7 +1216,7 @@ class DMLQuery(object):
|
||||
_if_exists = False
|
||||
|
||||
def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None,
|
||||
if_not_exists=False, transaction=None, timeout=connection.NOT_SET, if_exists=False):
|
||||
if_not_exists=False, conditional=None, timeout=connection.NOT_SET, if_exists=False):
|
||||
self.model = model
|
||||
self.column_family_name = self.model.column_family_name()
|
||||
self.instance = instance
|
||||
@@ -1225,7 +1226,7 @@ class DMLQuery(object):
|
||||
self._timestamp = timestamp
|
||||
self._if_not_exists = if_not_exists
|
||||
self._if_exists = if_exists
|
||||
self._transaction = transaction
|
||||
self._conditional = conditional
|
||||
self._timeout = timeout
|
||||
|
||||
def _execute(self, q):
|
||||
@@ -1233,7 +1234,7 @@ class DMLQuery(object):
|
||||
return self._batch.add_query(q)
|
||||
else:
|
||||
tmp = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout)
|
||||
if self._if_not_exists or self._if_exists or self._transaction:
|
||||
if self._if_not_exists or self._if_exists or self._conditional:
|
||||
check_applied(tmp)
|
||||
return tmp
|
||||
|
||||
@@ -1247,7 +1248,7 @@ class DMLQuery(object):
|
||||
"""
|
||||
executes a delete query to remove columns that have changed to null
|
||||
"""
|
||||
ds = DeleteStatement(self.column_family_name, if_exists=self._if_exists)
|
||||
ds = DeleteStatement(self.column_family_name, conditionals=self._conditional, if_exists=self._if_exists)
|
||||
deleted_fields = False
|
||||
for _, v in self.instance._values.items():
|
||||
col = v.column
|
||||
@@ -1282,7 +1283,7 @@ class DMLQuery(object):
|
||||
null_clustering_key = False if len(self.instance._clustering_keys) == 0 else True
|
||||
static_changed_only = True
|
||||
statement = UpdateStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp,
|
||||
transactions=self._transaction, if_exists=self._if_exists)
|
||||
conditionals=self._conditional, if_exists=self._if_exists)
|
||||
for name, col in self.instance._clustering_keys.items():
|
||||
null_clustering_key = null_clustering_key and col._val_is_null(getattr(self.instance, name, None))
|
||||
# get defined fields and their column names
|
||||
@@ -1324,7 +1325,7 @@ class DMLQuery(object):
|
||||
col.to_database(val)
|
||||
))
|
||||
|
||||
if statement.get_context_size() > 0 or self.instance._has_counter:
|
||||
if statement.assignments:
|
||||
for name, col in self.model._primary_keys.items():
|
||||
# only include clustering key if clustering key is not null, and non static columns are changed to avoid cql error
|
||||
if (null_clustering_key or static_changed_only) and (not col.partition_key):
|
||||
@@ -1386,7 +1387,7 @@ class DMLQuery(object):
|
||||
if self.instance is None:
|
||||
raise CQLEngineException("DML Query instance attribute is None")
|
||||
|
||||
ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp, if_exists=self._if_exists)
|
||||
ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists)
|
||||
for name, col in self.model._primary_keys.items():
|
||||
if (not col.partition_key) and (getattr(self.instance, name) is None):
|
||||
continue
|
||||
|
||||
@@ -148,7 +148,7 @@ class AssignmentClause(BaseClause):
|
||||
return self.field, self.context_id
|
||||
|
||||
|
||||
class TransactionClause(BaseClause):
|
||||
class ConditionalClause(BaseClause):
|
||||
""" A single variable iff statement """
|
||||
|
||||
def __unicode__(self):
|
||||
@@ -471,7 +471,7 @@ class MapDeleteClause(BaseDeleteClause):
|
||||
class BaseCQLStatement(UnicodeMixin):
|
||||
""" The base cql statement class """
|
||||
|
||||
def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None):
|
||||
def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, conditionals=None):
|
||||
super(BaseCQLStatement, self).__init__()
|
||||
self.table = table
|
||||
self.consistency = consistency
|
||||
@@ -484,6 +484,10 @@ class BaseCQLStatement(UnicodeMixin):
|
||||
for clause in where or []:
|
||||
self.add_where_clause(clause)
|
||||
|
||||
self.conditionals = []
|
||||
for conditional in conditionals or []:
|
||||
self.add_conditional_clause(conditional)
|
||||
|
||||
def add_where_clause(self, clause):
|
||||
"""
|
||||
adds a where clause to this statement
|
||||
@@ -506,6 +510,22 @@ class BaseCQLStatement(UnicodeMixin):
|
||||
clause.update_context(ctx)
|
||||
return ctx
|
||||
|
||||
def add_conditional_clause(self, clause):
|
||||
"""
|
||||
Adds a iff clause to this statement
|
||||
|
||||
:param clause: The clause that will be added to the iff statement
|
||||
:type clause: ConditionalClause
|
||||
"""
|
||||
if not isinstance(clause, ConditionalClause):
|
||||
raise StatementException('only instances of AssignmentClause can be added to statements')
|
||||
clause.set_context_id(self.context_counter)
|
||||
self.context_counter += clause.get_context_size()
|
||||
self.conditionals.append(clause)
|
||||
|
||||
def _get_conditionals(self):
|
||||
return 'IF {0}'.format(' AND '.join([six.text_type(c) for c in self.conditionals]))
|
||||
|
||||
def get_context_size(self):
|
||||
return len(self.get_context())
|
||||
|
||||
@@ -616,11 +636,13 @@ class AssignmentStatement(BaseCQLStatement):
|
||||
consistency=None,
|
||||
where=None,
|
||||
ttl=None,
|
||||
timestamp=None):
|
||||
timestamp=None,
|
||||
conditionals=None):
|
||||
super(AssignmentStatement, self).__init__(
|
||||
table,
|
||||
consistency=consistency,
|
||||
where=where,
|
||||
conditionals=conditionals
|
||||
)
|
||||
self.ttl = ttl
|
||||
self.timestamp = timestamp
|
||||
@@ -715,19 +737,15 @@ class UpdateStatement(AssignmentStatement):
|
||||
where=None,
|
||||
ttl=None,
|
||||
timestamp=None,
|
||||
transactions=None,
|
||||
conditionals=None,
|
||||
if_exists=False):
|
||||
super(UpdateStatement, self). __init__(table,
|
||||
assignments=assignments,
|
||||
consistency=consistency,
|
||||
where=where,
|
||||
ttl=ttl,
|
||||
timestamp=timestamp)
|
||||
|
||||
# Add iff statements
|
||||
self.transactions = []
|
||||
for transaction in transactions or []:
|
||||
self.add_transaction_clause(transaction)
|
||||
timestamp=timestamp,
|
||||
conditionals=conditionals)
|
||||
|
||||
self.if_exists = if_exists
|
||||
|
||||
@@ -751,58 +769,44 @@ class UpdateStatement(AssignmentStatement):
|
||||
if self.where_clauses:
|
||||
qs += [self._where]
|
||||
|
||||
if len(self.transactions) > 0:
|
||||
qs += [self._get_transactions()]
|
||||
if len(self.conditionals) > 0:
|
||||
qs += [self._get_conditionals()]
|
||||
|
||||
if self.if_exists:
|
||||
qs += ["IF EXISTS"]
|
||||
|
||||
return ' '.join(qs)
|
||||
|
||||
def add_transaction_clause(self, clause):
|
||||
"""
|
||||
Adds a iff clause to this statement
|
||||
|
||||
:param clause: The clause that will be added to the iff statement
|
||||
:type clause: TransactionClause
|
||||
"""
|
||||
if not isinstance(clause, TransactionClause):
|
||||
raise StatementException('only instances of AssignmentClause can be added to statements')
|
||||
clause.set_context_id(self.context_counter)
|
||||
self.context_counter += clause.get_context_size()
|
||||
self.transactions.append(clause)
|
||||
|
||||
def get_context(self):
|
||||
ctx = super(UpdateStatement, self).get_context()
|
||||
for clause in self.transactions or []:
|
||||
for clause in self.conditionals:
|
||||
clause.update_context(ctx)
|
||||
return ctx
|
||||
|
||||
def _get_transactions(self):
|
||||
return 'IF {0}'.format(' AND '.join([six.text_type(c) for c in self.transactions]))
|
||||
|
||||
def update_context_id(self, i):
|
||||
super(UpdateStatement, self).update_context_id(i)
|
||||
for transaction in self.transactions:
|
||||
transaction.set_context_id(self.context_counter)
|
||||
self.context_counter += transaction.get_context_size()
|
||||
for conditional in self.conditionals:
|
||||
conditional.set_context_id(self.context_counter)
|
||||
self.context_counter += conditional.get_context_size()
|
||||
|
||||
|
||||
class DeleteStatement(BaseCQLStatement):
|
||||
""" a cql delete statement """
|
||||
|
||||
def __init__(self, table, fields=None, consistency=None, where=None, timestamp=None, if_exists=False):
|
||||
def __init__(self, table, fields=None, consistency=None, where=None, timestamp=None, conditionals=None, if_exists=False):
|
||||
super(DeleteStatement, self).__init__(
|
||||
table,
|
||||
consistency=consistency,
|
||||
where=where,
|
||||
timestamp=timestamp
|
||||
timestamp=timestamp,
|
||||
conditionals=conditionals
|
||||
)
|
||||
self.fields = []
|
||||
if isinstance(fields, six.string_types):
|
||||
fields = [fields]
|
||||
for field in fields or []:
|
||||
self.add_field(field)
|
||||
|
||||
self.if_exists = if_exists
|
||||
|
||||
def update_context_id(self, i):
|
||||
@@ -810,11 +814,16 @@ class DeleteStatement(BaseCQLStatement):
|
||||
for field in self.fields:
|
||||
field.set_context_id(self.context_counter)
|
||||
self.context_counter += field.get_context_size()
|
||||
for t in self.conditionals:
|
||||
t.set_context_id(self.context_counter)
|
||||
self.context_counter += t.get_context_size()
|
||||
|
||||
def get_context(self):
|
||||
ctx = super(DeleteStatement, self).get_context()
|
||||
for field in self.fields:
|
||||
field.update_context(ctx)
|
||||
for clause in self.conditionals:
|
||||
clause.update_context(ctx)
|
||||
return ctx
|
||||
|
||||
def add_field(self, field):
|
||||
@@ -843,6 +852,9 @@ class DeleteStatement(BaseCQLStatement):
|
||||
if self.where_clauses:
|
||||
qs += [self._where]
|
||||
|
||||
if self.conditionals:
|
||||
qs += [self._get_conditionals()]
|
||||
|
||||
if self.if_exists:
|
||||
qs += ["IF EXISTS"]
|
||||
|
||||
|
||||
@@ -13,10 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
from unittest import TestCase
|
||||
from cassandra.cqlengine.statements import DeleteStatement, WhereClause, MapDeleteClause
|
||||
from cassandra.cqlengine.statements import DeleteStatement, WhereClause, MapDeleteClause, ConditionalClause
|
||||
from cassandra.cqlengine.operators import *
|
||||
import six
|
||||
|
||||
|
||||
class DeleteStatementTests(TestCase):
|
||||
|
||||
def test_single_field_is_listified(self):
|
||||
@@ -49,7 +50,7 @@ class DeleteStatementTests(TestCase):
|
||||
|
||||
def test_context_update(self):
|
||||
ds = DeleteStatement('table', None)
|
||||
ds.add_field(MapDeleteClause('d', {1: 2}, {1:2, 3: 4}))
|
||||
ds.add_field(MapDeleteClause('d', {1: 2}, {1: 2, 3: 4}))
|
||||
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
||||
|
||||
ds.update_context_id(7)
|
||||
@@ -72,3 +73,13 @@ class DeleteStatementTests(TestCase):
|
||||
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
||||
ds.add_where_clause(WhereClause('created_at', InOperator(), ['0', '10', '20']))
|
||||
self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" IN %(1)s', six.text_type(ds))
|
||||
|
||||
def test_delete_conditional(self):
|
||||
where = [WhereClause('id', EqualsOperator(), 1)]
|
||||
conditionals = [ConditionalClause('f0', 'value0'), ConditionalClause('f1', 'value1')]
|
||||
ds = DeleteStatement('table', where=where, conditionals=conditionals)
|
||||
self.assertEqual(len(ds.conditionals), len(conditionals))
|
||||
self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "id" = %(0)s IF "f0" = %(1)s AND "f1" = %(2)s', six.text_type(ds))
|
||||
fields = ['one', 'two']
|
||||
ds = DeleteStatement('table', fields=fields, where=where, conditionals=conditionals)
|
||||
self.assertEqual(six.text_type(ds), 'DELETE "one", "two" FROM table WHERE "id" = %(0)s IF "f0" = %(1)s AND "f1" = %(2)s', six.text_type(ds))
|
||||
|
||||
178
tests/integration/cqlengine/test_lwt_conditional.py
Normal file
178
tests/integration/cqlengine/test_lwt_conditional.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# Copyright 2013-2016 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
import mock
|
||||
import six
|
||||
from uuid import uuid4
|
||||
|
||||
from cassandra.cqlengine import columns
|
||||
from cassandra.cqlengine.management import sync_table, drop_table
|
||||
from cassandra.cqlengine.models import Model
|
||||
from cassandra.cqlengine.query import BatchQuery, LWTException
|
||||
from cassandra.cqlengine.statements import ConditionalClause
|
||||
|
||||
from tests.integration.cqlengine.base import BaseCassEngTestCase
|
||||
from tests.integration import CASSANDRA_VERSION
|
||||
|
||||
|
||||
class TestConditionalModel(Model):
|
||||
id = columns.UUID(primary_key=True, default=uuid4)
|
||||
count = columns.Integer()
|
||||
text = columns.Text(required=False)
|
||||
|
||||
|
||||
@unittest.skipUnless(CASSANDRA_VERSION >= '2.0.0', "conditionals only supported on cassandra 2.0 or higher")
|
||||
class TestConditional(BaseCassEngTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(TestConditional, cls).setUpClass()
|
||||
sync_table(TestConditionalModel)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super(TestConditional, cls).tearDownClass()
|
||||
drop_table(TestConditionalModel)
|
||||
|
||||
def test_update_using_conditional(self):
|
||||
t = TestConditionalModel.create(text='blah blah')
|
||||
t.text = 'new blah'
|
||||
with mock.patch.object(self.session, 'execute') as m:
|
||||
t.iff(text='blah blah').save()
|
||||
|
||||
args = m.call_args
|
||||
self.assertIn('IF "text" = %(0)s', args[0][0].query_string)
|
||||
|
||||
def test_update_conditional_success(self):
|
||||
t = TestConditionalModel.create(text='blah blah', count=5)
|
||||
id = t.id
|
||||
t.text = 'new blah'
|
||||
t.iff(text='blah blah').save()
|
||||
|
||||
updated = TestConditionalModel.objects(id=id).first()
|
||||
self.assertEqual(updated.count, 5)
|
||||
self.assertEqual(updated.text, 'new blah')
|
||||
|
||||
def test_update_failure(self):
|
||||
t = TestConditionalModel.create(text='blah blah')
|
||||
t.text = 'new blah'
|
||||
t = t.iff(text='something wrong')
|
||||
|
||||
with self.assertRaises(LWTException) as assertion:
|
||||
t.save()
|
||||
|
||||
self.assertEqual(assertion.exception.existing, {
|
||||
'text': 'blah blah',
|
||||
'[applied]': False,
|
||||
})
|
||||
|
||||
def test_blind_update(self):
|
||||
t = TestConditionalModel.create(text='blah blah')
|
||||
t.text = 'something else'
|
||||
uid = t.id
|
||||
|
||||
with mock.patch.object(self.session, 'execute') as m:
|
||||
TestConditionalModel.objects(id=uid).iff(text='blah blah').update(text='oh hey der')
|
||||
|
||||
args = m.call_args
|
||||
self.assertIn('IF "text" = %(1)s', args[0][0].query_string)
|
||||
|
||||
def test_blind_update_fail(self):
|
||||
t = TestConditionalModel.create(text='blah blah')
|
||||
t.text = 'something else'
|
||||
uid = t.id
|
||||
qs = TestConditionalModel.objects(id=uid).iff(text='Not dis!')
|
||||
with self.assertRaises(LWTException) as assertion:
|
||||
qs.update(text='this will never work')
|
||||
|
||||
self.assertEqual(assertion.exception.existing, {
|
||||
'text': 'blah blah',
|
||||
'[applied]': False,
|
||||
})
|
||||
|
||||
def test_conditional_clause(self):
|
||||
tc = ConditionalClause('some_value', 23)
|
||||
tc.set_context_id(3)
|
||||
|
||||
self.assertEqual('"some_value" = %(3)s', six.text_type(tc))
|
||||
self.assertEqual('"some_value" = %(3)s', str(tc))
|
||||
|
||||
def test_batch_update_conditional(self):
|
||||
t = TestConditionalModel.create(text='something', count=5)
|
||||
id = t.id
|
||||
with BatchQuery() as b:
|
||||
t.batch(b).iff(count=5).update(text='something else')
|
||||
|
||||
updated = TestConditionalModel.objects(id=id).first()
|
||||
self.assertEqual(updated.text, 'something else')
|
||||
|
||||
b = BatchQuery()
|
||||
updated.batch(b).iff(count=6).update(text='and another thing')
|
||||
with self.assertRaises(LWTException) as assertion:
|
||||
b.execute()
|
||||
|
||||
self.assertEqual(assertion.exception.existing, {
|
||||
'id': id,
|
||||
'count': 5,
|
||||
'[applied]': False,
|
||||
})
|
||||
|
||||
updated = TestConditionalModel.objects(id=id).first()
|
||||
self.assertEqual(updated.text, 'something else')
|
||||
|
||||
def test_delete_conditional(self):
|
||||
# DML path
|
||||
t = TestConditionalModel.create(text='something', count=5)
|
||||
self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1)
|
||||
with self.assertRaises(LWTException):
|
||||
t.iff(count=9999).delete()
|
||||
self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1)
|
||||
t.iff(count=5).delete()
|
||||
self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 0)
|
||||
|
||||
# QuerySet path
|
||||
t = TestConditionalModel.create(text='something', count=5)
|
||||
self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1)
|
||||
with self.assertRaises(LWTException):
|
||||
TestConditionalModel.objects(id=t.id).iff(count=9999).delete()
|
||||
self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1)
|
||||
TestConditionalModel.objects(id=t.id).iff(count=5).delete()
|
||||
self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 0)
|
||||
|
||||
def test_update_to_none(self):
|
||||
# This test is done because updates to none are split into deletes
|
||||
# for old versions of cassandra. Can be removed when we drop that code
|
||||
# https://github.com/datastax/python-driver/blob/3.1.1/cassandra/cqlengine/query.py#L1197-L1200
|
||||
|
||||
# DML path
|
||||
t = TestConditionalModel.create(text='something', count=5)
|
||||
self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1)
|
||||
with self.assertRaises(LWTException):
|
||||
t.iff(count=9999).update(text=None)
|
||||
self.assertIsNotNone(TestConditionalModel.objects(id=t.id).first().text)
|
||||
t.iff(count=5).update(text=None)
|
||||
self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text)
|
||||
|
||||
# QuerySet path
|
||||
t = TestConditionalModel.create(text='something', count=5)
|
||||
self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1)
|
||||
with self.assertRaises(LWTException):
|
||||
TestConditionalModel.objects(id=t.id).iff(count=9999).update(text=None)
|
||||
self.assertIsNotNone(TestConditionalModel.objects(id=t.id).first().text)
|
||||
TestConditionalModel.objects(id=t.id).iff(count=5).update(text=None)
|
||||
self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text)
|
||||
@@ -1,136 +0,0 @@
|
||||
# Copyright 2013-2016 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
import mock
|
||||
import six
|
||||
from uuid import uuid4
|
||||
|
||||
from cassandra.cqlengine import columns
|
||||
from cassandra.cqlengine.management import sync_table, drop_table
|
||||
from cassandra.cqlengine.models import Model
|
||||
from cassandra.cqlengine.query import BatchQuery, LWTException
|
||||
from cassandra.cqlengine.statements import TransactionClause
|
||||
|
||||
from tests.integration.cqlengine.base import BaseCassEngTestCase
|
||||
from tests.integration import CASSANDRA_VERSION
|
||||
|
||||
|
||||
class TestTransactionModel(Model):
|
||||
id = columns.UUID(primary_key=True, default=uuid4)
|
||||
count = columns.Integer()
|
||||
text = columns.Text(required=False)
|
||||
|
||||
|
||||
@unittest.skipUnless(CASSANDRA_VERSION >= '2.0.0', "transactions only supported on cassandra 2.0 or higher")
|
||||
class TestTransaction(BaseCassEngTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(TestTransaction, cls).setUpClass()
|
||||
sync_table(TestTransactionModel)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super(TestTransaction, cls).tearDownClass()
|
||||
drop_table(TestTransactionModel)
|
||||
|
||||
def test_update_using_transaction(self):
|
||||
t = TestTransactionModel.create(text='blah blah')
|
||||
t.text = 'new blah'
|
||||
with mock.patch.object(self.session, 'execute') as m:
|
||||
t.iff(text='blah blah').save()
|
||||
|
||||
args = m.call_args
|
||||
self.assertIn('IF "text" = %(0)s', args[0][0].query_string)
|
||||
|
||||
def test_update_transaction_success(self):
|
||||
t = TestTransactionModel.create(text='blah blah', count=5)
|
||||
id = t.id
|
||||
t.text = 'new blah'
|
||||
t.iff(text='blah blah').save()
|
||||
|
||||
updated = TestTransactionModel.objects(id=id).first()
|
||||
self.assertEqual(updated.count, 5)
|
||||
self.assertEqual(updated.text, 'new blah')
|
||||
|
||||
def test_update_failure(self):
|
||||
t = TestTransactionModel.create(text='blah blah')
|
||||
t.text = 'new blah'
|
||||
t = t.iff(text='something wrong')
|
||||
|
||||
with self.assertRaises(LWTException) as assertion:
|
||||
t.save()
|
||||
|
||||
self.assertEqual(assertion.exception.existing, {
|
||||
'text': 'blah blah',
|
||||
'[applied]': False,
|
||||
})
|
||||
|
||||
def test_blind_update(self):
|
||||
t = TestTransactionModel.create(text='blah blah')
|
||||
t.text = 'something else'
|
||||
uid = t.id
|
||||
|
||||
with mock.patch.object(self.session, 'execute') as m:
|
||||
TestTransactionModel.objects(id=uid).iff(text='blah blah').update(text='oh hey der')
|
||||
|
||||
args = m.call_args
|
||||
self.assertIn('IF "text" = %(1)s', args[0][0].query_string)
|
||||
|
||||
def test_blind_update_fail(self):
|
||||
t = TestTransactionModel.create(text='blah blah')
|
||||
t.text = 'something else'
|
||||
uid = t.id
|
||||
qs = TestTransactionModel.objects(id=uid).iff(text='Not dis!')
|
||||
with self.assertRaises(LWTException) as assertion:
|
||||
qs.update(text='this will never work')
|
||||
|
||||
self.assertEqual(assertion.exception.existing, {
|
||||
'text': 'blah blah',
|
||||
'[applied]': False,
|
||||
})
|
||||
|
||||
def test_transaction_clause(self):
|
||||
tc = TransactionClause('some_value', 23)
|
||||
tc.set_context_id(3)
|
||||
|
||||
self.assertEqual('"some_value" = %(3)s', six.text_type(tc))
|
||||
self.assertEqual('"some_value" = %(3)s', str(tc))
|
||||
|
||||
def test_batch_update_transaction(self):
|
||||
t = TestTransactionModel.create(text='something', count=5)
|
||||
id = t.id
|
||||
with BatchQuery() as b:
|
||||
t.batch(b).iff(count=5).update(text='something else')
|
||||
|
||||
updated = TestTransactionModel.objects(id=id).first()
|
||||
self.assertEqual(updated.text, 'something else')
|
||||
|
||||
b = BatchQuery()
|
||||
updated.batch(b).iff(count=6).update(text='and another thing')
|
||||
with self.assertRaises(LWTException) as assertion:
|
||||
b.execute()
|
||||
|
||||
self.assertEqual(assertion.exception.existing, {
|
||||
'id': id,
|
||||
'count': 5,
|
||||
'[applied]': False,
|
||||
})
|
||||
|
||||
updated = TestTransactionModel.objects(id=id).first()
|
||||
self.assertEqual(updated.text, 'something else')
|
||||
Reference in New Issue
Block a user