Merge pull request #267 from mission-liao/if_not_exist_insert

If not exist insert
This commit is contained in:
Jon Haddad
2014-10-07 14:19:46 -07:00
6 changed files with 270 additions and 7 deletions

View File

@@ -4,3 +4,4 @@ class ModelException(CQLEngineException): pass
class ValidationError(CQLEngineException): pass class ValidationError(CQLEngineException): pass
class UndefinedKeyspaceException(CQLEngineException): pass class UndefinedKeyspaceException(CQLEngineException): pass
class LWTException(CQLEngineException): pass

View File

@@ -116,6 +116,23 @@ class TimestampDescriptor(object):
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
class IfNotExistsDescriptor(object):
"""
return a query set descriptor with a if_not_exists flag specified
"""
def __get__(self, instance, model):
if instance:
# instance method
def ifnotexists_setter(ife):
instance._if_not_exists = ife
return instance
return ifnotexists_setter
return model.objects.if_not_exists
def __call__(self, *args, **kwargs):
raise NotImplementedError
class ConsistencyDescriptor(object): class ConsistencyDescriptor(object):
""" """
returns a query set descriptor if called on Class, instance if it was an instance call returns a query set descriptor if called on Class, instance if it was an instance call
@@ -226,6 +243,8 @@ class BaseModel(object):
# custom timestamps, see USING TIMESTAMP X # custom timestamps, see USING TIMESTAMP X
timestamp = TimestampDescriptor() timestamp = TimestampDescriptor()
if_not_exists = IfNotExistsDescriptor()
# _len is lazily created by __len__ # _len is lazily created by __len__
# table names will be generated automatically from it's model # table names will be generated automatically from it's model
@@ -276,6 +295,8 @@ class BaseModel(object):
_timestamp = None # optional timestamp to include with the operation (USING TIMESTAMP) _timestamp = None # optional timestamp to include with the operation (USING TIMESTAMP)
_if_not_exists = False # optional if_not_exists flag to check existence before insertion
def __init__(self, **values): def __init__(self, **values):
self._values = {} self._values = {}
self._ttl = self.__default_ttl__ self._ttl = self.__default_ttl__
@@ -528,7 +549,8 @@ class BaseModel(object):
batch=self._batch, batch=self._batch,
ttl=self._ttl, ttl=self._ttl,
timestamp=self._timestamp, timestamp=self._timestamp,
consistency=self.__consistency__).save() consistency=self.__consistency__,
if_not_exists=self._if_not_exists).save()
#reset the value managers #reset the value managers
for v in self._values.values(): for v in self._values.values():

View File

@@ -6,7 +6,7 @@ from cqlengine.columns import Counter, List, Set
from cqlengine.connection import execute from cqlengine.connection import execute
from cqlengine.exceptions import CQLEngineException, ValidationError from cqlengine.exceptions import CQLEngineException, ValidationError, LWTException
from cqlengine.functions import Token, BaseQueryFunction, QueryValue, UnicodeMixin from cqlengine.functions import Token, BaseQueryFunction, QueryValue, UnicodeMixin
#CQL 3 reference: #CQL 3 reference:
@@ -22,6 +22,17 @@ class MultipleObjectsReturned(QueryException): pass
import six import six
def check_applied(result):
"""
check if result contains some column '[applied]' with false value,
if that value is false, it means our light-weight transaction didn't
applied to database.
"""
if result and '[applied]' in result[0] and result[0]['[applied]'] == False:
raise LWTException('')
class AbstractQueryableColumn(UnicodeMixin): class AbstractQueryableColumn(UnicodeMixin):
""" """
exposes cql query operators through pythons exposes cql query operators through pythons
@@ -171,7 +182,8 @@ class BatchQuery(object):
query_list.append('APPLY BATCH;') query_list.append('APPLY BATCH;')
execute('\n'.join(query_list), parameters, self._consistency) tmp = execute('\n'.join(query_list), parameters, self._consistency)
check_applied(tmp)
self.queries = [] self.queries = []
self._execute_callbacks() self._execute_callbacks()
@@ -220,6 +232,7 @@ class AbstractQuerySet(object):
self._ttl = getattr(model, '__default_ttl__', None) self._ttl = getattr(model, '__default_ttl__', None)
self._consistency = None self._consistency = None
self._timestamp = None self._timestamp = None
self._if_not_exists = False
@property @property
def column_family_name(self): def column_family_name(self):
@@ -569,7 +582,7 @@ class AbstractQuerySet(object):
def create(self, **kwargs): def create(self, **kwargs):
return self.model(**kwargs).batch(self._batch).ttl(self._ttl).\ return self.model(**kwargs).batch(self._batch).ttl(self._ttl).\
consistency(self._consistency).\ consistency(self._consistency).if_not_exists(self._if_not_exists).\
timestamp(self._timestamp).save() timestamp(self._timestamp).save()
def delete(self): def delete(self):
@@ -708,6 +721,11 @@ class ModelQuerySet(AbstractQuerySet):
clone._timestamp = timestamp clone._timestamp = timestamp
return clone return clone
def if_not_exists(self):
clone = copy.deepcopy(self)
clone._if_not_exists = True
return clone
def update(self, **values): def update(self, **values):
""" Updates the rows in this queryset """ """ Updates the rows in this queryset """
if not values: if not values:
@@ -767,8 +785,9 @@ class DMLQuery(object):
_ttl = None _ttl = None
_consistency = None _consistency = None
_timestamp = None _timestamp = None
_if_not_exists = False
def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None): def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False):
self.model = model self.model = model
self.column_family_name = self.model.column_family_name() self.column_family_name = self.model.column_family_name()
self.instance = instance self.instance = instance
@@ -776,12 +795,16 @@ class DMLQuery(object):
self._ttl = ttl self._ttl = ttl
self._consistency = consistency self._consistency = consistency
self._timestamp = timestamp self._timestamp = timestamp
self._if_not_exists = if_not_exists
def _execute(self, q): def _execute(self, q):
if self._batch: if self._batch:
return self._batch.add_query(q) return self._batch.add_query(q)
else: else:
tmp = execute(q, consistency_level=self._consistency) tmp = execute(q, consistency_level=self._consistency)
if self._if_not_exists:
check_applied(tmp)
return tmp return tmp
def batch(self, batch_obj): def batch(self, batch_obj):
@@ -890,7 +913,7 @@ class DMLQuery(object):
if self.instance._has_counter or self.instance._can_update(): if self.instance._has_counter or self.instance._can_update():
return self.update() return self.update()
else: else:
insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp) insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, if_not_exists=self._if_not_exists)
for name, col in self.instance._columns.items(): for name, col in self.instance._columns.items():
val = getattr(self.instance, name, None) val = getattr(self.instance, name, None)
if col._val_is_null(val): if col._val_is_null(val):
@@ -906,7 +929,6 @@ class DMLQuery(object):
# caused by pointless update queries # caused by pointless update queries
if not insert.is_empty: if not insert.is_empty:
self._execute(insert) self._execute(insert)
# delete any nulled columns # delete any nulled columns
self._delete_null_columns() self._delete_null_columns()

View File

@@ -619,6 +619,24 @@ class AssignmentStatement(BaseCQLStatement):
class InsertStatement(AssignmentStatement): class InsertStatement(AssignmentStatement):
""" an cql insert select statement """ """ an cql insert select statement """
def __init__(self,
table,
assignments=None,
consistency=None,
where=None,
ttl=None,
timestamp=None,
if_not_exists=False):
super(InsertStatement, self).__init__(
table,
assignments=assignments,
consistency=consistency,
where=where,
ttl=ttl,
timestamp=timestamp)
self.if_not_exists = if_not_exists
def add_where_clause(self, clause): def add_where_clause(self, clause):
raise StatementException("Cannot add where clauses to insert statements") raise StatementException("Cannot add where clauses to insert statements")
@@ -633,6 +651,9 @@ class InsertStatement(AssignmentStatement):
qs += ['VALUES'] qs += ['VALUES']
qs += ["({})".format(', '.join(['%({})s'.format(v) for v in values]))] qs += ["({})".format(', '.join(['%({})s'.format(v) for v in values]))]
if self.if_not_exists:
qs += ["IF NOT EXISTS"]
if self.ttl: if self.ttl:
qs += ["USING TTL {}".format(self.ttl)] qs += ["USING TTL {}".format(self.ttl)]

View File

@@ -0,0 +1,178 @@
from unittest import skipUnless
from cqlengine.management import sync_table, drop_table, create_keyspace, delete_keyspace
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.models import Model
from cqlengine.exceptions import LWTException
from cqlengine import columns, BatchQuery
from uuid import uuid4
import mock
from cqlengine.connection import get_cluster
cluster = get_cluster()
class TestIfNotExistsModel(Model):
__keyspace__ = 'cqlengine_test_lwt'
id = columns.UUID(primary_key=True, default=lambda:uuid4())
count = columns.Integer()
text = columns.Text(required=False)
class BaseIfNotExistsTest(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(BaseIfNotExistsTest, cls).setUpClass()
"""
when receiving an insert statement with 'if not exist', cassandra would
perform a read with QUORUM level. Unittest would be failed if replica_factor
is 3 and one node only. Therefore I have create a new keyspace with
replica_factor:1.
"""
create_keyspace(TestIfNotExistsModel.__keyspace__, replication_factor=1)
sync_table(TestIfNotExistsModel)
@classmethod
def tearDownClass(cls):
super(BaseCassEngTestCase, cls).tearDownClass()
drop_table(TestIfNotExistsModel)
delete_keyspace(TestIfNotExistsModel.__keyspace__)
class IfNotExistsInsertTests(BaseIfNotExistsTest):
@skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0")
def test_insert_if_not_exists_success(self):
""" tests that insertion with if_not_exists work as expected """
id = uuid4()
TestIfNotExistsModel.create(id=id, count=8, text='123456789')
self.assertRaises(
LWTException,
TestIfNotExistsModel.if_not_exists().create, id=id, count=9, text='111111111111'
)
q = TestIfNotExistsModel.objects(id=id)
self.assertEqual(len(q), 1)
tm = q.first()
self.assertEquals(tm.count, 8)
self.assertEquals(tm.text, '123456789')
def test_insert_if_not_exists_failure(self):
""" tests that insertion with if_not_exists failure """
id = uuid4()
TestIfNotExistsModel.create(id=id, count=8, text='123456789')
TestIfNotExistsModel.create(id=id, count=9, text='111111111111')
q = TestIfNotExistsModel.objects(id=id)
self.assertEquals(len(q), 1)
tm = q.first()
self.assertEquals(tm.count, 9)
self.assertEquals(tm.text, '111111111111')
@skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0")
def test_batch_insert_if_not_exists_success(self):
""" tests that batch insertion with if_not_exists work as expected """
id = uuid4()
with BatchQuery() as b:
TestIfNotExistsModel.batch(b).if_not_exists().create(id=id, count=8, text='123456789')
b = BatchQuery()
TestIfNotExistsModel.batch(b).if_not_exists().create(id=id, count=9, text='111111111111')
self.assertRaises(LWTException, b.execute)
q = TestIfNotExistsModel.objects(id=id)
self.assertEqual(len(q), 1)
tm = q.first()
self.assertEquals(tm.count, 8)
self.assertEquals(tm.text, '123456789')
def test_batch_insert_if_not_exists_failure(self):
""" tests that batch insertion with if_not_exists failure """
id = uuid4()
with BatchQuery() as b:
TestIfNotExistsModel.batch(b).create(id=id, count=8, text='123456789')
with BatchQuery() as b:
TestIfNotExistsModel.batch(b).create(id=id, count=9, text='111111111111')
q = TestIfNotExistsModel.objects(id=id)
self.assertEquals(len(q), 1)
tm = q.first()
self.assertEquals(tm.count, 9)
self.assertEquals(tm.text, '111111111111')
class IfNotExistsModelTest(BaseIfNotExistsTest):
def test_if_not_exists_included_on_create(self):
""" tests that if_not_exists on models works as expected """
with mock.patch.object(self.session, 'execute') as m:
TestIfNotExistsModel.if_not_exists().create(count=8)
query = m.call_args[0][0].query_string
self.assertIn("IF NOT EXISTS", query)
def test_if_not_exists_included_on_save(self):
""" tests if we correctly put 'IF NOT EXISTS' for insert statement """
with mock.patch.object(self.session, 'execute') as m:
tm = TestIfNotExistsModel(count=8)
tm.if_not_exists(True).save()
query = m.call_args[0][0].query_string
self.assertIn("IF NOT EXISTS", query)
def test_queryset_is_returned_on_class(self):
""" ensure we get a queryset description back """
qs = TestIfNotExistsModel.if_not_exists()
self.assertTrue(isinstance(qs, TestIfNotExistsModel.__queryset__), type(qs))
def test_batch_if_not_exists(self):
""" ensure 'IF NOT EXISTS' exists in statement when in batch """
with mock.patch.object(self.session, 'execute') as m:
with BatchQuery() as b:
TestIfNotExistsModel.batch(b).if_not_exists().create(count=8)
self.assertIn("IF NOT EXISTS", m.call_args[0][0].query_string)
class IfNotExistsInstanceTest(BaseIfNotExistsTest):
def test_instance_is_returned(self):
"""
ensures that we properly handle the instance.if_not_exists(True).save()
scenario
"""
o = TestIfNotExistsModel.create(text="whatever")
o.text = "new stuff"
o = o.if_not_exists(True)
self.assertEqual(True, o._if_not_exists)
def test_if_not_exists_is_not_include_with_query_on_update(self):
"""
make sure we don't put 'IF NOT EXIST' in update statements
"""
o = TestIfNotExistsModel.create(text="whatever")
o.text = "new stuff"
o = o.if_not_exists(True)
with mock.patch.object(self.session, 'execute') as m:
o.save()
query = m.call_args[0][0].query_string
self.assertNotIn("IF NOT EXIST", query)

View File

@@ -188,6 +188,25 @@ Model Methods
Sets the ttl values to run instance updates and inserts queries with. Sets the ttl values to run instance updates and inserts queries with.
.. method:: if_not_exists()
Check the existence of an object before insertion. The existence of an
object is determined by its primary key(s). And please note using this flag
would incur performance cost.
if the insertion didn't applied, a LWTException exception would be raised.
*Example*
.. code-block:: python
try:
TestIfNotExistsModel.if_not_exists().create(id=id, count=9, text='111111111111')
except LWTException as e:
# handle failure case
print e.existing # existing object
This method is supported on Cassandra 2.0 or later.
.. method:: update(**values) .. method:: update(**values)
Performs an update on the model instance. You can pass in values to set on the model Performs an update on the model instance. You can pass in values to set on the model