From 874a63cdbcc658023c3d0f5c22c57b472d00cd68 Mon Sep 17 00:00:00 2001 From: timmartin19 Date: Wed, 8 Oct 2014 14:48:35 -0400 Subject: [PATCH] Bug fixes --- cqlengine/models.py | 20 +++++++++++-------- cqlengine/query.py | 14 ++++++++----- cqlengine/statements.py | 15 ++++++++++---- .../model/test_transaction_statements.py | 17 ++++++++++++---- 4 files changed, 45 insertions(+), 21 deletions(-) diff --git a/cqlengine/models.py b/cqlengine/models.py index 880f6207..f3114d6c 100644 --- a/cqlengine/models.py +++ b/cqlengine/models.py @@ -80,9 +80,14 @@ class TransactionDescriptor(object): """ def __get__(self, instance, model): if instance: - def transaction_setter(transaction): - instance._transaction = transaction + def transaction_setter(*prepared_transaction, **unprepared_transactions): + if len(prepared_transaction) > 0: + transactions = prepared_transaction[0] + else: + transactions = instance.objects.transaction(**unprepared_transactions)._transaction + instance._transaction = transactions return instance + return transaction_setter qs = model.__queryset__(model) @@ -323,6 +328,7 @@ class BaseModel(object): self._values = {} self._ttl = self.__default_ttl__ self._timestamp = None + self._transaction = None for name, column in self._columns.items(): value = values.get(name, None) @@ -553,10 +559,6 @@ class BaseModel(object): return cls.objects.filter(*args, **kwargs) - @classmethod - def transaction(cls, *args, **kwargs): - return cls.objects.transaction(*args, **kwargs) - @classmethod def get(cls, *args, **kwargs): return cls.objects.get(*args, **kwargs) @@ -576,7 +578,8 @@ class BaseModel(object): ttl=self._ttl, timestamp=self._timestamp, consistency=self.__consistency__, - if_not_exists=self._if_not_exists).save() + if_not_exists=self._if_not_exists, + transaction=self._transaction).save() #reset the value managers for v in self._values.values(): @@ -614,7 +617,8 @@ class BaseModel(object): batch=self._batch, ttl=self._ttl, timestamp=self._timestamp, - consistency=self.__consistency__).update() + consistency=self.__consistency__, + transaction=self._transaction).update() #reset the value managers for v in self._values.values(): diff --git a/cqlengine/query.py b/cqlengine/query.py index 8ee7e621..9c715843 100644 --- a/cqlengine/query.py +++ b/cqlengine/query.py @@ -422,11 +422,12 @@ class AbstractQuerySet(object): clone._transaction.append(operator) for col_name, val in kwargs.items(): + exists = False try: column = self.model._get_column(col_name) except KeyError: if col_name in ['exists', 'not_exists']: - pass + exists = True elif col_name == 'pk__token': if not isinstance(val, Token): raise QueryException("Virtual column 'pk__token' may only be compared to Token() values") @@ -445,7 +446,7 @@ class AbstractQuerySet(object): len(val.value), len(partition_columns))) val.set_columns(partition_columns) - if isinstance(val, BaseQueryFunction): + if isinstance(val, BaseQueryFunction) or exists is True: query_val = val else: query_val = column.to_database(val) @@ -834,7 +835,7 @@ class DMLQuery(object): _timestamp = None _if_not_exists = False - def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False): + def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False, transaction=None): self.model = model self.column_family_name = self.model.column_family_name() self.instance = instance @@ -843,6 +844,7 @@ class DMLQuery(object): self._consistency = consistency self._timestamp = timestamp self._if_not_exists = if_not_exists + self._transaction = transaction def _execute(self, q): if self._batch: @@ -897,7 +899,8 @@ class DMLQuery(object): raise CQLEngineException("DML Query intance attribute is None") assert type(self.instance) == self.model static_update_only = True - statement = UpdateStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp) + statement = UpdateStatement(self.column_family_name, ttl=self._ttl, + timestamp=self._timestamp, transactions=self._transaction) #get defined fields and their column names for name, col in self.model._columns.items(): if not col.is_primary_key: @@ -960,7 +963,8 @@ class DMLQuery(object): if self.instance._has_counter or self.instance._can_update(): return self.update() else: - insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, if_not_exists=self._if_not_exists) + insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, if_not_exists=self._if_not_exists, transactions=self._transaction) + for name, col in self.instance._columns.items(): val = getattr(self.instance, name, None) if col._val_is_null(val): diff --git a/cqlengine/statements.py b/cqlengine/statements.py index 5cdef38a..4df43061 100644 --- a/cqlengine/statements.py +++ b/cqlengine/statements.py @@ -152,6 +152,11 @@ class TransactionClause(BaseClause): def insert_tuple(self): return self.field, self.context_id + def update_context(self, ctx): + if self.field not in ['exists', 'not_exists']: + return super(TransactionClause, self).update_context(ctx) + return ctx + class ContainerUpdateClause(AssignmentClause): @@ -646,10 +651,12 @@ class AssignmentStatement(BaseCQLStatement): ctx = super(AssignmentStatement, self).get_context() for clause in self.assignments: clause.update_context(ctx) + for clause in self.transactions or []: + clause.update_context(ctx) return ctx - def _transactions(self): - return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.where_clauses])) + def _get_transactions(self): + return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.transactions])) class InsertStatement(AssignmentStatement): @@ -697,7 +704,7 @@ class InsertStatement(AssignmentStatement): qs += ["USING TIMESTAMP {}".format(self.timestamp_normalized)] if len(self.transactions) > 0: - qs += [self._transactions] + qs += [self._get_transactions()] return ' '.join(qs) @@ -726,7 +733,7 @@ class UpdateStatement(AssignmentStatement): qs += [self._where] if len(self.transactions) > 0: - qs += [self._transactions] + qs += [self._get_transactions()] return ' '.join(qs) diff --git a/cqlengine/tests/model/test_transaction_statements.py b/cqlengine/tests/model/test_transaction_statements.py index 8afabb1e..a62ac122 100644 --- a/cqlengine/tests/model/test_transaction_statements.py +++ b/cqlengine/tests/model/test_transaction_statements.py @@ -4,10 +4,13 @@ from uuid import uuid4 from mock import patch from cqlengine.exceptions import ValidationError -from cqlengine.tests.base import BaseCassEngTestCase +from unittest import TestCase from cqlengine.models import Model from cqlengine import columns from cqlengine.management import sync_table, drop_table +from cqlengine import connection + +connection.setup(['192.168.56.103'], 'test') class TestUpdateModel(Model): @@ -18,7 +21,7 @@ class TestUpdateModel(Model): text = columns.Text(required=False, index=True) -class ModelUpdateTests(BaseCassEngTestCase): +class ModelUpdateTests(TestCase): @classmethod def setUpClass(cls): @@ -31,6 +34,12 @@ class ModelUpdateTests(BaseCassEngTestCase): drop_table(TestUpdateModel) def test_transaction_insertion(self): - m = TestUpdateModel(count=5, text='something').transaction(exists=True) - m.save() + m = TestUpdateModel.objects.create(count=5, text='something') + all_models = TestUpdateModel.objects.all() + for x in all_models: + pass + m = TestUpdateModel.objects + m = m.transaction(not_exists=True) + m = m.create(count=5, text='something') + m.transaction(count=6).update(text='something else') x = 10 \ No newline at end of file