Merge pull request #520 from datastax/249

PYTHON-249 - Fixing conditional deletes in cqlengine
This commit is contained in:
Adam Holmberg
2016-03-22 09:17:54 -05:00
6 changed files with 272 additions and 205 deletions

View File

@@ -101,28 +101,28 @@ class QuerySetDescriptor(object):
raise NotImplementedError
class TransactionDescriptor(object):
class ConditionalDescriptor(object):
"""
returns a query set descriptor
"""
def __get__(self, instance, model):
if instance:
def transaction_setter(*prepared_transaction, **unprepared_transactions):
if len(prepared_transaction) > 0:
transactions = prepared_transaction[0]
def conditional_setter(*prepared_conditional, **unprepared_conditionals):
if len(prepared_conditional) > 0:
conditionals = prepared_conditional[0]
else:
transactions = instance.objects.iff(**unprepared_transactions)._transaction
instance._transaction = transactions
conditionals = instance.objects.iff(**unprepared_conditionals)._conditional
instance._conditional = conditionals
return instance
return transaction_setter
return conditional_setter
qs = model.__queryset__(model)
def transaction_setter(**unprepared_transactions):
transactions = model.objects.iff(**unprepared_transactions)._transaction
qs._transaction = transactions
def conditional_setter(**unprepared_conditionals):
conditionals = model.objects.iff(**unprepared_conditionals)._conditional
qs._conditional = conditionals
return qs
return transaction_setter
return conditional_setter
def __call__(self, *args, **kwargs):
raise NotImplementedError
@@ -314,7 +314,7 @@ class BaseModel(object):
objects = QuerySetDescriptor()
ttl = TTLDescriptor()
consistency = ConsistencyDescriptor()
iff = TransactionDescriptor()
iff = ConditionalDescriptor()
# custom timestamps, see USING TIMESTAMP X
timestamp = TimestampDescriptor()
@@ -352,7 +352,7 @@ class BaseModel(object):
def __init__(self, **values):
self._ttl = self.__default_ttl__
self._timestamp = None
self._transaction = None
self._conditional = None
self._batch = None
self._timeout = connection.NOT_SET
self._is_persisted = False
@@ -684,7 +684,7 @@ class BaseModel(object):
timestamp=self._timestamp,
consistency=self.__consistency__,
if_not_exists=self._if_not_exists,
transaction=self._transaction,
conditional=self._conditional,
timeout=self._timeout,
if_exists=self._if_exists).save()
@@ -731,7 +731,7 @@ class BaseModel(object):
ttl=self._ttl,
timestamp=self._timestamp,
consistency=self.__consistency__,
transaction=self._transaction,
conditional=self._conditional,
timeout=self._timeout,
if_exists=self._if_exists).update()
@@ -751,6 +751,7 @@ class BaseModel(object):
timestamp=self._timestamp,
consistency=self.__consistency__,
timeout=self._timeout,
conditional=self._conditional,
if_exists=self._if_exists).delete()
def get_changed_columns(self):

View File

@@ -28,7 +28,7 @@ from cassandra.cqlengine.statements import (WhereClause, SelectStatement, Delete
UpdateStatement, AssignmentClause, InsertStatement,
BaseCQLStatement, MapUpdateClause, MapDeleteClause,
ListUpdateClause, SetUpdateClause, CounterUpdateClause,
TransactionClause)
ConditionalClause)
class QueryException(CQLEngineException):
@@ -43,7 +43,7 @@ class IfExistsWithCounterColumn(CQLEngineException):
class LWTException(CQLEngineException):
"""Lightweight transaction exception.
"""Lightweight conditional exception.
This exception will be raised when a write using an `IF` clause could not be
applied due to existing data violating the condition. The existing data is
@@ -146,7 +146,7 @@ class BatchQuery(object):
:param batch_type: (optional) One of batch type values available through BatchType enum
:type batch_type: str or None
:param timestamp: (optional) A datetime or timedelta object with desired timestamp to be applied
to the batch transaction.
to the batch conditional.
:type timestamp: datetime or timedelta or None
:param consistency: (optional) One of consistency values ("ANY", "ONE", "QUORUM" etc)
:type consistency: The :class:`.ConsistencyLevel` to be used for the batch query, or None.
@@ -267,8 +267,8 @@ class AbstractQuerySet(object):
# Where clause filters
self._where = []
# Transaction clause filters
self._transaction = []
# Conditional clause filters
self._conditional = []
# ordering arguments
self._order = []
@@ -314,7 +314,7 @@ class AbstractQuerySet(object):
return self._batch.add_query(q)
else:
result = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout)
if self._if_not_exists or self._if_exists or self._transaction:
if self._if_not_exists or self._if_exists or self._conditional:
check_applied(result)
return result
@@ -545,9 +545,9 @@ class AbstractQuerySet(object):
clone = copy.deepcopy(self)
for operator in args:
if not isinstance(operator, TransactionClause):
if not isinstance(operator, ConditionalClause):
raise QueryException('{0} is not a valid query operator'.format(operator))
clone._transaction.append(operator)
clone._conditional.append(operator)
for col_name, val in kwargs.items():
exists = False
@@ -576,7 +576,7 @@ class AbstractQuerySet(object):
else:
query_val = column.to_database(val)
clone._transaction.append(TransactionClause(col_name, query_val))
clone._conditional.append(ConditionalClause(col_name, query_val))
return clone
@@ -898,6 +898,7 @@ class AbstractQuerySet(object):
self.column_family_name,
where=self._where,
timestamp=self._timestamp,
conditionals=self._conditional,
if_exists=self._if_exists
)
self._execute(dq)
@@ -1155,7 +1156,7 @@ class ModelQuerySet(AbstractQuerySet):
nulled_columns = set()
us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl,
timestamp=self._timestamp, transactions=self._transaction, if_exists=self._if_exists)
timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists)
for name, val in values.items():
col_name, col_op = self._parse_filter_arg(name)
col = self.model._columns.get(col_name)
@@ -1196,7 +1197,7 @@ class ModelQuerySet(AbstractQuerySet):
if nulled_columns:
ds = DeleteStatement(self.column_family_name, fields=nulled_columns,
where=self._where, if_exists=self._if_exists)
where=self._where, conditionals=self._conditional, if_exists=self._if_exists)
self._execute(ds)
@@ -1215,7 +1216,7 @@ class DMLQuery(object):
_if_exists = False
def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None,
if_not_exists=False, transaction=None, timeout=connection.NOT_SET, if_exists=False):
if_not_exists=False, conditional=None, timeout=connection.NOT_SET, if_exists=False):
self.model = model
self.column_family_name = self.model.column_family_name()
self.instance = instance
@@ -1225,7 +1226,7 @@ class DMLQuery(object):
self._timestamp = timestamp
self._if_not_exists = if_not_exists
self._if_exists = if_exists
self._transaction = transaction
self._conditional = conditional
self._timeout = timeout
def _execute(self, q):
@@ -1233,7 +1234,7 @@ class DMLQuery(object):
return self._batch.add_query(q)
else:
tmp = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout)
if self._if_not_exists or self._if_exists or self._transaction:
if self._if_not_exists or self._if_exists or self._conditional:
check_applied(tmp)
return tmp
@@ -1247,7 +1248,7 @@ class DMLQuery(object):
"""
executes a delete query to remove columns that have changed to null
"""
ds = DeleteStatement(self.column_family_name, if_exists=self._if_exists)
ds = DeleteStatement(self.column_family_name, conditionals=self._conditional, if_exists=self._if_exists)
deleted_fields = False
for _, v in self.instance._values.items():
col = v.column
@@ -1282,7 +1283,7 @@ class DMLQuery(object):
null_clustering_key = False if len(self.instance._clustering_keys) == 0 else True
static_changed_only = True
statement = UpdateStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp,
transactions=self._transaction, if_exists=self._if_exists)
conditionals=self._conditional, if_exists=self._if_exists)
for name, col in self.instance._clustering_keys.items():
null_clustering_key = null_clustering_key and col._val_is_null(getattr(self.instance, name, None))
# get defined fields and their column names
@@ -1324,7 +1325,7 @@ class DMLQuery(object):
col.to_database(val)
))
if statement.get_context_size() > 0 or self.instance._has_counter:
if statement.assignments:
for name, col in self.model._primary_keys.items():
# only include clustering key if clustering key is not null, and non static columns are changed to avoid cql error
if (null_clustering_key or static_changed_only) and (not col.partition_key):
@@ -1386,7 +1387,7 @@ class DMLQuery(object):
if self.instance is None:
raise CQLEngineException("DML Query instance attribute is None")
ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp, if_exists=self._if_exists)
ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists)
for name, col in self.model._primary_keys.items():
if (not col.partition_key) and (getattr(self.instance, name) is None):
continue

View File

@@ -148,7 +148,7 @@ class AssignmentClause(BaseClause):
return self.field, self.context_id
class TransactionClause(BaseClause):
class ConditionalClause(BaseClause):
""" A single variable iff statement """
def __unicode__(self):
@@ -471,7 +471,7 @@ class MapDeleteClause(BaseDeleteClause):
class BaseCQLStatement(UnicodeMixin):
""" The base cql statement class """
def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None):
def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, conditionals=None):
super(BaseCQLStatement, self).__init__()
self.table = table
self.consistency = consistency
@@ -484,6 +484,10 @@ class BaseCQLStatement(UnicodeMixin):
for clause in where or []:
self.add_where_clause(clause)
self.conditionals = []
for conditional in conditionals or []:
self.add_conditional_clause(conditional)
def add_where_clause(self, clause):
"""
adds a where clause to this statement
@@ -506,6 +510,22 @@ class BaseCQLStatement(UnicodeMixin):
clause.update_context(ctx)
return ctx
def add_conditional_clause(self, clause):
"""
Adds a iff clause to this statement
:param clause: The clause that will be added to the iff statement
:type clause: ConditionalClause
"""
if not isinstance(clause, ConditionalClause):
raise StatementException('only instances of AssignmentClause can be added to statements')
clause.set_context_id(self.context_counter)
self.context_counter += clause.get_context_size()
self.conditionals.append(clause)
def _get_conditionals(self):
return 'IF {0}'.format(' AND '.join([six.text_type(c) for c in self.conditionals]))
def get_context_size(self):
return len(self.get_context())
@@ -616,11 +636,13 @@ class AssignmentStatement(BaseCQLStatement):
consistency=None,
where=None,
ttl=None,
timestamp=None):
timestamp=None,
conditionals=None):
super(AssignmentStatement, self).__init__(
table,
consistency=consistency,
where=where,
conditionals=conditionals
)
self.ttl = ttl
self.timestamp = timestamp
@@ -715,19 +737,15 @@ class UpdateStatement(AssignmentStatement):
where=None,
ttl=None,
timestamp=None,
transactions=None,
conditionals=None,
if_exists=False):
super(UpdateStatement, self). __init__(table,
assignments=assignments,
consistency=consistency,
where=where,
ttl=ttl,
timestamp=timestamp)
# Add iff statements
self.transactions = []
for transaction in transactions or []:
self.add_transaction_clause(transaction)
timestamp=timestamp,
conditionals=conditionals)
self.if_exists = if_exists
@@ -751,58 +769,44 @@ class UpdateStatement(AssignmentStatement):
if self.where_clauses:
qs += [self._where]
if len(self.transactions) > 0:
qs += [self._get_transactions()]
if len(self.conditionals) > 0:
qs += [self._get_conditionals()]
if self.if_exists:
qs += ["IF EXISTS"]
return ' '.join(qs)
def add_transaction_clause(self, clause):
"""
Adds a iff clause to this statement
:param clause: The clause that will be added to the iff statement
:type clause: TransactionClause
"""
if not isinstance(clause, TransactionClause):
raise StatementException('only instances of AssignmentClause can be added to statements')
clause.set_context_id(self.context_counter)
self.context_counter += clause.get_context_size()
self.transactions.append(clause)
def get_context(self):
ctx = super(UpdateStatement, self).get_context()
for clause in self.transactions or []:
for clause in self.conditionals:
clause.update_context(ctx)
return ctx
def _get_transactions(self):
return 'IF {0}'.format(' AND '.join([six.text_type(c) for c in self.transactions]))
def update_context_id(self, i):
super(UpdateStatement, self).update_context_id(i)
for transaction in self.transactions:
transaction.set_context_id(self.context_counter)
self.context_counter += transaction.get_context_size()
for conditional in self.conditionals:
conditional.set_context_id(self.context_counter)
self.context_counter += conditional.get_context_size()
class DeleteStatement(BaseCQLStatement):
""" a cql delete statement """
def __init__(self, table, fields=None, consistency=None, where=None, timestamp=None, if_exists=False):
def __init__(self, table, fields=None, consistency=None, where=None, timestamp=None, conditionals=None, if_exists=False):
super(DeleteStatement, self).__init__(
table,
consistency=consistency,
where=where,
timestamp=timestamp
timestamp=timestamp,
conditionals=conditionals
)
self.fields = []
if isinstance(fields, six.string_types):
fields = [fields]
for field in fields or []:
self.add_field(field)
self.if_exists = if_exists
def update_context_id(self, i):
@@ -810,11 +814,16 @@ class DeleteStatement(BaseCQLStatement):
for field in self.fields:
field.set_context_id(self.context_counter)
self.context_counter += field.get_context_size()
for t in self.conditionals:
t.set_context_id(self.context_counter)
self.context_counter += t.get_context_size()
def get_context(self):
ctx = super(DeleteStatement, self).get_context()
for field in self.fields:
field.update_context(ctx)
for clause in self.conditionals:
clause.update_context(ctx)
return ctx
def add_field(self, field):
@@ -843,6 +852,9 @@ class DeleteStatement(BaseCQLStatement):
if self.where_clauses:
qs += [self._where]
if self.conditionals:
qs += [self._get_conditionals()]
if self.if_exists:
qs += ["IF EXISTS"]

View File

@@ -13,10 +13,11 @@
# limitations under the License.
from unittest import TestCase
from cassandra.cqlengine.statements import DeleteStatement, WhereClause, MapDeleteClause
from cassandra.cqlengine.statements import DeleteStatement, WhereClause, MapDeleteClause, ConditionalClause
from cassandra.cqlengine.operators import *
import six
class DeleteStatementTests(TestCase):
def test_single_field_is_listified(self):
@@ -49,7 +50,7 @@ class DeleteStatementTests(TestCase):
def test_context_update(self):
ds = DeleteStatement('table', None)
ds.add_field(MapDeleteClause('d', {1: 2}, {1:2, 3: 4}))
ds.add_field(MapDeleteClause('d', {1: 2}, {1: 2, 3: 4}))
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
ds.update_context_id(7)
@@ -72,3 +73,13 @@ class DeleteStatementTests(TestCase):
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
ds.add_where_clause(WhereClause('created_at', InOperator(), ['0', '10', '20']))
self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" IN %(1)s', six.text_type(ds))
def test_delete_conditional(self):
where = [WhereClause('id', EqualsOperator(), 1)]
conditionals = [ConditionalClause('f0', 'value0'), ConditionalClause('f1', 'value1')]
ds = DeleteStatement('table', where=where, conditionals=conditionals)
self.assertEqual(len(ds.conditionals), len(conditionals))
self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "id" = %(0)s IF "f0" = %(1)s AND "f1" = %(2)s', six.text_type(ds))
fields = ['one', 'two']
ds = DeleteStatement('table', fields=fields, where=where, conditionals=conditionals)
self.assertEqual(six.text_type(ds), 'DELETE "one", "two" FROM table WHERE "id" = %(0)s IF "f0" = %(1)s AND "f1" = %(2)s', six.text_type(ds))

View File

@@ -0,0 +1,178 @@
# Copyright 2013-2016 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
import mock
import six
from uuid import uuid4
from cassandra.cqlengine import columns
from cassandra.cqlengine.management import sync_table, drop_table
from cassandra.cqlengine.models import Model
from cassandra.cqlengine.query import BatchQuery, LWTException
from cassandra.cqlengine.statements import ConditionalClause
from tests.integration.cqlengine.base import BaseCassEngTestCase
from tests.integration import CASSANDRA_VERSION
class TestConditionalModel(Model):
id = columns.UUID(primary_key=True, default=uuid4)
count = columns.Integer()
text = columns.Text(required=False)
@unittest.skipUnless(CASSANDRA_VERSION >= '2.0.0', "conditionals only supported on cassandra 2.0 or higher")
class TestConditional(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestConditional, cls).setUpClass()
sync_table(TestConditionalModel)
@classmethod
def tearDownClass(cls):
super(TestConditional, cls).tearDownClass()
drop_table(TestConditionalModel)
def test_update_using_conditional(self):
t = TestConditionalModel.create(text='blah blah')
t.text = 'new blah'
with mock.patch.object(self.session, 'execute') as m:
t.iff(text='blah blah').save()
args = m.call_args
self.assertIn('IF "text" = %(0)s', args[0][0].query_string)
def test_update_conditional_success(self):
t = TestConditionalModel.create(text='blah blah', count=5)
id = t.id
t.text = 'new blah'
t.iff(text='blah blah').save()
updated = TestConditionalModel.objects(id=id).first()
self.assertEqual(updated.count, 5)
self.assertEqual(updated.text, 'new blah')
def test_update_failure(self):
t = TestConditionalModel.create(text='blah blah')
t.text = 'new blah'
t = t.iff(text='something wrong')
with self.assertRaises(LWTException) as assertion:
t.save()
self.assertEqual(assertion.exception.existing, {
'text': 'blah blah',
'[applied]': False,
})
def test_blind_update(self):
t = TestConditionalModel.create(text='blah blah')
t.text = 'something else'
uid = t.id
with mock.patch.object(self.session, 'execute') as m:
TestConditionalModel.objects(id=uid).iff(text='blah blah').update(text='oh hey der')
args = m.call_args
self.assertIn('IF "text" = %(1)s', args[0][0].query_string)
def test_blind_update_fail(self):
t = TestConditionalModel.create(text='blah blah')
t.text = 'something else'
uid = t.id
qs = TestConditionalModel.objects(id=uid).iff(text='Not dis!')
with self.assertRaises(LWTException) as assertion:
qs.update(text='this will never work')
self.assertEqual(assertion.exception.existing, {
'text': 'blah blah',
'[applied]': False,
})
def test_conditional_clause(self):
tc = ConditionalClause('some_value', 23)
tc.set_context_id(3)
self.assertEqual('"some_value" = %(3)s', six.text_type(tc))
self.assertEqual('"some_value" = %(3)s', str(tc))
def test_batch_update_conditional(self):
t = TestConditionalModel.create(text='something', count=5)
id = t.id
with BatchQuery() as b:
t.batch(b).iff(count=5).update(text='something else')
updated = TestConditionalModel.objects(id=id).first()
self.assertEqual(updated.text, 'something else')
b = BatchQuery()
updated.batch(b).iff(count=6).update(text='and another thing')
with self.assertRaises(LWTException) as assertion:
b.execute()
self.assertEqual(assertion.exception.existing, {
'id': id,
'count': 5,
'[applied]': False,
})
updated = TestConditionalModel.objects(id=id).first()
self.assertEqual(updated.text, 'something else')
def test_delete_conditional(self):
# DML path
t = TestConditionalModel.create(text='something', count=5)
self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1)
with self.assertRaises(LWTException):
t.iff(count=9999).delete()
self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1)
t.iff(count=5).delete()
self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 0)
# QuerySet path
t = TestConditionalModel.create(text='something', count=5)
self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1)
with self.assertRaises(LWTException):
TestConditionalModel.objects(id=t.id).iff(count=9999).delete()
self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1)
TestConditionalModel.objects(id=t.id).iff(count=5).delete()
self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 0)
def test_update_to_none(self):
# This test is done because updates to none are split into deletes
# for old versions of cassandra. Can be removed when we drop that code
# https://github.com/datastax/python-driver/blob/3.1.1/cassandra/cqlengine/query.py#L1197-L1200
# DML path
t = TestConditionalModel.create(text='something', count=5)
self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1)
with self.assertRaises(LWTException):
t.iff(count=9999).update(text=None)
self.assertIsNotNone(TestConditionalModel.objects(id=t.id).first().text)
t.iff(count=5).update(text=None)
self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text)
# QuerySet path
t = TestConditionalModel.create(text='something', count=5)
self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1)
with self.assertRaises(LWTException):
TestConditionalModel.objects(id=t.id).iff(count=9999).update(text=None)
self.assertIsNotNone(TestConditionalModel.objects(id=t.id).first().text)
TestConditionalModel.objects(id=t.id).iff(count=5).update(text=None)
self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text)

View File

@@ -1,136 +0,0 @@
# Copyright 2013-2016 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
import mock
import six
from uuid import uuid4
from cassandra.cqlengine import columns
from cassandra.cqlengine.management import sync_table, drop_table
from cassandra.cqlengine.models import Model
from cassandra.cqlengine.query import BatchQuery, LWTException
from cassandra.cqlengine.statements import TransactionClause
from tests.integration.cqlengine.base import BaseCassEngTestCase
from tests.integration import CASSANDRA_VERSION
class TestTransactionModel(Model):
id = columns.UUID(primary_key=True, default=uuid4)
count = columns.Integer()
text = columns.Text(required=False)
@unittest.skipUnless(CASSANDRA_VERSION >= '2.0.0', "transactions only supported on cassandra 2.0 or higher")
class TestTransaction(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestTransaction, cls).setUpClass()
sync_table(TestTransactionModel)
@classmethod
def tearDownClass(cls):
super(TestTransaction, cls).tearDownClass()
drop_table(TestTransactionModel)
def test_update_using_transaction(self):
t = TestTransactionModel.create(text='blah blah')
t.text = 'new blah'
with mock.patch.object(self.session, 'execute') as m:
t.iff(text='blah blah').save()
args = m.call_args
self.assertIn('IF "text" = %(0)s', args[0][0].query_string)
def test_update_transaction_success(self):
t = TestTransactionModel.create(text='blah blah', count=5)
id = t.id
t.text = 'new blah'
t.iff(text='blah blah').save()
updated = TestTransactionModel.objects(id=id).first()
self.assertEqual(updated.count, 5)
self.assertEqual(updated.text, 'new blah')
def test_update_failure(self):
t = TestTransactionModel.create(text='blah blah')
t.text = 'new blah'
t = t.iff(text='something wrong')
with self.assertRaises(LWTException) as assertion:
t.save()
self.assertEqual(assertion.exception.existing, {
'text': 'blah blah',
'[applied]': False,
})
def test_blind_update(self):
t = TestTransactionModel.create(text='blah blah')
t.text = 'something else'
uid = t.id
with mock.patch.object(self.session, 'execute') as m:
TestTransactionModel.objects(id=uid).iff(text='blah blah').update(text='oh hey der')
args = m.call_args
self.assertIn('IF "text" = %(1)s', args[0][0].query_string)
def test_blind_update_fail(self):
t = TestTransactionModel.create(text='blah blah')
t.text = 'something else'
uid = t.id
qs = TestTransactionModel.objects(id=uid).iff(text='Not dis!')
with self.assertRaises(LWTException) as assertion:
qs.update(text='this will never work')
self.assertEqual(assertion.exception.existing, {
'text': 'blah blah',
'[applied]': False,
})
def test_transaction_clause(self):
tc = TransactionClause('some_value', 23)
tc.set_context_id(3)
self.assertEqual('"some_value" = %(3)s', six.text_type(tc))
self.assertEqual('"some_value" = %(3)s', str(tc))
def test_batch_update_transaction(self):
t = TestTransactionModel.create(text='something', count=5)
id = t.id
with BatchQuery() as b:
t.batch(b).iff(count=5).update(text='something else')
updated = TestTransactionModel.objects(id=id).first()
self.assertEqual(updated.text, 'something else')
b = BatchQuery()
updated.batch(b).iff(count=6).update(text='and another thing')
with self.assertRaises(LWTException) as assertion:
b.execute()
self.assertEqual(assertion.exception.existing, {
'id': id,
'count': 5,
'[applied]': False,
})
updated = TestTransactionModel.objects(id=id).first()
self.assertEqual(updated.text, 'something else')