From 7abba76e74b14fd1a461737321523d5cf6db5e9b Mon Sep 17 00:00:00 2001 From: timmartin19 Date: Wed, 8 Oct 2014 12:24:47 -0400 Subject: [PATCH 1/7] =?UTF-8?q?Initial=20attempt=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cqlengine/models.py | 28 ++++++++++- cqlengine/query.py | 49 ++++++++++++++++++- cqlengine/statements.py | 44 ++++++++++++++++- .../model/test_transaction_statements.py | 36 ++++++++++++++ 4 files changed, 154 insertions(+), 3 deletions(-) create mode 100644 cqlengine/tests/model/test_transaction_statements.py diff --git a/cqlengine/models.py b/cqlengine/models.py index d725e42c..92ce9bac 100644 --- a/cqlengine/models.py +++ b/cqlengine/models.py @@ -74,6 +74,27 @@ class QuerySetDescriptor(object): raise NotImplementedError +class TransactionDescriptor(object): + """ + returns a query set descriptor + """ + def __get__(self, instance, model): + if instance: + def transaction_setter(transaction): + instance._transaction = transaction + return instance + return transaction_setter + qs = model.__queryset__(model) + + def transaction_setter(transaction): + qs._transaction = transaction + return instance + return transaction_setter + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + class TTLDescriptor(object): """ returns a query set descriptor @@ -222,6 +243,7 @@ class BaseModel(object): objects = QuerySetDescriptor() ttl = TTLDescriptor() consistency = ConsistencyDescriptor() + transaction = TransactionDescriptor() # custom timestamps, see USING TIMESTAMP X timestamp = TimestampDescriptor() @@ -282,7 +304,7 @@ class BaseModel(object): self._timestamp = None for name, column in self._columns.items(): - value = values.get(name, None) + value = values.get(name, None) if value is not None or isinstance(column, columns.BaseContainerColumn): value = column.to_python(value) value_mngr = column.value_manager(self, column, value) @@ -510,6 +532,10 @@ class BaseModel(object): return cls.objects.filter(*args, **kwargs) + @classmethod + def transaction(cls, *args, **kwargs): + return cls.objects.transaction(*args, **kwargs) + @classmethod def get(cls, *args, **kwargs): return cls.objects.get(*args, **kwargs) diff --git a/cqlengine/query.py b/cqlengine/query.py index de929c76..0f042dbf 100644 --- a/cqlengine/query.py +++ b/cqlengine/query.py @@ -13,7 +13,7 @@ from cqlengine.functions import Token, BaseQueryFunction, QueryValue, UnicodeMix #http://www.datastax.com/docs/1.1/references/cql/index from cqlengine.operators import InOperator, EqualsOperator, GreaterThanOperator, GreaterThanOrEqualOperator from cqlengine.operators import LessThanOperator, LessThanOrEqualOperator, BaseWhereOperator -from cqlengine.statements import WhereClause, SelectStatement, DeleteStatement, UpdateStatement, AssignmentClause, InsertStatement, BaseCQLStatement, MapUpdateClause, MapDeleteClause, ListUpdateClause, SetUpdateClause, CounterUpdateClause +from cqlengine.statements import WhereClause, SelectStatement, DeleteStatement, UpdateStatement, AssignmentClause, InsertStatement, BaseCQLStatement, MapUpdateClause, MapDeleteClause, ListUpdateClause, SetUpdateClause, CounterUpdateClause, TransactionClause class QueryException(CQLEngineException): pass @@ -194,6 +194,9 @@ class AbstractQuerySet(object): #Where clause filters self._where = [] + # Transaction clause filters + self._transaction = [] + #ordering arguments self._order = [] @@ -394,6 +397,50 @@ class AbstractQuerySet(object): else: raise QueryException("Can't parse '{}'".format(arg)) + def transaction(self, *args, **kwargs): + """Adds IF statements to queryset""" + if len([x for x in kwargs.values() if x is None]): + raise CQLEngineException("None values on transaction are not allowed") + + clone = copy.deepcopy(self) + for operator in args: + if not isinstance(operator, TransactionClause): + raise QueryException('{} is not a valid query operator'.format(operator)) + clone._transaction.append(operator) + + for col_name, val in kwargs.items(): + try: + column = self.model._get_column(col_name) + except KeyError: + if col_name in ['exists', 'not_exists']: + pass + elif col_name == 'pk__token': + if not isinstance(val, Token): + raise QueryException("Virtual column 'pk__token' may only be compared to Token() values") + column = columns._PartitionKeysToken(self.model) + quote_field = False + else: + raise QueryException("Can't resolve column name: '{}'".format(col_name)) + + if isinstance(val, Token): + if col_name != 'pk__token': + raise QueryException("Token() values may only be compared to the 'pk__token' virtual column") + partition_columns = column.partition_columns + if len(partition_columns) != len(val.value): + raise QueryException( + 'Token() received {} arguments but model has {} partition keys'.format( + len(val.value), len(partition_columns))) + val.set_columns(partition_columns) + + if isinstance(val, BaseQueryFunction): + query_val = val + else: + query_val = column.to_database(val) + + clone._transaction.append(TransactionClause(col_name, query_val)) + + return clone + def filter(self, *args, **kwargs): """ Adds WHERE arguments to the queryset, returning a new queryset diff --git a/cqlengine/statements.py b/cqlengine/statements.py index f1681a16..834db391 100644 --- a/cqlengine/statements.py +++ b/cqlengine/statements.py @@ -139,6 +139,20 @@ class AssignmentClause(BaseClause): return self.field, self.context_id +class TransactionClause(BaseClause): + """ A single variable transaction statement """ + + def __unicode__(self): + if self.field == 'exists': + return u'EXISTS' + if self.field == 'not_exists': + return u'NOT EXISTS' + return u'"{}" = %({})s'.format(self.field, self.context_id) + + def insert_tuple(self): + return self.field, self.context_id + + class ContainerUpdateClause(AssignmentClause): def __init__(self, field, value, operation=None, previous=None, column=None): @@ -573,7 +587,8 @@ class AssignmentStatement(BaseCQLStatement): consistency=None, where=None, ttl=None, - timestamp=None): + timestamp=None, + transactions=None): super(AssignmentStatement, self).__init__( table, consistency=consistency, @@ -587,6 +602,11 @@ class AssignmentStatement(BaseCQLStatement): for assignment in assignments or []: self.add_assignment_clause(assignment) + # Add transaction statements + self.transactions = [] + for transaction in transactions or []: + self.add_transaction_clause(transaction) + def update_context_id(self, i): super(AssignmentStatement, self).update_context_id(i) for assignment in self.assignments: @@ -605,6 +625,19 @@ class AssignmentStatement(BaseCQLStatement): self.context_counter += clause.get_context_size() self.assignments.append(clause) + def add_transaction_clause(self, clause): + """ + Adds a transaction clause to this statement + + :param clause: The clause that will be added to the transaction statement + :type clause: TransactionClause + """ + if not isinstance(clause, TransactionClause): + raise StatementException('only instances of AssignmentClause can be added to statements') + clause.set_context_id(self.context_counter) + self.context_counter += clause.get_context_size() + self.transactions.append(clause) + @property def is_empty(self): return len(self.assignments) == 0 @@ -615,6 +648,9 @@ class AssignmentStatement(BaseCQLStatement): clause.update_context(ctx) return ctx + def _transactions(self): + return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.where_clauses])) + class InsertStatement(AssignmentStatement): """ an cql insert select statement """ @@ -639,6 +675,9 @@ class InsertStatement(AssignmentStatement): if self.timestamp: qs += ["USING TIMESTAMP {}".format(self.timestamp_normalized)] + if len(self.transactions) > 0: + qs += [self._transactions] + return ' '.join(qs) @@ -665,6 +704,9 @@ class UpdateStatement(AssignmentStatement): if self.where_clauses: qs += [self._where] + if len(self.transactions) > 0: + qs += [self._transactions] + return ' '.join(qs) diff --git a/cqlengine/tests/model/test_transaction_statements.py b/cqlengine/tests/model/test_transaction_statements.py new file mode 100644 index 00000000..8afabb1e --- /dev/null +++ b/cqlengine/tests/model/test_transaction_statements.py @@ -0,0 +1,36 @@ +__author__ = 'Tim Martin' +from uuid import uuid4 + +from mock import patch +from cqlengine.exceptions import ValidationError + +from cqlengine.tests.base import BaseCassEngTestCase +from cqlengine.models import Model +from cqlengine import columns +from cqlengine.management import sync_table, drop_table + + +class TestUpdateModel(Model): + __keyspace__ = 'test' + partition = columns.UUID(primary_key=True, default=uuid4) + cluster = columns.UUID(primary_key=True, default=uuid4) + count = columns.Integer(required=False) + text = columns.Text(required=False, index=True) + + +class ModelUpdateTests(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(ModelUpdateTests, cls).setUpClass() + sync_table(TestUpdateModel) + + @classmethod + def tearDownClass(cls): + super(ModelUpdateTests, cls).tearDownClass() + drop_table(TestUpdateModel) + + def test_transaction_insertion(self): + m = TestUpdateModel(count=5, text='something').transaction(exists=True) + m.save() + x = 10 \ No newline at end of file From 44a518100947702eca15d258a6258d2ca17b7c62 Mon Sep 17 00:00:00 2001 From: timmartin19 Date: Wed, 8 Oct 2014 14:48:35 -0400 Subject: [PATCH 2/7] Bug fixes --- cqlengine/models.py | 20 +++++++++------- cqlengine/query.py | 23 +++++++++++++------ cqlengine/statements.py | 15 ++++++++---- .../model/test_transaction_statements.py | 17 ++++++++++---- 4 files changed, 52 insertions(+), 23 deletions(-) diff --git a/cqlengine/models.py b/cqlengine/models.py index 92ce9bac..3598bcad 100644 --- a/cqlengine/models.py +++ b/cqlengine/models.py @@ -80,9 +80,14 @@ class TransactionDescriptor(object): """ def __get__(self, instance, model): if instance: - def transaction_setter(transaction): - instance._transaction = transaction + def transaction_setter(*prepared_transaction, **unprepared_transactions): + if len(prepared_transaction) > 0: + transactions = prepared_transaction[0] + else: + transactions = instance.objects.transaction(**unprepared_transactions)._transaction + instance._transaction = transactions return instance + return transaction_setter qs = model.__queryset__(model) @@ -302,6 +307,7 @@ class BaseModel(object): self._values = {} self._ttl = self.__default_ttl__ self._timestamp = None + self._transaction = None for name, column in self._columns.items(): value = values.get(name, None) @@ -532,10 +538,6 @@ class BaseModel(object): return cls.objects.filter(*args, **kwargs) - @classmethod - def transaction(cls, *args, **kwargs): - return cls.objects.transaction(*args, **kwargs) - @classmethod def get(cls, *args, **kwargs): return cls.objects.get(*args, **kwargs) @@ -554,7 +556,8 @@ class BaseModel(object): batch=self._batch, ttl=self._ttl, timestamp=self._timestamp, - consistency=self.__consistency__).save() + consistency=self.__consistency__, + transaction=self._transaction).save() #reset the value managers for v in self._values.values(): @@ -592,7 +595,8 @@ class BaseModel(object): batch=self._batch, ttl=self._ttl, timestamp=self._timestamp, - consistency=self.__consistency__).update() + consistency=self.__consistency__, + transaction=self._transaction).update() #reset the value managers for v in self._values.values(): diff --git a/cqlengine/query.py b/cqlengine/query.py index 0f042dbf..e85771b5 100644 --- a/cqlengine/query.py +++ b/cqlengine/query.py @@ -409,11 +409,12 @@ class AbstractQuerySet(object): clone._transaction.append(operator) for col_name, val in kwargs.items(): + exists = False try: column = self.model._get_column(col_name) except KeyError: if col_name in ['exists', 'not_exists']: - pass + exists = True elif col_name == 'pk__token': if not isinstance(val, Token): raise QueryException("Virtual column 'pk__token' may only be compared to Token() values") @@ -432,7 +433,7 @@ class AbstractQuerySet(object): len(val.value), len(partition_columns))) val.set_columns(partition_columns) - if isinstance(val, BaseQueryFunction): + if isinstance(val, BaseQueryFunction) or exists is True: query_val = val else: query_val = column.to_database(val) @@ -615,9 +616,14 @@ class AbstractQuerySet(object): return self._only_or_defer('defer', fields) def create(self, **kwargs): - return self.model(**kwargs).batch(self._batch).ttl(self._ttl).\ + + x = self.model(**kwargs).batch(self._batch) + x = x.ttl(self._ttl).\ consistency(self._consistency).\ - timestamp(self._timestamp).save() + timestamp(self._timestamp) + x = x.transaction(self._transaction).save() + + return x def delete(self): """ @@ -815,7 +821,7 @@ class DMLQuery(object): _consistency = None _timestamp = None - def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None): + def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, transaction=None): self.model = model self.column_family_name = self.model.column_family_name() self.instance = instance @@ -823,6 +829,7 @@ class DMLQuery(object): self._ttl = ttl self._consistency = consistency self._timestamp = timestamp + self._transaction = transaction def _execute(self, q): if self._batch: @@ -874,7 +881,8 @@ class DMLQuery(object): raise CQLEngineException("DML Query intance attribute is None") assert type(self.instance) == self.model static_update_only = True - statement = UpdateStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp) + statement = UpdateStatement(self.column_family_name, ttl=self._ttl, + timestamp=self._timestamp, transactions=self._transaction) #get defined fields and their column names for name, col in self.model._columns.items(): if not col.is_primary_key: @@ -937,7 +945,8 @@ class DMLQuery(object): if self.instance._has_counter or self.instance._can_update(): return self.update() else: - insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp) + insert = InsertStatement(self.column_family_name, ttl=self._ttl, + timestamp=self._timestamp, transactions=self._transaction) for name, col in self.instance._columns.items(): val = getattr(self.instance, name, None) if col._val_is_null(val): diff --git a/cqlengine/statements.py b/cqlengine/statements.py index 834db391..e4a2aa5b 100644 --- a/cqlengine/statements.py +++ b/cqlengine/statements.py @@ -152,6 +152,11 @@ class TransactionClause(BaseClause): def insert_tuple(self): return self.field, self.context_id + def update_context(self, ctx): + if self.field not in ['exists', 'not_exists']: + return super(TransactionClause, self).update_context(ctx) + return ctx + class ContainerUpdateClause(AssignmentClause): @@ -646,10 +651,12 @@ class AssignmentStatement(BaseCQLStatement): ctx = super(AssignmentStatement, self).get_context() for clause in self.assignments: clause.update_context(ctx) + for clause in self.transactions or []: + clause.update_context(ctx) return ctx - def _transactions(self): - return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.where_clauses])) + def _get_transactions(self): + return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.transactions])) class InsertStatement(AssignmentStatement): @@ -676,7 +683,7 @@ class InsertStatement(AssignmentStatement): qs += ["USING TIMESTAMP {}".format(self.timestamp_normalized)] if len(self.transactions) > 0: - qs += [self._transactions] + qs += [self._get_transactions()] return ' '.join(qs) @@ -705,7 +712,7 @@ class UpdateStatement(AssignmentStatement): qs += [self._where] if len(self.transactions) > 0: - qs += [self._transactions] + qs += [self._get_transactions()] return ' '.join(qs) diff --git a/cqlengine/tests/model/test_transaction_statements.py b/cqlengine/tests/model/test_transaction_statements.py index 8afabb1e..a62ac122 100644 --- a/cqlengine/tests/model/test_transaction_statements.py +++ b/cqlengine/tests/model/test_transaction_statements.py @@ -4,10 +4,13 @@ from uuid import uuid4 from mock import patch from cqlengine.exceptions import ValidationError -from cqlengine.tests.base import BaseCassEngTestCase +from unittest import TestCase from cqlengine.models import Model from cqlengine import columns from cqlengine.management import sync_table, drop_table +from cqlengine import connection + +connection.setup(['192.168.56.103'], 'test') class TestUpdateModel(Model): @@ -18,7 +21,7 @@ class TestUpdateModel(Model): text = columns.Text(required=False, index=True) -class ModelUpdateTests(BaseCassEngTestCase): +class ModelUpdateTests(TestCase): @classmethod def setUpClass(cls): @@ -31,6 +34,12 @@ class ModelUpdateTests(BaseCassEngTestCase): drop_table(TestUpdateModel) def test_transaction_insertion(self): - m = TestUpdateModel(count=5, text='something').transaction(exists=True) - m.save() + m = TestUpdateModel.objects.create(count=5, text='something') + all_models = TestUpdateModel.objects.all() + for x in all_models: + pass + m = TestUpdateModel.objects + m = m.transaction(not_exists=True) + m = m.create(count=5, text='something') + m.transaction(count=6).update(text='something else') x = 10 \ No newline at end of file From f9874c7787393dac617c00321ac58246610df4a4 Mon Sep 17 00:00:00 2001 From: timmartin19 Date: Wed, 8 Oct 2014 17:11:19 -0400 Subject: [PATCH 3/7] Added a thrown exception when a transaction fails to be applied --- .gitignore | 3 +++ cqlengine/exceptions.py | 1 + cqlengine/query.py | 8 +++++++- 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index b43bb3d1..f5766784 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,6 @@ docs/_build #iPython *.ipynb + +cqlengine/tests/* +cqlengine/tests* \ No newline at end of file diff --git a/cqlengine/exceptions.py b/cqlengine/exceptions.py index 94a51b92..40663052 100644 --- a/cqlengine/exceptions.py +++ b/cqlengine/exceptions.py @@ -2,5 +2,6 @@ class CQLEngineException(Exception): pass class ModelException(CQLEngineException): pass class ValidationError(CQLEngineException): pass +class TransactionException(CQLEngineException): pass class UndefinedKeyspaceException(CQLEngineException): pass diff --git a/cqlengine/query.py b/cqlengine/query.py index e85771b5..406d4161 100644 --- a/cqlengine/query.py +++ b/cqlengine/query.py @@ -6,7 +6,7 @@ from cqlengine.columns import Counter, List, Set from cqlengine.connection import execute -from cqlengine.exceptions import CQLEngineException, ValidationError +from cqlengine.exceptions import CQLEngineException, ValidationError, TransactionException from cqlengine.functions import Token, BaseQueryFunction, QueryValue, UnicodeMixin #CQL 3 reference: @@ -836,6 +836,12 @@ class DMLQuery(object): return self._batch.add_query(q) else: tmp = execute(q, consistency_level=self._consistency) + if tmp and tmp[0].get('[applied]', True) is False: + tmp[0].pop('[applied]') + expected = ', '.join('{0}={1}'.format(t.field, t.value) for t in q.transactions) + actual = ', '.join('{0}={1}'.format(f, v) for f, v in tmp[0].items()) + message = 'Transaction statement failed: Expected: {0} Actual: {1}'.format(expected, actual) + raise TransactionException(message) return tmp def batch(self, batch_obj): From 423dccd6fa544a38c52982997b1c7e572331c48d Mon Sep 17 00:00:00 2001 From: timmartin19 Date: Wed, 15 Oct 2014 09:31:05 -0400 Subject: [PATCH 4/7] Reverted tests to original --- cqlengine/VERSION | 2 +- cqlengine/tests/management/test_management.py | 48 +++++++++++++------ .../model/test_transaction_statements.py | 45 ----------------- setup.py | 4 +- 4 files changed, 36 insertions(+), 63 deletions(-) delete mode 100644 cqlengine/tests/model/test_transaction_statements.py diff --git a/cqlengine/VERSION b/cqlengine/VERSION index 249afd51..503a21de 100644 --- a/cqlengine/VERSION +++ b/cqlengine/VERSION @@ -1 +1 @@ -0.18.1 +0.18.2 diff --git a/cqlengine/tests/management/test_management.py b/cqlengine/tests/management/test_management.py index c2fb19af..3bda88ed 100644 --- a/cqlengine/tests/management/test_management.py +++ b/cqlengine/tests/management/test_management.py @@ -8,8 +8,9 @@ from cqlengine import management from cqlengine.tests.query.test_queryset import TestModel from cqlengine.models import Model from cqlengine import columns, SizeTieredCompactionStrategy, LeveledCompactionStrategy - - +from unittest import skipUnless +from cqlengine.connection import get_cluster +cluster = get_cluster() class CreateKeyspaceTest(BaseCassEngTestCase): def test_create_succeeeds(self): @@ -145,22 +146,25 @@ class TablePropertiesTests(BaseCassEngTestCase): sync_table(ModelWithTableProperties) expected = {'bloom_filter_fp_chance': 0.76328, - 'caching': CACHING_ALL, 'comment': 'TxfguvBdzwROQALmQBOziRMbkqVGFjqcJfVhwGR', 'gc_grace_seconds': 2063, - 'populate_io_cache_on_flush': True, 'read_repair_chance': 0.17985, - 'replicate_on_write': False # For some reason 'dclocal_read_repair_chance' in CQL is called # just 'local_read_repair_chance' in the schema table. # Source: https://issues.apache.org/jira/browse/CASSANDRA-6717 # TODO: due to a bug in the native driver i'm not seeing the local read repair chance show up # 'local_read_repair_chance': 0.50811, } + if CASSANDRA_VERSION <= 20: + expected['caching'] = CACHING_ALL + expected['replicate_on_write'] = False + + if CASSANDRA_VERSION == 20: + expected['populate_io_cache_on_flush'] = True + expected['index_interval'] = 98706 if CASSANDRA_VERSION >= 20: expected['default_time_to_live'] = 4756 - expected['index_interval'] = 98706 expected['memtable_flush_period_in_ms'] = 43681 self.assertDictContainsSubset(expected, management.get_table_settings(ModelWithTableProperties).options) @@ -186,19 +190,24 @@ class TablePropertiesTests(BaseCassEngTestCase): table_settings = management.get_table_settings(ModelWithTableProperties).options expected = {'bloom_filter_fp_chance': 0.66778, - 'caching': CACHING_NONE, 'comment': 'xirAkRWZVVvsmzRvXamiEcQkshkUIDINVJZgLYSdnGHweiBrAiJdLJkVohdRy', 'gc_grace_seconds': 96362, - 'populate_io_cache_on_flush': False, 'read_repair_chance': 0.2989, - 'replicate_on_write': True # TODO see above comment re: native driver missing local read repair chance - # 'local_read_repair_chance': 0.12732, + #'local_read_repair_chance': 0.12732, } if CASSANDRA_VERSION >= 20: expected['memtable_flush_period_in_ms'] = 60210 expected['default_time_to_live'] = 65178 + + if CASSANDRA_VERSION == 20: expected['index_interval'] = 94207 + # these featuers removed in cassandra 2.1 + if CASSANDRA_VERSION <= 20: + expected['caching'] = CACHING_NONE + expected['replicate_on_write'] = True + expected['populate_io_cache_on_flush'] = False + self.assertDictContainsSubset(expected, table_settings) @@ -247,6 +256,7 @@ class NonModelFailureTest(BaseCassEngTestCase): sync_table(self.FakeModel) +@skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0") def test_static_columns(): class StaticModel(Model): id = columns.Integer(primary_key=True) @@ -260,15 +270,23 @@ def test_static_columns(): from cqlengine.connection import get_session session = get_session() - with patch.object(session, "execute", side_effect=Exception) as m: - try: - sync_table(StaticModel) - except: - pass + with patch.object(session, "execute", wraps=session.execute) as m: + sync_table(StaticModel) assert m.call_count > 0 statement = m.call_args[0][0].query_string assert '"name" text static' in statement, statement + # if we sync again, we should not apply an alter w/ a static + sync_table(StaticModel) + + with patch.object(session, "execute", wraps=session.execute) as m2: + sync_table(StaticModel) + + assert len(m2.call_args_list) == 1 + assert "ALTER" not in m2.call_args[0][0].query_string + + + diff --git a/cqlengine/tests/model/test_transaction_statements.py b/cqlengine/tests/model/test_transaction_statements.py deleted file mode 100644 index a62ac122..00000000 --- a/cqlengine/tests/model/test_transaction_statements.py +++ /dev/null @@ -1,45 +0,0 @@ -__author__ = 'Tim Martin' -from uuid import uuid4 - -from mock import patch -from cqlengine.exceptions import ValidationError - -from unittest import TestCase -from cqlengine.models import Model -from cqlengine import columns -from cqlengine.management import sync_table, drop_table -from cqlengine import connection - -connection.setup(['192.168.56.103'], 'test') - - -class TestUpdateModel(Model): - __keyspace__ = 'test' - partition = columns.UUID(primary_key=True, default=uuid4) - cluster = columns.UUID(primary_key=True, default=uuid4) - count = columns.Integer(required=False) - text = columns.Text(required=False, index=True) - - -class ModelUpdateTests(TestCase): - - @classmethod - def setUpClass(cls): - super(ModelUpdateTests, cls).setUpClass() - sync_table(TestUpdateModel) - - @classmethod - def tearDownClass(cls): - super(ModelUpdateTests, cls).tearDownClass() - drop_table(TestUpdateModel) - - def test_transaction_insertion(self): - m = TestUpdateModel.objects.create(count=5, text='something') - all_models = TestUpdateModel.objects.all() - for x in all_models: - pass - m = TestUpdateModel.objects - m = m.transaction(not_exists=True) - m = m.create(count=5, text='something') - m.transaction(count=6).update(text='something else') - x = 10 \ No newline at end of file diff --git a/setup.py b/setup.py index cf35583c..0b9a9af3 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ Cassandra CQL 3 Object Mapper for Python """ setup( - name='cqlengine', + name='vk-cqlengine', version=version, description='Cassandra CQL 3 Object Mapper for Python', long_description=long_desc, @@ -35,7 +35,7 @@ setup( author='Blake Eggleston, Jon Haddad', author_email='bdeggleston@gmail.com, jon@jonhaddad.com', url='https://github.com/cqlengine/cqlengine', - license='BSD', + license='Private', packages=find_packages(), include_package_data=True, ) From 6e4dc3afdfbba5253218c7bc9565c6666bebe410 Mon Sep 17 00:00:00 2001 From: timmartin19 Date: Thu, 16 Oct 2014 12:18:56 -0400 Subject: [PATCH 5/7] Small fixes and tests added --- cqlengine/connection.py | 9 ++++++++- cqlengine/models.py | 7 ++++--- cqlengine/query.py | 11 +++-------- cqlengine/statements.py | 6 +----- cqlengine/tests/base.py | 12 +++++++----- 5 files changed, 23 insertions(+), 22 deletions(-) diff --git a/cqlengine/connection.py b/cqlengine/connection.py index 444400d5..8e465e8d 100644 --- a/cqlengine/connection.py +++ b/cqlengine/connection.py @@ -5,6 +5,7 @@ from collections import namedtuple from cassandra.cluster import Cluster, NoHostAvailable from cassandra.query import SimpleStatement, Statement +from cqlengine.exceptions import TransactionException import six try: @@ -88,7 +89,7 @@ def setup( def execute(query, params=None, consistency_level=None): handle_lazy_connect() - + statement = query if not session: raise CQLEngineException("It is required to setup() cqlengine before executing queries") @@ -111,6 +112,12 @@ def execute(query, params=None, consistency_level=None): params = params or {} result = session.execute(query, params) + if result and result[0].get('[applied]', True) is False: + result[0].pop('[applied]') + expected = ', '.join('{0}={1}'.format(t.field, t.value) for t in statement.transactions) + actual = ', '.join('{0}={1}'.format(f, v) for f, v in result[0].items()) + message = 'Transaction statement failed: Expected: {0} Actual: {1}'.format(expected, actual) + raise TransactionException(message) return result def get_session(): diff --git a/cqlengine/models.py b/cqlengine/models.py index 3598bcad..b60521a7 100644 --- a/cqlengine/models.py +++ b/cqlengine/models.py @@ -91,9 +91,10 @@ class TransactionDescriptor(object): return transaction_setter qs = model.__queryset__(model) - def transaction_setter(transaction): - qs._transaction = transaction - return instance + def transaction_setter(**unprepared_transactions): + transactions = model.objects.transaction(**unprepared_transactions)._transaction + qs._transaction = transactions + return qs return transaction_setter def __call__(self, *args, **kwargs): diff --git a/cqlengine/query.py b/cqlengine/query.py index 406d4161..a4f701a6 100644 --- a/cqlengine/query.py +++ b/cqlengine/query.py @@ -413,7 +413,7 @@ class AbstractQuerySet(object): try: column = self.model._get_column(col_name) except KeyError: - if col_name in ['exists', 'not_exists']: + if col_name == 'not_exists': exists = True elif col_name == 'pk__token': if not isinstance(val, Token): @@ -767,7 +767,8 @@ class ModelQuerySet(AbstractQuerySet): return nulled_columns = set() - us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, timestamp=self._timestamp) + us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, + timestamp=self._timestamp, transactions=self._transaction) for name, val in values.items(): col_name, col_op = self._parse_filter_arg(name) col = self.model._columns.get(col_name) @@ -836,12 +837,6 @@ class DMLQuery(object): return self._batch.add_query(q) else: tmp = execute(q, consistency_level=self._consistency) - if tmp and tmp[0].get('[applied]', True) is False: - tmp[0].pop('[applied]') - expected = ', '.join('{0}={1}'.format(t.field, t.value) for t in q.transactions) - actual = ', '.join('{0}={1}'.format(f, v) for f, v in tmp[0].items()) - message = 'Transaction statement failed: Expected: {0} Actual: {1}'.format(expected, actual) - raise TransactionException(message) return tmp def batch(self, batch_obj): diff --git a/cqlengine/statements.py b/cqlengine/statements.py index e4a2aa5b..5618fa49 100644 --- a/cqlengine/statements.py +++ b/cqlengine/statements.py @@ -143,8 +143,6 @@ class TransactionClause(BaseClause): """ A single variable transaction statement """ def __unicode__(self): - if self.field == 'exists': - return u'EXISTS' if self.field == 'not_exists': return u'NOT EXISTS' return u'"{}" = %({})s'.format(self.field, self.context_id) @@ -153,9 +151,7 @@ class TransactionClause(BaseClause): return self.field, self.context_id def update_context(self, ctx): - if self.field not in ['exists', 'not_exists']: - return super(TransactionClause, self).update_context(ctx) - return ctx + return super(TransactionClause, self).update_context(ctx) class ContainerUpdateClause(AssignmentClause): diff --git a/cqlengine/tests/base.py b/cqlengine/tests/base.py index dea0cc19..90581e42 100644 --- a/cqlengine/tests/base.py +++ b/cqlengine/tests/base.py @@ -2,9 +2,9 @@ from unittest import TestCase import os import sys import six -from cqlengine.connection import get_session +from cqlengine import connection -CASSANDRA_VERSION = int(os.environ['CASSANDRA_VERSION']) +CASSANDRA_VERSION = 20 #int(os.environ['CASSANDRA_VERSION']) class BaseCassEngTestCase(TestCase): @@ -13,9 +13,11 @@ class BaseCassEngTestCase(TestCase): # super(BaseCassEngTestCase, cls).setUpClass() session = None - def setUp(self): - self.session = get_session() - super(BaseCassEngTestCase, self).setUp() + @classmethod + def setUpClass(cls): + connection.setup(['192.168.56.103'], 'test') + cls.session = connection.get_session() + super(BaseCassEngTestCase, cls).setUpClass() def assertHasAttr(self, obj, attr): self.assertTrue(hasattr(obj, attr), From de18d562962741818fa0591e8c08c0342406ba87 Mon Sep 17 00:00:00 2001 From: timmartin19 Date: Thu, 16 Oct 2014 12:19:54 -0400 Subject: [PATCH 6/7] Reverted test base to original --- cqlengine/tests/base.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/cqlengine/tests/base.py b/cqlengine/tests/base.py index 90581e42..dea0cc19 100644 --- a/cqlengine/tests/base.py +++ b/cqlengine/tests/base.py @@ -2,9 +2,9 @@ from unittest import TestCase import os import sys import six -from cqlengine import connection +from cqlengine.connection import get_session -CASSANDRA_VERSION = 20 #int(os.environ['CASSANDRA_VERSION']) +CASSANDRA_VERSION = int(os.environ['CASSANDRA_VERSION']) class BaseCassEngTestCase(TestCase): @@ -13,11 +13,9 @@ class BaseCassEngTestCase(TestCase): # super(BaseCassEngTestCase, cls).setUpClass() session = None - @classmethod - def setUpClass(cls): - connection.setup(['192.168.56.103'], 'test') - cls.session = connection.get_session() - super(BaseCassEngTestCase, cls).setUpClass() + def setUp(self): + self.session = get_session() + super(BaseCassEngTestCase, self).setUp() def assertHasAttr(self, obj, attr): self.assertTrue(hasattr(obj, attr), From 281fb84a1551763ad68c7b11778e7bcdc67a6862 Mon Sep 17 00:00:00 2001 From: timmartin19 Date: Thu, 16 Oct 2014 13:32:02 -0400 Subject: [PATCH 7/7] Added tests for real. --- .gitignore | 3 - .../statements/test_transaction_statement.py | 30 +++ cqlengine/tests/test_ifnotexists.py | 178 ++++++++++++++++++ cqlengine/tests/test_transaction.py | 78 ++++++++ 4 files changed, 286 insertions(+), 3 deletions(-) create mode 100644 cqlengine/tests/statements/test_transaction_statement.py create mode 100644 cqlengine/tests/test_ifnotexists.py create mode 100644 cqlengine/tests/test_transaction.py diff --git a/.gitignore b/.gitignore index f5766784..b43bb3d1 100644 --- a/.gitignore +++ b/.gitignore @@ -45,6 +45,3 @@ docs/_build #iPython *.ipynb - -cqlengine/tests/* -cqlengine/tests* \ No newline at end of file diff --git a/cqlengine/tests/statements/test_transaction_statement.py b/cqlengine/tests/statements/test_transaction_statement.py new file mode 100644 index 00000000..4ead43a3 --- /dev/null +++ b/cqlengine/tests/statements/test_transaction_statement.py @@ -0,0 +1,30 @@ +__author__ = 'Tim Martin' +from unittest import TestCase +from cqlengine.statements import TransactionClause +import six + +class TestTransactionClause(TestCase): + + def test_not_exists_clause(self): + tc = TransactionClause('not_exists', True) + + self.assertEqual('NOT EXISTS', six.text_type(tc)) + self.assertEqual('NOT EXISTS', str(tc)) + + def test_normal_transaction(self): + tc = TransactionClause('some_value', 23) + tc.set_context_id(3) + + self.assertEqual('"some_value" = %(3)s', six.text_type(tc)) + self.assertEqual('"some_value" = %(3)s', str(tc)) + + def test_equality(self): + tc1 = TransactionClause('some_value', 5) + tc2 = TransactionClause('some_value', 5) + + assert tc1 == tc2 + + tc3 = TransactionClause('not_exists', True) + tc4 = TransactionClause('not_exists', True) + + assert tc3 == tc4 \ No newline at end of file diff --git a/cqlengine/tests/test_ifnotexists.py b/cqlengine/tests/test_ifnotexists.py new file mode 100644 index 00000000..202a1c03 --- /dev/null +++ b/cqlengine/tests/test_ifnotexists.py @@ -0,0 +1,178 @@ +from unittest import skipUnless +from cqlengine.management import sync_table, drop_table, create_keyspace, delete_keyspace +from cqlengine.tests.base import BaseCassEngTestCase +from cqlengine.models import Model +from cqlengine.exceptions import LWTException +from cqlengine import columns, BatchQuery +from uuid import uuid4 +import mock +from cqlengine.connection import get_cluster + +cluster = get_cluster() + + +class TestIfNotExistsModel(Model): + + __keyspace__ = 'cqlengine_test_lwt' + + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + count = columns.Integer() + text = columns.Text(required=False) + + +class BaseIfNotExistsTest(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(BaseIfNotExistsTest, cls).setUpClass() + """ + when receiving an insert statement with 'if not exist', cassandra would + perform a read with QUORUM level. Unittest would be failed if replica_factor + is 3 and one node only. Therefore I have create a new keyspace with + replica_factor:1. + """ + create_keyspace(TestIfNotExistsModel.__keyspace__, replication_factor=1) + sync_table(TestIfNotExistsModel) + + @classmethod + def tearDownClass(cls): + super(BaseCassEngTestCase, cls).tearDownClass() + drop_table(TestIfNotExistsModel) + delete_keyspace(TestIfNotExistsModel.__keyspace__) + + +class IfNotExistsInsertTests(BaseIfNotExistsTest): + + @skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0") + def test_insert_if_not_exists_success(self): + """ tests that insertion with if_not_exists work as expected """ + + id = uuid4() + + TestIfNotExistsModel.create(id=id, count=8, text='123456789') + self.assertRaises( + LWTException, + TestIfNotExistsModel.if_not_exists().create, id=id, count=9, text='111111111111' + ) + + q = TestIfNotExistsModel.objects(id=id) + self.assertEqual(len(q), 1) + + tm = q.first() + self.assertEquals(tm.count, 8) + self.assertEquals(tm.text, '123456789') + + def test_insert_if_not_exists_failure(self): + """ tests that insertion with if_not_exists failure """ + + id = uuid4() + + TestIfNotExistsModel.create(id=id, count=8, text='123456789') + TestIfNotExistsModel.create(id=id, count=9, text='111111111111') + + q = TestIfNotExistsModel.objects(id=id) + self.assertEquals(len(q), 1) + + tm = q.first() + self.assertEquals(tm.count, 9) + self.assertEquals(tm.text, '111111111111') + + @skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0") + def test_batch_insert_if_not_exists_success(self): + """ tests that batch insertion with if_not_exists work as expected """ + + id = uuid4() + + with BatchQuery() as b: + TestIfNotExistsModel.batch(b).if_not_exists().create(id=id, count=8, text='123456789') + + b = BatchQuery() + TestIfNotExistsModel.batch(b).if_not_exists().create(id=id, count=9, text='111111111111') + self.assertRaises(LWTException, b.execute) + + q = TestIfNotExistsModel.objects(id=id) + self.assertEqual(len(q), 1) + + tm = q.first() + self.assertEquals(tm.count, 8) + self.assertEquals(tm.text, '123456789') + + def test_batch_insert_if_not_exists_failure(self): + """ tests that batch insertion with if_not_exists failure """ + id = uuid4() + + with BatchQuery() as b: + TestIfNotExistsModel.batch(b).create(id=id, count=8, text='123456789') + with BatchQuery() as b: + TestIfNotExistsModel.batch(b).create(id=id, count=9, text='111111111111') + + q = TestIfNotExistsModel.objects(id=id) + self.assertEquals(len(q), 1) + + tm = q.first() + self.assertEquals(tm.count, 9) + self.assertEquals(tm.text, '111111111111') + + +class IfNotExistsModelTest(BaseIfNotExistsTest): + + def test_if_not_exists_included_on_create(self): + """ tests that if_not_exists on models works as expected """ + + with mock.patch.object(self.session, 'execute') as m: + TestIfNotExistsModel.if_not_exists().create(count=8) + + query = m.call_args[0][0].query_string + self.assertIn("IF NOT EXISTS", query) + + def test_if_not_exists_included_on_save(self): + """ tests if we correctly put 'IF NOT EXISTS' for insert statement """ + + with mock.patch.object(self.session, 'execute') as m: + tm = TestIfNotExistsModel(count=8) + tm.if_not_exists(True).save() + + query = m.call_args[0][0].query_string + self.assertIn("IF NOT EXISTS", query) + + def test_queryset_is_returned_on_class(self): + """ ensure we get a queryset description back """ + qs = TestIfNotExistsModel.if_not_exists() + self.assertTrue(isinstance(qs, TestIfNotExistsModel.__queryset__), type(qs)) + + def test_batch_if_not_exists(self): + """ ensure 'IF NOT EXISTS' exists in statement when in batch """ + with mock.patch.object(self.session, 'execute') as m: + with BatchQuery() as b: + TestIfNotExistsModel.batch(b).if_not_exists().create(count=8) + + self.assertIn("IF NOT EXISTS", m.call_args[0][0].query_string) + + +class IfNotExistsInstanceTest(BaseIfNotExistsTest): + + def test_instance_is_returned(self): + """ + ensures that we properly handle the instance.if_not_exists(True).save() + scenario + """ + o = TestIfNotExistsModel.create(text="whatever") + o.text = "new stuff" + o = o.if_not_exists(True) + self.assertEqual(True, o._if_not_exists) + + def test_if_not_exists_is_not_include_with_query_on_update(self): + """ + make sure we don't put 'IF NOT EXIST' in update statements + """ + o = TestIfNotExistsModel.create(text="whatever") + o.text = "new stuff" + o = o.if_not_exists(True) + + with mock.patch.object(self.session, 'execute') as m: + o.save() + + query = m.call_args[0][0].query_string + self.assertNotIn("IF NOT EXIST", query) + + diff --git a/cqlengine/tests/test_transaction.py b/cqlengine/tests/test_transaction.py new file mode 100644 index 00000000..9dbe9632 --- /dev/null +++ b/cqlengine/tests/test_transaction.py @@ -0,0 +1,78 @@ +__author__ = 'Tim Martin' +from cqlengine.management import sync_table, drop_table +from cqlengine.tests.base import BaseCassEngTestCase +from cqlengine.models import Model +from cqlengine.exceptions import TransactionException +from uuid import uuid4 +from cqlengine import columns +import mock +from cqlengine import ALL, BatchQuery + + +class TestTransactionModel(Model): + __keyspace__ = 'test' + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + count = columns.Integer() + text = columns.Text(required=False) + + +class TestTransaction(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestTransaction, cls).setUpClass() + sync_table(TestTransactionModel) + + @classmethod + def tearDownClass(cls): + super(TestTransaction, cls).tearDownClass() + drop_table(TestTransactionModel) + + def test_create_uses_transaction(self): + qs = TestTransactionModel.transaction(not_exists=True) + with mock.patch.object(self.session, 'execute') as m: + qs.create(text='blah blah', count=2) + args = m.call_args + self.assertIn('IF NOT EXISTS', args[0][0].query_string) + + def test_queryset_returned_on_create(self): + qs = TestTransactionModel.transaction(not_exists=True) + self.assertTrue(isinstance(qs, TestTransactionModel.__queryset__), type(qs)) + + def test_update_using_transaction(self): + t = TestTransactionModel.create(text='blah blah') + t.text = 'new blah' + with mock.patch.object(self.session, 'execute') as m: + t.transaction(text='blah blah').save() + + args = m.call_args + self.assertIn('IF "text" = %(0)s', args[0][0].query_string) + + def test_update_failure(self): + t = TestTransactionModel.create(text='blah blah') + t.text = 'new blah' + t = t.transaction(text='something wrong') + self.assertRaises(TransactionException, t.save) + + def test_creation_failure(self): + t = TestTransactionModel.create(text='blah blah') + t_clone = TestTransactionModel.transaction(not_exists=True) + self.assertRaises(TransactionException, t_clone.create, id=t.id, count=t.count, text=t.text) + + def test_blind_update(self): + t = TestTransactionModel.create(text='blah blah') + t.text = 'something else' + uid = t.id + + with mock.patch.object(self.session, 'execute') as m: + TestTransactionModel.objects(id=uid).transaction(text='blah blah').update(text='oh hey der') + + args = m.call_args + self.assertIn('IF "text" = %(1)s', args[0][0].query_string) + + def test_blind_update_fail(self): + t = TestTransactionModel.create(text='blah blah') + t.text = 'something else' + uid = t.id + qs = TestTransactionModel.objects(id=uid).transaction(text='Not dis!') + self.assertRaises(TransactionException, qs.update, text='this will never work') \ No newline at end of file