Merge pull request #267 from mission-liao/if_not_exist_insert
If not exist insert
This commit is contained in:
		@@ -4,3 +4,4 @@ class ModelException(CQLEngineException): pass
 | 
			
		||||
class ValidationError(CQLEngineException): pass
 | 
			
		||||
 | 
			
		||||
class UndefinedKeyspaceException(CQLEngineException): pass
 | 
			
		||||
class LWTException(CQLEngineException): pass
 | 
			
		||||
 
 | 
			
		||||
@@ -116,6 +116,23 @@ class TimestampDescriptor(object):
 | 
			
		||||
    def __call__(self, *args, **kwargs):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
class IfNotExistsDescriptor(object):
 | 
			
		||||
    """
 | 
			
		||||
    return a query set descriptor with a if_not_exists flag specified
 | 
			
		||||
    """
 | 
			
		||||
    def __get__(self, instance, model):
 | 
			
		||||
        if instance:
 | 
			
		||||
            # instance method
 | 
			
		||||
            def ifnotexists_setter(ife):
 | 
			
		||||
                instance._if_not_exists = ife
 | 
			
		||||
                return instance
 | 
			
		||||
            return ifnotexists_setter
 | 
			
		||||
 | 
			
		||||
        return model.objects.if_not_exists
 | 
			
		||||
 | 
			
		||||
    def __call__(self, *args, **kwargs):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
class ConsistencyDescriptor(object):
 | 
			
		||||
    """
 | 
			
		||||
    returns a query set descriptor if called on Class, instance if it was an instance call
 | 
			
		||||
@@ -225,6 +242,8 @@ class BaseModel(object):
 | 
			
		||||
 | 
			
		||||
    # custom timestamps, see USING TIMESTAMP X
 | 
			
		||||
    timestamp = TimestampDescriptor()
 | 
			
		||||
    
 | 
			
		||||
    if_not_exists = IfNotExistsDescriptor()
 | 
			
		||||
 | 
			
		||||
    # _len is lazily created by __len__
 | 
			
		||||
 | 
			
		||||
@@ -276,6 +295,8 @@ class BaseModel(object):
 | 
			
		||||
 | 
			
		||||
    _timestamp = None # optional timestamp to include with the operation (USING TIMESTAMP)
 | 
			
		||||
 | 
			
		||||
    _if_not_exists = False # optional if_not_exists flag to check existence before insertion
 | 
			
		||||
 | 
			
		||||
    def __init__(self, **values):
 | 
			
		||||
        self._values = {}
 | 
			
		||||
        self._ttl = self.__default_ttl__
 | 
			
		||||
@@ -528,7 +549,8 @@ class BaseModel(object):
 | 
			
		||||
                          batch=self._batch,
 | 
			
		||||
                          ttl=self._ttl,
 | 
			
		||||
                          timestamp=self._timestamp,
 | 
			
		||||
                          consistency=self.__consistency__).save()
 | 
			
		||||
                          consistency=self.__consistency__,
 | 
			
		||||
                          if_not_exists=self._if_not_exists).save()
 | 
			
		||||
 | 
			
		||||
        #reset the value managers
 | 
			
		||||
        for v in self._values.values():
 | 
			
		||||
 
 | 
			
		||||
@@ -6,7 +6,7 @@ from cqlengine.columns import Counter, List, Set
 | 
			
		||||
 | 
			
		||||
from cqlengine.connection import execute
 | 
			
		||||
 | 
			
		||||
from cqlengine.exceptions import CQLEngineException, ValidationError
 | 
			
		||||
from cqlengine.exceptions import CQLEngineException, ValidationError, LWTException
 | 
			
		||||
from cqlengine.functions import Token, BaseQueryFunction, QueryValue, UnicodeMixin
 | 
			
		||||
 | 
			
		||||
#CQL 3 reference:
 | 
			
		||||
@@ -22,6 +22,17 @@ class MultipleObjectsReturned(QueryException): pass
 | 
			
		||||
 | 
			
		||||
import six
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_applied(result):
 | 
			
		||||
    """
 | 
			
		||||
    check if result contains some column '[applied]' with false value,
 | 
			
		||||
    if that value is false, it means our light-weight transaction didn't
 | 
			
		||||
    applied to database.
 | 
			
		||||
    """
 | 
			
		||||
    if result and '[applied]' in result[0] and result[0]['[applied]'] == False:
 | 
			
		||||
        raise LWTException('')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AbstractQueryableColumn(UnicodeMixin):
 | 
			
		||||
    """
 | 
			
		||||
    exposes cql query operators through pythons
 | 
			
		||||
@@ -171,7 +182,8 @@ class BatchQuery(object):
 | 
			
		||||
 | 
			
		||||
        query_list.append('APPLY BATCH;')
 | 
			
		||||
 | 
			
		||||
        execute('\n'.join(query_list), parameters, self._consistency)
 | 
			
		||||
        tmp = execute('\n'.join(query_list), parameters, self._consistency)
 | 
			
		||||
        check_applied(tmp)
 | 
			
		||||
 | 
			
		||||
        self.queries = []
 | 
			
		||||
        self._execute_callbacks()
 | 
			
		||||
@@ -220,6 +232,7 @@ class AbstractQuerySet(object):
 | 
			
		||||
        self._ttl = getattr(model, '__default_ttl__', None)
 | 
			
		||||
        self._consistency = None
 | 
			
		||||
        self._timestamp = None
 | 
			
		||||
        self._if_not_exists = False
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def column_family_name(self):
 | 
			
		||||
@@ -569,7 +582,7 @@ class AbstractQuerySet(object):
 | 
			
		||||
 | 
			
		||||
    def create(self, **kwargs):
 | 
			
		||||
        return self.model(**kwargs).batch(self._batch).ttl(self._ttl).\
 | 
			
		||||
            consistency(self._consistency).\
 | 
			
		||||
            consistency(self._consistency).if_not_exists(self._if_not_exists).\
 | 
			
		||||
            timestamp(self._timestamp).save()
 | 
			
		||||
 | 
			
		||||
    def delete(self):
 | 
			
		||||
@@ -708,6 +721,11 @@ class ModelQuerySet(AbstractQuerySet):
 | 
			
		||||
        clone._timestamp = timestamp
 | 
			
		||||
        return clone
 | 
			
		||||
 | 
			
		||||
    def if_not_exists(self):
 | 
			
		||||
        clone = copy.deepcopy(self)
 | 
			
		||||
        clone._if_not_exists = True
 | 
			
		||||
        return clone
 | 
			
		||||
 | 
			
		||||
    def update(self, **values):
 | 
			
		||||
        """ Updates the rows in this queryset """
 | 
			
		||||
        if not values:
 | 
			
		||||
@@ -767,8 +785,9 @@ class DMLQuery(object):
 | 
			
		||||
    _ttl = None
 | 
			
		||||
    _consistency = None
 | 
			
		||||
    _timestamp = None
 | 
			
		||||
    _if_not_exists = False
 | 
			
		||||
 | 
			
		||||
    def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None):
 | 
			
		||||
    def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False):
 | 
			
		||||
        self.model = model
 | 
			
		||||
        self.column_family_name = self.model.column_family_name()
 | 
			
		||||
        self.instance = instance
 | 
			
		||||
@@ -776,12 +795,16 @@ class DMLQuery(object):
 | 
			
		||||
        self._ttl = ttl
 | 
			
		||||
        self._consistency = consistency
 | 
			
		||||
        self._timestamp = timestamp
 | 
			
		||||
        self._if_not_exists = if_not_exists
 | 
			
		||||
 | 
			
		||||
    def _execute(self, q):
 | 
			
		||||
        if self._batch:
 | 
			
		||||
            return self._batch.add_query(q)
 | 
			
		||||
        else:
 | 
			
		||||
            tmp = execute(q, consistency_level=self._consistency)
 | 
			
		||||
            if self._if_not_exists:
 | 
			
		||||
                check_applied(tmp)
 | 
			
		||||
 
 | 
			
		||||
            return tmp
 | 
			
		||||
 | 
			
		||||
    def batch(self, batch_obj):
 | 
			
		||||
@@ -890,7 +913,7 @@ class DMLQuery(object):
 | 
			
		||||
        if self.instance._has_counter or self.instance._can_update():
 | 
			
		||||
            return self.update()
 | 
			
		||||
        else:
 | 
			
		||||
            insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp)
 | 
			
		||||
            insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, if_not_exists=self._if_not_exists)
 | 
			
		||||
            for name, col in self.instance._columns.items():
 | 
			
		||||
                val = getattr(self.instance, name, None)
 | 
			
		||||
                if col._val_is_null(val):
 | 
			
		||||
@@ -906,7 +929,6 @@ class DMLQuery(object):
 | 
			
		||||
        # caused by pointless update queries
 | 
			
		||||
        if not insert.is_empty:
 | 
			
		||||
            self._execute(insert)
 | 
			
		||||
 | 
			
		||||
        # delete any nulled columns
 | 
			
		||||
        self._delete_null_columns()
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -619,6 +619,24 @@ class AssignmentStatement(BaseCQLStatement):
 | 
			
		||||
class InsertStatement(AssignmentStatement):
 | 
			
		||||
    """ an cql insert select statement """
 | 
			
		||||
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 table,
 | 
			
		||||
                 assignments=None,
 | 
			
		||||
                 consistency=None,
 | 
			
		||||
                 where=None,
 | 
			
		||||
                 ttl=None,
 | 
			
		||||
                 timestamp=None,
 | 
			
		||||
                 if_not_exists=False):
 | 
			
		||||
        super(InsertStatement, self).__init__(
 | 
			
		||||
                table,
 | 
			
		||||
                assignments=assignments,
 | 
			
		||||
                consistency=consistency,
 | 
			
		||||
                where=where,
 | 
			
		||||
                ttl=ttl,
 | 
			
		||||
                timestamp=timestamp)
 | 
			
		||||
 | 
			
		||||
        self.if_not_exists = if_not_exists
 | 
			
		||||
 | 
			
		||||
    def add_where_clause(self, clause):
 | 
			
		||||
        raise StatementException("Cannot add where clauses to insert statements")
 | 
			
		||||
 | 
			
		||||
@@ -633,6 +651,9 @@ class InsertStatement(AssignmentStatement):
 | 
			
		||||
        qs += ['VALUES']
 | 
			
		||||
        qs += ["({})".format(', '.join(['%({})s'.format(v) for v in values]))]
 | 
			
		||||
 | 
			
		||||
        if self.if_not_exists:
 | 
			
		||||
            qs += ["IF NOT EXISTS"]
 | 
			
		||||
 | 
			
		||||
        if self.ttl:
 | 
			
		||||
            qs += ["USING TTL {}".format(self.ttl)]
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										178
									
								
								cqlengine/tests/test_ifnotexists.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										178
									
								
								cqlengine/tests/test_ifnotexists.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,178 @@
 | 
			
		||||
from unittest import skipUnless
 | 
			
		||||
from cqlengine.management import sync_table, drop_table, create_keyspace, delete_keyspace
 | 
			
		||||
from cqlengine.tests.base import BaseCassEngTestCase
 | 
			
		||||
from cqlengine.models import Model
 | 
			
		||||
from cqlengine.exceptions import LWTException
 | 
			
		||||
from cqlengine import columns, BatchQuery
 | 
			
		||||
from uuid import uuid4
 | 
			
		||||
import mock
 | 
			
		||||
from cqlengine.connection import get_cluster
 | 
			
		||||
 | 
			
		||||
cluster = get_cluster()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestIfNotExistsModel(Model):
 | 
			
		||||
 | 
			
		||||
    __keyspace__ = 'cqlengine_test_lwt'
 | 
			
		||||
 | 
			
		||||
    id      = columns.UUID(primary_key=True, default=lambda:uuid4())
 | 
			
		||||
    count   = columns.Integer()
 | 
			
		||||
    text    = columns.Text(required=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseIfNotExistsTest(BaseCassEngTestCase):
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def setUpClass(cls):
 | 
			
		||||
        super(BaseIfNotExistsTest, cls).setUpClass()
 | 
			
		||||
        """
 | 
			
		||||
        when receiving an insert statement with 'if not exist', cassandra would
 | 
			
		||||
        perform a read with QUORUM level. Unittest would be failed if replica_factor
 | 
			
		||||
        is 3 and one node only. Therefore I have create a new keyspace with
 | 
			
		||||
        replica_factor:1.
 | 
			
		||||
        """
 | 
			
		||||
        create_keyspace(TestIfNotExistsModel.__keyspace__, replication_factor=1)
 | 
			
		||||
        sync_table(TestIfNotExistsModel)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def tearDownClass(cls):
 | 
			
		||||
        super(BaseCassEngTestCase, cls).tearDownClass()
 | 
			
		||||
        drop_table(TestIfNotExistsModel)
 | 
			
		||||
        delete_keyspace(TestIfNotExistsModel.__keyspace__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class IfNotExistsInsertTests(BaseIfNotExistsTest):
 | 
			
		||||
 | 
			
		||||
    @skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0")
 | 
			
		||||
    def test_insert_if_not_exists_success(self):
 | 
			
		||||
        """ tests that insertion with if_not_exists work as expected """
 | 
			
		||||
 | 
			
		||||
        id = uuid4()
 | 
			
		||||
 | 
			
		||||
        TestIfNotExistsModel.create(id=id, count=8, text='123456789')
 | 
			
		||||
        self.assertRaises(
 | 
			
		||||
            LWTException,
 | 
			
		||||
            TestIfNotExistsModel.if_not_exists().create, id=id, count=9, text='111111111111'
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        q = TestIfNotExistsModel.objects(id=id)
 | 
			
		||||
        self.assertEqual(len(q), 1)
 | 
			
		||||
 | 
			
		||||
        tm = q.first()
 | 
			
		||||
        self.assertEquals(tm.count, 8)
 | 
			
		||||
        self.assertEquals(tm.text, '123456789')
 | 
			
		||||
 | 
			
		||||
    def test_insert_if_not_exists_failure(self):
 | 
			
		||||
        """ tests that insertion with if_not_exists failure """
 | 
			
		||||
 | 
			
		||||
        id = uuid4()
 | 
			
		||||
 | 
			
		||||
        TestIfNotExistsModel.create(id=id, count=8, text='123456789')
 | 
			
		||||
        TestIfNotExistsModel.create(id=id, count=9, text='111111111111')
 | 
			
		||||
 | 
			
		||||
        q = TestIfNotExistsModel.objects(id=id)
 | 
			
		||||
        self.assertEquals(len(q), 1)
 | 
			
		||||
 | 
			
		||||
        tm = q.first()
 | 
			
		||||
        self.assertEquals(tm.count, 9)
 | 
			
		||||
        self.assertEquals(tm.text, '111111111111')
 | 
			
		||||
 | 
			
		||||
    @skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0")
 | 
			
		||||
    def test_batch_insert_if_not_exists_success(self):
 | 
			
		||||
        """ tests that batch insertion with if_not_exists work as expected """
 | 
			
		||||
 | 
			
		||||
        id = uuid4()
 | 
			
		||||
 | 
			
		||||
        with BatchQuery() as b:
 | 
			
		||||
            TestIfNotExistsModel.batch(b).if_not_exists().create(id=id, count=8, text='123456789')
 | 
			
		||||
 | 
			
		||||
        b = BatchQuery()
 | 
			
		||||
        TestIfNotExistsModel.batch(b).if_not_exists().create(id=id, count=9, text='111111111111')
 | 
			
		||||
        self.assertRaises(LWTException, b.execute)
 | 
			
		||||
 | 
			
		||||
        q = TestIfNotExistsModel.objects(id=id)
 | 
			
		||||
        self.assertEqual(len(q), 1)
 | 
			
		||||
 | 
			
		||||
        tm = q.first()
 | 
			
		||||
        self.assertEquals(tm.count, 8)
 | 
			
		||||
        self.assertEquals(tm.text, '123456789')
 | 
			
		||||
 | 
			
		||||
    def test_batch_insert_if_not_exists_failure(self):
 | 
			
		||||
        """ tests that batch insertion with if_not_exists failure """
 | 
			
		||||
        id = uuid4()
 | 
			
		||||
 | 
			
		||||
        with BatchQuery() as b:
 | 
			
		||||
            TestIfNotExistsModel.batch(b).create(id=id, count=8, text='123456789')
 | 
			
		||||
        with BatchQuery() as b:
 | 
			
		||||
            TestIfNotExistsModel.batch(b).create(id=id, count=9, text='111111111111')
 | 
			
		||||
 | 
			
		||||
        q = TestIfNotExistsModel.objects(id=id)
 | 
			
		||||
        self.assertEquals(len(q), 1)
 | 
			
		||||
 | 
			
		||||
        tm = q.first()
 | 
			
		||||
        self.assertEquals(tm.count, 9)
 | 
			
		||||
        self.assertEquals(tm.text, '111111111111')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class IfNotExistsModelTest(BaseIfNotExistsTest):
 | 
			
		||||
 | 
			
		||||
    def test_if_not_exists_included_on_create(self):
 | 
			
		||||
        """ tests that if_not_exists on models works as expected """
 | 
			
		||||
 | 
			
		||||
        with mock.patch.object(self.session, 'execute') as m:
 | 
			
		||||
            TestIfNotExistsModel.if_not_exists().create(count=8)
 | 
			
		||||
 | 
			
		||||
        query = m.call_args[0][0].query_string
 | 
			
		||||
        self.assertIn("IF NOT EXISTS", query)
 | 
			
		||||
 | 
			
		||||
    def test_if_not_exists_included_on_save(self):
 | 
			
		||||
        """ tests if we correctly put 'IF NOT EXISTS' for insert statement """
 | 
			
		||||
 | 
			
		||||
        with mock.patch.object(self.session, 'execute') as m:
 | 
			
		||||
            tm = TestIfNotExistsModel(count=8)
 | 
			
		||||
            tm.if_not_exists(True).save()
 | 
			
		||||
 | 
			
		||||
        query = m.call_args[0][0].query_string
 | 
			
		||||
        self.assertIn("IF NOT EXISTS", query)
 | 
			
		||||
 | 
			
		||||
    def test_queryset_is_returned_on_class(self):
 | 
			
		||||
        """ ensure we get a queryset description back """
 | 
			
		||||
        qs = TestIfNotExistsModel.if_not_exists()
 | 
			
		||||
        self.assertTrue(isinstance(qs, TestIfNotExistsModel.__queryset__), type(qs))
 | 
			
		||||
 | 
			
		||||
    def test_batch_if_not_exists(self):
 | 
			
		||||
        """ ensure 'IF NOT EXISTS' exists in statement when in batch """ 
 | 
			
		||||
        with mock.patch.object(self.session, 'execute') as m:
 | 
			
		||||
            with BatchQuery() as b:
 | 
			
		||||
                TestIfNotExistsModel.batch(b).if_not_exists().create(count=8)
 | 
			
		||||
 | 
			
		||||
        self.assertIn("IF NOT EXISTS", m.call_args[0][0].query_string)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class IfNotExistsInstanceTest(BaseIfNotExistsTest):
 | 
			
		||||
 | 
			
		||||
    def test_instance_is_returned(self):
 | 
			
		||||
        """
 | 
			
		||||
        ensures that we properly handle the instance.if_not_exists(True).save()
 | 
			
		||||
        scenario
 | 
			
		||||
        """
 | 
			
		||||
        o = TestIfNotExistsModel.create(text="whatever")
 | 
			
		||||
        o.text = "new stuff"
 | 
			
		||||
        o = o.if_not_exists(True)
 | 
			
		||||
        self.assertEqual(True, o._if_not_exists)
 | 
			
		||||
 | 
			
		||||
    def test_if_not_exists_is_not_include_with_query_on_update(self):
 | 
			
		||||
        """
 | 
			
		||||
        make sure we don't put 'IF NOT EXIST' in update statements
 | 
			
		||||
        """
 | 
			
		||||
        o = TestIfNotExistsModel.create(text="whatever")
 | 
			
		||||
        o.text = "new stuff"
 | 
			
		||||
        o = o.if_not_exists(True)
 | 
			
		||||
 | 
			
		||||
        with mock.patch.object(self.session, 'execute') as m:
 | 
			
		||||
            o.save()
 | 
			
		||||
 | 
			
		||||
        query = m.call_args[0][0].query_string
 | 
			
		||||
        self.assertNotIn("IF NOT EXIST", query)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -188,6 +188,25 @@ Model Methods
 | 
			
		||||
 | 
			
		||||
        Sets the ttl values to run instance updates and inserts queries with.
 | 
			
		||||
 | 
			
		||||
    .. method:: if_not_exists()
 | 
			
		||||
 | 
			
		||||
        Check the existence of an object before insertion. The existence of an
 | 
			
		||||
        object is determined by its primary key(s). And please note using this flag
 | 
			
		||||
        would incur performance cost.
 | 
			
		||||
 | 
			
		||||
        if the insertion didn't applied, a LWTException exception would be raised.
 | 
			
		||||
 | 
			
		||||
        *Example*
 | 
			
		||||
 | 
			
		||||
        .. code-block:: python
 | 
			
		||||
            try:
 | 
			
		||||
                TestIfNotExistsModel.if_not_exists().create(id=id, count=9, text='111111111111')
 | 
			
		||||
            except LWTException as e:
 | 
			
		||||
                # handle failure case
 | 
			
		||||
                print e.existing # existing object
 | 
			
		||||
 | 
			
		||||
        This method is supported on Cassandra 2.0 or later.
 | 
			
		||||
 | 
			
		||||
    .. method:: update(**values)
 | 
			
		||||
 | 
			
		||||
        Performs an update on the model instance. You can pass in values to set on the model
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user