diff --git a/cqlengine/exceptions.py b/cqlengine/exceptions.py index 94a51b92..f40f8af7 100644 --- a/cqlengine/exceptions.py +++ b/cqlengine/exceptions.py @@ -4,3 +4,4 @@ class ModelException(CQLEngineException): pass class ValidationError(CQLEngineException): pass class UndefinedKeyspaceException(CQLEngineException): pass +class LWTException(CQLEngineException): pass diff --git a/cqlengine/models.py b/cqlengine/models.py index d725e42c..77812426 100644 --- a/cqlengine/models.py +++ b/cqlengine/models.py @@ -116,6 +116,23 @@ class TimestampDescriptor(object): def __call__(self, *args, **kwargs): raise NotImplementedError +class IfNotExistsDescriptor(object): + """ + return a query set descriptor with a if_not_exists flag specified + """ + def __get__(self, instance, model): + if instance: + # instance method + def ifnotexists_setter(ife): + instance._if_not_exists = ife + return instance + return ifnotexists_setter + + return model.objects.if_not_exists + + def __call__(self, *args, **kwargs): + raise NotImplementedError + class ConsistencyDescriptor(object): """ returns a query set descriptor if called on Class, instance if it was an instance call @@ -225,6 +242,8 @@ class BaseModel(object): # custom timestamps, see USING TIMESTAMP X timestamp = TimestampDescriptor() + + if_not_exists = IfNotExistsDescriptor() # _len is lazily created by __len__ @@ -276,6 +295,8 @@ class BaseModel(object): _timestamp = None # optional timestamp to include with the operation (USING TIMESTAMP) + _if_not_exists = False # optional if_not_exists flag to check existence before insertion + def __init__(self, **values): self._values = {} self._ttl = self.__default_ttl__ @@ -528,7 +549,8 @@ class BaseModel(object): batch=self._batch, ttl=self._ttl, timestamp=self._timestamp, - consistency=self.__consistency__).save() + consistency=self.__consistency__, + if_not_exists=self._if_not_exists).save() #reset the value managers for v in self._values.values(): diff --git a/cqlengine/query.py b/cqlengine/query.py index de929c76..fd2f0072 100644 --- a/cqlengine/query.py +++ b/cqlengine/query.py @@ -6,7 +6,7 @@ from cqlengine.columns import Counter, List, Set from cqlengine.connection import execute -from cqlengine.exceptions import CQLEngineException, ValidationError +from cqlengine.exceptions import CQLEngineException, ValidationError, LWTException from cqlengine.functions import Token, BaseQueryFunction, QueryValue, UnicodeMixin #CQL 3 reference: @@ -22,6 +22,17 @@ class MultipleObjectsReturned(QueryException): pass import six + +def check_applied(result): + """ + check if result contains some column '[applied]' with false value, + if that value is false, it means our light-weight transaction didn't + applied to database. + """ + if result and '[applied]' in result[0] and result[0]['[applied]'] == False: + raise LWTException('') + + class AbstractQueryableColumn(UnicodeMixin): """ exposes cql query operators through pythons @@ -171,7 +182,8 @@ class BatchQuery(object): query_list.append('APPLY BATCH;') - execute('\n'.join(query_list), parameters, self._consistency) + tmp = execute('\n'.join(query_list), parameters, self._consistency) + check_applied(tmp) self.queries = [] self._execute_callbacks() @@ -220,6 +232,7 @@ class AbstractQuerySet(object): self._ttl = getattr(model, '__default_ttl__', None) self._consistency = None self._timestamp = None + self._if_not_exists = False @property def column_family_name(self): @@ -569,7 +582,7 @@ class AbstractQuerySet(object): def create(self, **kwargs): return self.model(**kwargs).batch(self._batch).ttl(self._ttl).\ - consistency(self._consistency).\ + consistency(self._consistency).if_not_exists(self._if_not_exists).\ timestamp(self._timestamp).save() def delete(self): @@ -708,6 +721,11 @@ class ModelQuerySet(AbstractQuerySet): clone._timestamp = timestamp return clone + def if_not_exists(self): + clone = copy.deepcopy(self) + clone._if_not_exists = True + return clone + def update(self, **values): """ Updates the rows in this queryset """ if not values: @@ -767,8 +785,9 @@ class DMLQuery(object): _ttl = None _consistency = None _timestamp = None + _if_not_exists = False - def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None): + def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False): self.model = model self.column_family_name = self.model.column_family_name() self.instance = instance @@ -776,12 +795,16 @@ class DMLQuery(object): self._ttl = ttl self._consistency = consistency self._timestamp = timestamp + self._if_not_exists = if_not_exists def _execute(self, q): if self._batch: return self._batch.add_query(q) else: tmp = execute(q, consistency_level=self._consistency) + if self._if_not_exists: + check_applied(tmp) + return tmp def batch(self, batch_obj): @@ -890,7 +913,7 @@ class DMLQuery(object): if self.instance._has_counter or self.instance._can_update(): return self.update() else: - insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp) + insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, if_not_exists=self._if_not_exists) for name, col in self.instance._columns.items(): val = getattr(self.instance, name, None) if col._val_is_null(val): @@ -906,7 +929,6 @@ class DMLQuery(object): # caused by pointless update queries if not insert.is_empty: self._execute(insert) - # delete any nulled columns self._delete_null_columns() diff --git a/cqlengine/statements.py b/cqlengine/statements.py index f1681a16..a6cdf4b3 100644 --- a/cqlengine/statements.py +++ b/cqlengine/statements.py @@ -619,6 +619,24 @@ class AssignmentStatement(BaseCQLStatement): class InsertStatement(AssignmentStatement): """ an cql insert select statement """ + def __init__(self, + table, + assignments=None, + consistency=None, + where=None, + ttl=None, + timestamp=None, + if_not_exists=False): + super(InsertStatement, self).__init__( + table, + assignments=assignments, + consistency=consistency, + where=where, + ttl=ttl, + timestamp=timestamp) + + self.if_not_exists = if_not_exists + def add_where_clause(self, clause): raise StatementException("Cannot add where clauses to insert statements") @@ -633,6 +651,9 @@ class InsertStatement(AssignmentStatement): qs += ['VALUES'] qs += ["({})".format(', '.join(['%({})s'.format(v) for v in values]))] + if self.if_not_exists: + qs += ["IF NOT EXISTS"] + if self.ttl: qs += ["USING TTL {}".format(self.ttl)] diff --git a/cqlengine/tests/test_ifnotexists.py b/cqlengine/tests/test_ifnotexists.py new file mode 100644 index 00000000..202a1c03 --- /dev/null +++ b/cqlengine/tests/test_ifnotexists.py @@ -0,0 +1,178 @@ +from unittest import skipUnless +from cqlengine.management import sync_table, drop_table, create_keyspace, delete_keyspace +from cqlengine.tests.base import BaseCassEngTestCase +from cqlengine.models import Model +from cqlengine.exceptions import LWTException +from cqlengine import columns, BatchQuery +from uuid import uuid4 +import mock +from cqlengine.connection import get_cluster + +cluster = get_cluster() + + +class TestIfNotExistsModel(Model): + + __keyspace__ = 'cqlengine_test_lwt' + + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + count = columns.Integer() + text = columns.Text(required=False) + + +class BaseIfNotExistsTest(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(BaseIfNotExistsTest, cls).setUpClass() + """ + when receiving an insert statement with 'if not exist', cassandra would + perform a read with QUORUM level. Unittest would be failed if replica_factor + is 3 and one node only. Therefore I have create a new keyspace with + replica_factor:1. + """ + create_keyspace(TestIfNotExistsModel.__keyspace__, replication_factor=1) + sync_table(TestIfNotExistsModel) + + @classmethod + def tearDownClass(cls): + super(BaseCassEngTestCase, cls).tearDownClass() + drop_table(TestIfNotExistsModel) + delete_keyspace(TestIfNotExistsModel.__keyspace__) + + +class IfNotExistsInsertTests(BaseIfNotExistsTest): + + @skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0") + def test_insert_if_not_exists_success(self): + """ tests that insertion with if_not_exists work as expected """ + + id = uuid4() + + TestIfNotExistsModel.create(id=id, count=8, text='123456789') + self.assertRaises( + LWTException, + TestIfNotExistsModel.if_not_exists().create, id=id, count=9, text='111111111111' + ) + + q = TestIfNotExistsModel.objects(id=id) + self.assertEqual(len(q), 1) + + tm = q.first() + self.assertEquals(tm.count, 8) + self.assertEquals(tm.text, '123456789') + + def test_insert_if_not_exists_failure(self): + """ tests that insertion with if_not_exists failure """ + + id = uuid4() + + TestIfNotExistsModel.create(id=id, count=8, text='123456789') + TestIfNotExistsModel.create(id=id, count=9, text='111111111111') + + q = TestIfNotExistsModel.objects(id=id) + self.assertEquals(len(q), 1) + + tm = q.first() + self.assertEquals(tm.count, 9) + self.assertEquals(tm.text, '111111111111') + + @skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0") + def test_batch_insert_if_not_exists_success(self): + """ tests that batch insertion with if_not_exists work as expected """ + + id = uuid4() + + with BatchQuery() as b: + TestIfNotExistsModel.batch(b).if_not_exists().create(id=id, count=8, text='123456789') + + b = BatchQuery() + TestIfNotExistsModel.batch(b).if_not_exists().create(id=id, count=9, text='111111111111') + self.assertRaises(LWTException, b.execute) + + q = TestIfNotExistsModel.objects(id=id) + self.assertEqual(len(q), 1) + + tm = q.first() + self.assertEquals(tm.count, 8) + self.assertEquals(tm.text, '123456789') + + def test_batch_insert_if_not_exists_failure(self): + """ tests that batch insertion with if_not_exists failure """ + id = uuid4() + + with BatchQuery() as b: + TestIfNotExistsModel.batch(b).create(id=id, count=8, text='123456789') + with BatchQuery() as b: + TestIfNotExistsModel.batch(b).create(id=id, count=9, text='111111111111') + + q = TestIfNotExistsModel.objects(id=id) + self.assertEquals(len(q), 1) + + tm = q.first() + self.assertEquals(tm.count, 9) + self.assertEquals(tm.text, '111111111111') + + +class IfNotExistsModelTest(BaseIfNotExistsTest): + + def test_if_not_exists_included_on_create(self): + """ tests that if_not_exists on models works as expected """ + + with mock.patch.object(self.session, 'execute') as m: + TestIfNotExistsModel.if_not_exists().create(count=8) + + query = m.call_args[0][0].query_string + self.assertIn("IF NOT EXISTS", query) + + def test_if_not_exists_included_on_save(self): + """ tests if we correctly put 'IF NOT EXISTS' for insert statement """ + + with mock.patch.object(self.session, 'execute') as m: + tm = TestIfNotExistsModel(count=8) + tm.if_not_exists(True).save() + + query = m.call_args[0][0].query_string + self.assertIn("IF NOT EXISTS", query) + + def test_queryset_is_returned_on_class(self): + """ ensure we get a queryset description back """ + qs = TestIfNotExistsModel.if_not_exists() + self.assertTrue(isinstance(qs, TestIfNotExistsModel.__queryset__), type(qs)) + + def test_batch_if_not_exists(self): + """ ensure 'IF NOT EXISTS' exists in statement when in batch """ + with mock.patch.object(self.session, 'execute') as m: + with BatchQuery() as b: + TestIfNotExistsModel.batch(b).if_not_exists().create(count=8) + + self.assertIn("IF NOT EXISTS", m.call_args[0][0].query_string) + + +class IfNotExistsInstanceTest(BaseIfNotExistsTest): + + def test_instance_is_returned(self): + """ + ensures that we properly handle the instance.if_not_exists(True).save() + scenario + """ + o = TestIfNotExistsModel.create(text="whatever") + o.text = "new stuff" + o = o.if_not_exists(True) + self.assertEqual(True, o._if_not_exists) + + def test_if_not_exists_is_not_include_with_query_on_update(self): + """ + make sure we don't put 'IF NOT EXIST' in update statements + """ + o = TestIfNotExistsModel.create(text="whatever") + o.text = "new stuff" + o = o.if_not_exists(True) + + with mock.patch.object(self.session, 'execute') as m: + o.save() + + query = m.call_args[0][0].query_string + self.assertNotIn("IF NOT EXIST", query) + + diff --git a/docs/topics/models.rst b/docs/topics/models.rst index f2e328d2..618838ec 100644 --- a/docs/topics/models.rst +++ b/docs/topics/models.rst @@ -188,6 +188,25 @@ Model Methods Sets the ttl values to run instance updates and inserts queries with. + .. method:: if_not_exists() + + Check the existence of an object before insertion. The existence of an + object is determined by its primary key(s). And please note using this flag + would incur performance cost. + + if the insertion didn't applied, a LWTException exception would be raised. + + *Example* + + .. code-block:: python + try: + TestIfNotExistsModel.if_not_exists().create(id=id, count=9, text='111111111111') + except LWTException as e: + # handle failure case + print e.existing # existing object + + This method is supported on Cassandra 2.0 or later. + .. method:: update(**values) Performs an update on the model instance. You can pass in values to set on the model