cqle: transaction --> conditional

To remove overload with other transaction types (ife, ine).
This commit is contained in:
Adam Holmberg
2016-03-18 09:20:16 -05:00
parent c4db355311
commit 67e51b4766
5 changed files with 113 additions and 110 deletions

View File

@@ -101,28 +101,28 @@ class QuerySetDescriptor(object):
raise NotImplementedError raise NotImplementedError
class TransactionDescriptor(object): class ConditionalDescriptor(object):
""" """
returns a query set descriptor returns a query set descriptor
""" """
def __get__(self, instance, model): def __get__(self, instance, model):
if instance: if instance:
def transaction_setter(*prepared_transaction, **unprepared_transactions): def conditional_setter(*prepared_conditional, **unprepared_conditionals):
if len(prepared_transaction) > 0: if len(prepared_conditional) > 0:
transactions = prepared_transaction[0] conditionals = prepared_conditional[0]
else: else:
transactions = instance.objects.iff(**unprepared_transactions)._transaction conditionals = instance.objects.iff(**unprepared_conditionals)._conditional
instance._transaction = transactions instance._conditional = conditionals
return instance return instance
return transaction_setter return conditional_setter
qs = model.__queryset__(model) qs = model.__queryset__(model)
def transaction_setter(**unprepared_transactions): def conditional_setter(**unprepared_conditionals):
transactions = model.objects.iff(**unprepared_transactions)._transaction conditionals = model.objects.iff(**unprepared_conditionals)._conditional
qs._transaction = transactions qs._conditional = conditionals
return qs return qs
return transaction_setter return conditional_setter
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
@@ -314,7 +314,7 @@ class BaseModel(object):
objects = QuerySetDescriptor() objects = QuerySetDescriptor()
ttl = TTLDescriptor() ttl = TTLDescriptor()
consistency = ConsistencyDescriptor() consistency = ConsistencyDescriptor()
iff = TransactionDescriptor() iff = ConditionalDescriptor()
# custom timestamps, see USING TIMESTAMP X # custom timestamps, see USING TIMESTAMP X
timestamp = TimestampDescriptor() timestamp = TimestampDescriptor()
@@ -352,7 +352,7 @@ class BaseModel(object):
def __init__(self, **values): def __init__(self, **values):
self._ttl = self.__default_ttl__ self._ttl = self.__default_ttl__
self._timestamp = None self._timestamp = None
self._transaction = None self._conditional = None
self._batch = None self._batch = None
self._timeout = connection.NOT_SET self._timeout = connection.NOT_SET
self._is_persisted = False self._is_persisted = False
@@ -684,7 +684,7 @@ class BaseModel(object):
timestamp=self._timestamp, timestamp=self._timestamp,
consistency=self.__consistency__, consistency=self.__consistency__,
if_not_exists=self._if_not_exists, if_not_exists=self._if_not_exists,
transaction=self._transaction, conditional=self._conditional,
timeout=self._timeout, timeout=self._timeout,
if_exists=self._if_exists).save() if_exists=self._if_exists).save()
@@ -731,7 +731,7 @@ class BaseModel(object):
ttl=self._ttl, ttl=self._ttl,
timestamp=self._timestamp, timestamp=self._timestamp,
consistency=self.__consistency__, consistency=self.__consistency__,
transaction=self._transaction, conditional=self._conditional,
timeout=self._timeout, timeout=self._timeout,
if_exists=self._if_exists).update() if_exists=self._if_exists).update()
@@ -751,7 +751,7 @@ class BaseModel(object):
timestamp=self._timestamp, timestamp=self._timestamp,
consistency=self.__consistency__, consistency=self.__consistency__,
timeout=self._timeout, timeout=self._timeout,
transaction=self._transaction, conditional=self._conditional,
if_exists=self._if_exists).delete() if_exists=self._if_exists).delete()
def get_changed_columns(self): def get_changed_columns(self):

View File

@@ -28,7 +28,7 @@ from cassandra.cqlengine.statements import (WhereClause, SelectStatement, Delete
UpdateStatement, AssignmentClause, InsertStatement, UpdateStatement, AssignmentClause, InsertStatement,
BaseCQLStatement, MapUpdateClause, MapDeleteClause, BaseCQLStatement, MapUpdateClause, MapDeleteClause,
ListUpdateClause, SetUpdateClause, CounterUpdateClause, ListUpdateClause, SetUpdateClause, CounterUpdateClause,
TransactionClause) ConditionalClause)
class QueryException(CQLEngineException): class QueryException(CQLEngineException):
@@ -43,7 +43,7 @@ class IfExistsWithCounterColumn(CQLEngineException):
class LWTException(CQLEngineException): class LWTException(CQLEngineException):
"""Lightweight transaction exception. """Lightweight conditional exception.
This exception will be raised when a write using an `IF` clause could not be This exception will be raised when a write using an `IF` clause could not be
applied due to existing data violating the condition. The existing data is applied due to existing data violating the condition. The existing data is
@@ -146,7 +146,7 @@ class BatchQuery(object):
:param batch_type: (optional) One of batch type values available through BatchType enum :param batch_type: (optional) One of batch type values available through BatchType enum
:type batch_type: str or None :type batch_type: str or None
:param timestamp: (optional) A datetime or timedelta object with desired timestamp to be applied :param timestamp: (optional) A datetime or timedelta object with desired timestamp to be applied
to the batch transaction. to the batch conditional.
:type timestamp: datetime or timedelta or None :type timestamp: datetime or timedelta or None
:param consistency: (optional) One of consistency values ("ANY", "ONE", "QUORUM" etc) :param consistency: (optional) One of consistency values ("ANY", "ONE", "QUORUM" etc)
:type consistency: The :class:`.ConsistencyLevel` to be used for the batch query, or None. :type consistency: The :class:`.ConsistencyLevel` to be used for the batch query, or None.
@@ -267,8 +267,8 @@ class AbstractQuerySet(object):
# Where clause filters # Where clause filters
self._where = [] self._where = []
# Transaction clause filters # Conditional clause filters
self._transaction = [] self._conditional = []
# ordering arguments # ordering arguments
self._order = [] self._order = []
@@ -314,7 +314,7 @@ class AbstractQuerySet(object):
return self._batch.add_query(q) return self._batch.add_query(q)
else: else:
result = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout) result = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout)
if self._if_not_exists or self._if_exists or self._transaction: if self._if_not_exists or self._if_exists or self._conditional:
check_applied(result) check_applied(result)
return result return result
@@ -545,9 +545,9 @@ class AbstractQuerySet(object):
clone = copy.deepcopy(self) clone = copy.deepcopy(self)
for operator in args: for operator in args:
if not isinstance(operator, TransactionClause): if not isinstance(operator, ConditionalClause):
raise QueryException('{0} is not a valid query operator'.format(operator)) raise QueryException('{0} is not a valid query operator'.format(operator))
clone._transaction.append(operator) clone._conditional.append(operator)
for col_name, val in kwargs.items(): for col_name, val in kwargs.items():
exists = False exists = False
@@ -576,7 +576,7 @@ class AbstractQuerySet(object):
else: else:
query_val = column.to_database(val) query_val = column.to_database(val)
clone._transaction.append(TransactionClause(col_name, query_val)) clone._conditional.append(ConditionalClause(col_name, query_val))
return clone return clone
@@ -898,7 +898,7 @@ class AbstractQuerySet(object):
self.column_family_name, self.column_family_name,
where=self._where, where=self._where,
timestamp=self._timestamp, timestamp=self._timestamp,
transactions=self._transaction, conditionals=self._conditional,
if_exists=self._if_exists if_exists=self._if_exists
) )
self._execute(dq) self._execute(dq)
@@ -1156,7 +1156,7 @@ class ModelQuerySet(AbstractQuerySet):
nulled_columns = set() nulled_columns = set()
us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl,
timestamp=self._timestamp, transactions=self._transaction, if_exists=self._if_exists) timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists)
for name, val in values.items(): for name, val in values.items():
col_name, col_op = self._parse_filter_arg(name) col_name, col_op = self._parse_filter_arg(name)
col = self.model._columns.get(col_name) col = self.model._columns.get(col_name)
@@ -1197,7 +1197,7 @@ class ModelQuerySet(AbstractQuerySet):
if nulled_columns: if nulled_columns:
ds = DeleteStatement(self.column_family_name, fields=nulled_columns, ds = DeleteStatement(self.column_family_name, fields=nulled_columns,
where=self._where, transactions=self._transaction, if_exists=self._if_exists) where=self._where, conditionals=self._conditional, if_exists=self._if_exists)
self._execute(ds) self._execute(ds)
@@ -1216,7 +1216,7 @@ class DMLQuery(object):
_if_exists = False _if_exists = False
def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None,
if_not_exists=False, transaction=None, timeout=connection.NOT_SET, if_exists=False): if_not_exists=False, conditional=None, timeout=connection.NOT_SET, if_exists=False):
self.model = model self.model = model
self.column_family_name = self.model.column_family_name() self.column_family_name = self.model.column_family_name()
self.instance = instance self.instance = instance
@@ -1226,7 +1226,7 @@ class DMLQuery(object):
self._timestamp = timestamp self._timestamp = timestamp
self._if_not_exists = if_not_exists self._if_not_exists = if_not_exists
self._if_exists = if_exists self._if_exists = if_exists
self._transaction = transaction self._conditional = conditional
self._timeout = timeout self._timeout = timeout
def _execute(self, q): def _execute(self, q):
@@ -1234,7 +1234,7 @@ class DMLQuery(object):
return self._batch.add_query(q) return self._batch.add_query(q)
else: else:
tmp = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout) tmp = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout)
if self._if_not_exists or self._if_exists or self._transaction: if self._if_not_exists or self._if_exists or self._conditional:
check_applied(tmp) check_applied(tmp)
return tmp return tmp
@@ -1248,7 +1248,7 @@ class DMLQuery(object):
""" """
executes a delete query to remove columns that have changed to null executes a delete query to remove columns that have changed to null
""" """
ds = DeleteStatement(self.column_family_name, transactions=self._transaction, if_exists=self._if_exists) ds = DeleteStatement(self.column_family_name, conditionals=self._conditional, if_exists=self._if_exists)
deleted_fields = False deleted_fields = False
for _, v in self.instance._values.items(): for _, v in self.instance._values.items():
col = v.column col = v.column
@@ -1283,7 +1283,7 @@ class DMLQuery(object):
null_clustering_key = False if len(self.instance._clustering_keys) == 0 else True null_clustering_key = False if len(self.instance._clustering_keys) == 0 else True
static_changed_only = True static_changed_only = True
statement = UpdateStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, statement = UpdateStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp,
transactions=self._transaction, if_exists=self._if_exists) conditionals=self._conditional, if_exists=self._if_exists)
for name, col in self.instance._clustering_keys.items(): for name, col in self.instance._clustering_keys.items():
null_clustering_key = null_clustering_key and col._val_is_null(getattr(self.instance, name, None)) null_clustering_key = null_clustering_key and col._val_is_null(getattr(self.instance, name, None))
# get defined fields and their column names # get defined fields and their column names
@@ -1387,7 +1387,7 @@ class DMLQuery(object):
if self.instance is None: if self.instance is None:
raise CQLEngineException("DML Query instance attribute is None") raise CQLEngineException("DML Query instance attribute is None")
ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp, transactions=self._transaction, if_exists=self._if_exists) ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists)
for name, col in self.model._primary_keys.items(): for name, col in self.model._primary_keys.items():
if (not col.partition_key) and (getattr(self.instance, name) is None): if (not col.partition_key) and (getattr(self.instance, name) is None):
continue continue

View File

@@ -148,7 +148,7 @@ class AssignmentClause(BaseClause):
return self.field, self.context_id return self.field, self.context_id
class TransactionClause(BaseClause): class ConditionalClause(BaseClause):
""" A single variable iff statement """ """ A single variable iff statement """
def __unicode__(self): def __unicode__(self):
@@ -471,7 +471,7 @@ class MapDeleteClause(BaseDeleteClause):
class BaseCQLStatement(UnicodeMixin): class BaseCQLStatement(UnicodeMixin):
""" The base cql statement class """ """ The base cql statement class """
def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, transactions=None): def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, conditionals=None):
super(BaseCQLStatement, self).__init__() super(BaseCQLStatement, self).__init__()
self.table = table self.table = table
self.consistency = consistency self.consistency = consistency
@@ -484,9 +484,9 @@ class BaseCQLStatement(UnicodeMixin):
for clause in where or []: for clause in where or []:
self.add_where_clause(clause) self.add_where_clause(clause)
self.transactions = [] self.conditionals = []
for transaction in transactions or []: for conditional in conditionals or []:
self.add_transaction_clause(transaction) self.add_conditional_clause(conditional)
def add_where_clause(self, clause): def add_where_clause(self, clause):
""" """
@@ -510,21 +510,21 @@ class BaseCQLStatement(UnicodeMixin):
clause.update_context(ctx) clause.update_context(ctx)
return ctx return ctx
def add_transaction_clause(self, clause): def add_conditional_clause(self, clause):
""" """
Adds a iff clause to this statement Adds a iff clause to this statement
:param clause: The clause that will be added to the iff statement :param clause: The clause that will be added to the iff statement
:type clause: TransactionClause :type clause: ConditionalClause
""" """
if not isinstance(clause, TransactionClause): if not isinstance(clause, ConditionalClause):
raise StatementException('only instances of AssignmentClause can be added to statements') raise StatementException('only instances of AssignmentClause can be added to statements')
clause.set_context_id(self.context_counter) clause.set_context_id(self.context_counter)
self.context_counter += clause.get_context_size() self.context_counter += clause.get_context_size()
self.transactions.append(clause) self.conditionals.append(clause)
def _get_transactions(self): def _get_conditionals(self):
return 'IF {0}'.format(' AND '.join([six.text_type(c) for c in self.transactions])) return 'IF {0}'.format(' AND '.join([six.text_type(c) for c in self.conditionals]))
def get_context_size(self): def get_context_size(self):
return len(self.get_context()) return len(self.get_context())
@@ -637,12 +637,12 @@ class AssignmentStatement(BaseCQLStatement):
where=None, where=None,
ttl=None, ttl=None,
timestamp=None, timestamp=None,
transactions=None): conditionals=None):
super(AssignmentStatement, self).__init__( super(AssignmentStatement, self).__init__(
table, table,
consistency=consistency, consistency=consistency,
where=where, where=where,
transactions=transactions conditionals=conditionals
) )
self.ttl = ttl self.ttl = ttl
self.timestamp = timestamp self.timestamp = timestamp
@@ -737,7 +737,7 @@ class UpdateStatement(AssignmentStatement):
where=None, where=None,
ttl=None, ttl=None,
timestamp=None, timestamp=None,
transactions=None, conditionals=None,
if_exists=False): if_exists=False):
super(UpdateStatement, self). __init__(table, super(UpdateStatement, self). __init__(table,
assignments=assignments, assignments=assignments,
@@ -745,7 +745,7 @@ class UpdateStatement(AssignmentStatement):
where=where, where=where,
ttl=ttl, ttl=ttl,
timestamp=timestamp, timestamp=timestamp,
transactions=transactions) conditionals=conditionals)
self.if_exists = if_exists self.if_exists = if_exists
@@ -769,8 +769,8 @@ class UpdateStatement(AssignmentStatement):
if self.where_clauses: if self.where_clauses:
qs += [self._where] qs += [self._where]
if len(self.transactions) > 0: if len(self.conditionals) > 0:
qs += [self._get_transactions()] qs += [self._get_conditionals()]
if self.if_exists: if self.if_exists:
qs += ["IF EXISTS"] qs += ["IF EXISTS"]
@@ -779,27 +779,27 @@ class UpdateStatement(AssignmentStatement):
def get_context(self): def get_context(self):
ctx = super(UpdateStatement, self).get_context() ctx = super(UpdateStatement, self).get_context()
for clause in self.transactions or []: for clause in self.conditionals:
clause.update_context(ctx) clause.update_context(ctx)
return ctx return ctx
def update_context_id(self, i): def update_context_id(self, i):
super(UpdateStatement, self).update_context_id(i) super(UpdateStatement, self).update_context_id(i)
for transaction in self.transactions: for conditional in self.conditionals:
transaction.set_context_id(self.context_counter) conditional.set_context_id(self.context_counter)
self.context_counter += transaction.get_context_size() self.context_counter += conditional.get_context_size()
class DeleteStatement(BaseCQLStatement): class DeleteStatement(BaseCQLStatement):
""" a cql delete statement """ """ a cql delete statement """
def __init__(self, table, fields=None, consistency=None, where=None, timestamp=None, transactions=None, if_exists=False): def __init__(self, table, fields=None, consistency=None, where=None, timestamp=None, conditionals=None, if_exists=False):
super(DeleteStatement, self).__init__( super(DeleteStatement, self).__init__(
table, table,
consistency=consistency, consistency=consistency,
where=where, where=where,
timestamp=timestamp, timestamp=timestamp,
transactions=transactions conditionals=conditionals
) )
self.fields = [] self.fields = []
if isinstance(fields, six.string_types): if isinstance(fields, six.string_types):
@@ -814,12 +814,15 @@ class DeleteStatement(BaseCQLStatement):
for field in self.fields: for field in self.fields:
field.set_context_id(self.context_counter) field.set_context_id(self.context_counter)
self.context_counter += field.get_context_size() self.context_counter += field.get_context_size()
for t in self.conditionals:
t.set_context_id(self.context_counter)
self.context_counter += t.get_context_size()
def get_context(self): def get_context(self):
ctx = super(DeleteStatement, self).get_context() ctx = super(DeleteStatement, self).get_context()
for field in self.fields: for field in self.fields:
field.update_context(ctx) field.update_context(ctx)
for clause in self.transactions or []: for clause in self.conditionals:
clause.update_context(ctx) clause.update_context(ctx)
return ctx return ctx
@@ -849,8 +852,8 @@ class DeleteStatement(BaseCQLStatement):
if self.where_clauses: if self.where_clauses:
qs += [self._where] qs += [self._where]
if self.transactions: if self.conditionals:
qs += [self._get_transactions()] qs += [self._get_conditionals()]
if self.if_exists: if self.if_exists:
qs += ["IF EXISTS"] qs += ["IF EXISTS"]

View File

@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from unittest import TestCase from unittest import TestCase
from cassandra.cqlengine.statements import DeleteStatement, WhereClause, MapDeleteClause, TransactionClause from cassandra.cqlengine.statements import DeleteStatement, WhereClause, MapDeleteClause, ConditionalClause
from cassandra.cqlengine.operators import * from cassandra.cqlengine.operators import *
import six import six
@@ -76,10 +76,10 @@ class DeleteStatementTests(TestCase):
def test_delete_conditional(self): def test_delete_conditional(self):
where = [WhereClause('id', EqualsOperator(), 1)] where = [WhereClause('id', EqualsOperator(), 1)]
transactions = [TransactionClause('f0', 'value0'), TransactionClause('f1', 'value1')] conditionals = [ConditionalClause('f0', 'value0'), ConditionalClause('f1', 'value1')]
ds = DeleteStatement('table', where=where, transactions=transactions) ds = DeleteStatement('table', where=where, conditionals=conditionals)
self.assertEqual(len(ds.transactions), len(transactions)) self.assertEqual(len(ds.conditionals), len(conditionals))
self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "id" = %(0)s IF "f0" = %(1)s AND "f1" = %(2)s', six.text_type(ds)) self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "id" = %(0)s IF "f0" = %(1)s AND "f1" = %(2)s', six.text_type(ds))
fields = ['one', 'two'] fields = ['one', 'two']
ds = DeleteStatement('table', fields=fields, where=where, transactions=transactions) ds = DeleteStatement('table', fields=fields, where=where, conditionals=conditionals)
self.assertEqual(six.text_type(ds), 'DELETE "one", "two" FROM table WHERE "id" = %(0)s IF "f0" = %(1)s AND "f1" = %(2)s', six.text_type(ds)) self.assertEqual(six.text_type(ds), 'DELETE "one", "two" FROM table WHERE "id" = %(0)s IF "f0" = %(1)s AND "f1" = %(2)s', six.text_type(ds))

View File

@@ -24,33 +24,33 @@ from cassandra.cqlengine import columns
from cassandra.cqlengine.management import sync_table, drop_table from cassandra.cqlengine.management import sync_table, drop_table
from cassandra.cqlengine.models import Model from cassandra.cqlengine.models import Model
from cassandra.cqlengine.query import BatchQuery, LWTException from cassandra.cqlengine.query import BatchQuery, LWTException
from cassandra.cqlengine.statements import TransactionClause from cassandra.cqlengine.statements import ConditionalClause
from tests.integration.cqlengine.base import BaseCassEngTestCase from tests.integration.cqlengine.base import BaseCassEngTestCase
from tests.integration import CASSANDRA_VERSION from tests.integration import CASSANDRA_VERSION
class TestTransactionModel(Model): class TestConditionalModel(Model):
id = columns.UUID(primary_key=True, default=uuid4) id = columns.UUID(primary_key=True, default=uuid4)
count = columns.Integer() count = columns.Integer()
text = columns.Text(required=False) text = columns.Text(required=False)
@unittest.skipUnless(CASSANDRA_VERSION >= '2.0.0', "transactions only supported on cassandra 2.0 or higher") @unittest.skipUnless(CASSANDRA_VERSION >= '2.0.0', "conditionals only supported on cassandra 2.0 or higher")
class TestTransaction(BaseCassEngTestCase): class TestConditional(BaseCassEngTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(TestTransaction, cls).setUpClass() super(TestConditional, cls).setUpClass()
sync_table(TestTransactionModel) sync_table(TestConditionalModel)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
super(TestTransaction, cls).tearDownClass() super(TestConditional, cls).tearDownClass()
drop_table(TestTransactionModel) drop_table(TestConditionalModel)
def test_update_using_transaction(self): def test_update_using_conditional(self):
t = TestTransactionModel.create(text='blah blah') t = TestConditionalModel.create(text='blah blah')
t.text = 'new blah' t.text = 'new blah'
with mock.patch.object(self.session, 'execute') as m: with mock.patch.object(self.session, 'execute') as m:
t.iff(text='blah blah').save() t.iff(text='blah blah').save()
@@ -58,18 +58,18 @@ class TestTransaction(BaseCassEngTestCase):
args = m.call_args args = m.call_args
self.assertIn('IF "text" = %(0)s', args[0][0].query_string) self.assertIn('IF "text" = %(0)s', args[0][0].query_string)
def test_update_transaction_success(self): def test_update_conditional_success(self):
t = TestTransactionModel.create(text='blah blah', count=5) t = TestConditionalModel.create(text='blah blah', count=5)
id = t.id id = t.id
t.text = 'new blah' t.text = 'new blah'
t.iff(text='blah blah').save() t.iff(text='blah blah').save()
updated = TestTransactionModel.objects(id=id).first() updated = TestConditionalModel.objects(id=id).first()
self.assertEqual(updated.count, 5) self.assertEqual(updated.count, 5)
self.assertEqual(updated.text, 'new blah') self.assertEqual(updated.text, 'new blah')
def test_update_failure(self): def test_update_failure(self):
t = TestTransactionModel.create(text='blah blah') t = TestConditionalModel.create(text='blah blah')
t.text = 'new blah' t.text = 'new blah'
t = t.iff(text='something wrong') t = t.iff(text='something wrong')
@@ -82,21 +82,21 @@ class TestTransaction(BaseCassEngTestCase):
}) })
def test_blind_update(self): def test_blind_update(self):
t = TestTransactionModel.create(text='blah blah') t = TestConditionalModel.create(text='blah blah')
t.text = 'something else' t.text = 'something else'
uid = t.id uid = t.id
with mock.patch.object(self.session, 'execute') as m: with mock.patch.object(self.session, 'execute') as m:
TestTransactionModel.objects(id=uid).iff(text='blah blah').update(text='oh hey der') TestConditionalModel.objects(id=uid).iff(text='blah blah').update(text='oh hey der')
args = m.call_args args = m.call_args
self.assertIn('IF "text" = %(1)s', args[0][0].query_string) self.assertIn('IF "text" = %(1)s', args[0][0].query_string)
def test_blind_update_fail(self): def test_blind_update_fail(self):
t = TestTransactionModel.create(text='blah blah') t = TestConditionalModel.create(text='blah blah')
t.text = 'something else' t.text = 'something else'
uid = t.id uid = t.id
qs = TestTransactionModel.objects(id=uid).iff(text='Not dis!') qs = TestConditionalModel.objects(id=uid).iff(text='Not dis!')
with self.assertRaises(LWTException) as assertion: with self.assertRaises(LWTException) as assertion:
qs.update(text='this will never work') qs.update(text='this will never work')
@@ -105,20 +105,20 @@ class TestTransaction(BaseCassEngTestCase):
'[applied]': False, '[applied]': False,
}) })
def test_transaction_clause(self): def test_conditional_clause(self):
tc = TransactionClause('some_value', 23) tc = ConditionalClause('some_value', 23)
tc.set_context_id(3) tc.set_context_id(3)
self.assertEqual('"some_value" = %(3)s', six.text_type(tc)) self.assertEqual('"some_value" = %(3)s', six.text_type(tc))
self.assertEqual('"some_value" = %(3)s', str(tc)) self.assertEqual('"some_value" = %(3)s', str(tc))
def test_batch_update_transaction(self): def test_batch_update_conditional(self):
t = TestTransactionModel.create(text='something', count=5) t = TestConditionalModel.create(text='something', count=5)
id = t.id id = t.id
with BatchQuery() as b: with BatchQuery() as b:
t.batch(b).iff(count=5).update(text='something else') t.batch(b).iff(count=5).update(text='something else')
updated = TestTransactionModel.objects(id=id).first() updated = TestConditionalModel.objects(id=id).first()
self.assertEqual(updated.text, 'something else') self.assertEqual(updated.text, 'something else')
b = BatchQuery() b = BatchQuery()
@@ -132,27 +132,27 @@ class TestTransaction(BaseCassEngTestCase):
'[applied]': False, '[applied]': False,
}) })
updated = TestTransactionModel.objects(id=id).first() updated = TestConditionalModel.objects(id=id).first()
self.assertEqual(updated.text, 'something else') self.assertEqual(updated.text, 'something else')
def test_delete_transaction(self): def test_delete_conditional(self):
# DML path # DML path
t = TestTransactionModel.create(text='something', count=5) t = TestConditionalModel.create(text='something', count=5)
self.assertEqual(TestTransactionModel.objects(id=t.id).count(), 1) self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1)
with self.assertRaises(LWTException): with self.assertRaises(LWTException):
t.iff(count=9999).delete() t.iff(count=9999).delete()
self.assertEqual(TestTransactionModel.objects(id=t.id).count(), 1) self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1)
t.iff(count=5).delete() t.iff(count=5).delete()
self.assertEqual(TestTransactionModel.objects(id=t.id).count(), 0) self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 0)
# QuerySet path # QuerySet path
t = TestTransactionModel.create(text='something', count=5) t = TestConditionalModel.create(text='something', count=5)
self.assertEqual(TestTransactionModel.objects(id=t.id).count(), 1) self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1)
with self.assertRaises(LWTException): with self.assertRaises(LWTException):
TestTransactionModel.objects(id=t.id).iff(count=9999).delete() TestConditionalModel.objects(id=t.id).iff(count=9999).delete()
self.assertEqual(TestTransactionModel.objects(id=t.id).count(), 1) self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1)
TestTransactionModel.objects(id=t.id).iff(count=5).delete() TestConditionalModel.objects(id=t.id).iff(count=5).delete()
self.assertEqual(TestTransactionModel.objects(id=t.id).count(), 0) self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 0)
def test_update_to_none(self): def test_update_to_none(self):
# This test is done because updates to none are split into deletes # This test is done because updates to none are split into deletes
@@ -160,19 +160,19 @@ class TestTransaction(BaseCassEngTestCase):
# https://github.com/datastax/python-driver/blob/3.1.1/cassandra/cqlengine/query.py#L1197-L1200 # https://github.com/datastax/python-driver/blob/3.1.1/cassandra/cqlengine/query.py#L1197-L1200
# DML path # DML path
t = TestTransactionModel.create(text='something', count=5) t = TestConditionalModel.create(text='something', count=5)
self.assertEqual(TestTransactionModel.objects(id=t.id).count(), 1) self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1)
with self.assertRaises(LWTException): with self.assertRaises(LWTException):
t.iff(count=9999).update(text=None) t.iff(count=9999).update(text=None)
self.assertIsNotNone(TestTransactionModel.objects(id=t.id).first().text) self.assertIsNotNone(TestConditionalModel.objects(id=t.id).first().text)
t.iff(count=5).update(text=None) t.iff(count=5).update(text=None)
self.assertIsNone(TestTransactionModel.objects(id=t.id).first().text) self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text)
# QuerySet path # QuerySet path
t = TestTransactionModel.create(text='something', count=5) t = TestConditionalModel.create(text='something', count=5)
self.assertEqual(TestTransactionModel.objects(id=t.id).count(), 1) self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1)
with self.assertRaises(LWTException): with self.assertRaises(LWTException):
TestTransactionModel.objects(id=t.id).iff(count=9999).update(text=None) TestConditionalModel.objects(id=t.id).iff(count=9999).update(text=None)
self.assertIsNotNone(TestTransactionModel.objects(id=t.id).first().text) self.assertIsNotNone(TestConditionalModel.objects(id=t.id).first().text)
TestTransactionModel.objects(id=t.id).iff(count=5).update(text=None) TestConditionalModel.objects(id=t.id).iff(count=5).update(text=None)
self.assertIsNone(TestTransactionModel.objects(id=t.id).first().text) self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text)