Bug fixes
This commit is contained in:
@@ -80,9 +80,14 @@ class TransactionDescriptor(object):
|
||||
"""
|
||||
def __get__(self, instance, model):
|
||||
if instance:
|
||||
def transaction_setter(transaction):
|
||||
instance._transaction = transaction
|
||||
def transaction_setter(*prepared_transaction, **unprepared_transactions):
|
||||
if len(prepared_transaction) > 0:
|
||||
transactions = prepared_transaction[0]
|
||||
else:
|
||||
transactions = instance.objects.transaction(**unprepared_transactions)._transaction
|
||||
instance._transaction = transactions
|
||||
return instance
|
||||
|
||||
return transaction_setter
|
||||
qs = model.__queryset__(model)
|
||||
|
||||
@@ -323,6 +328,7 @@ class BaseModel(object):
|
||||
self._values = {}
|
||||
self._ttl = self.__default_ttl__
|
||||
self._timestamp = None
|
||||
self._transaction = None
|
||||
|
||||
for name, column in self._columns.items():
|
||||
value = values.get(name, None)
|
||||
@@ -553,10 +559,6 @@ class BaseModel(object):
|
||||
|
||||
return cls.objects.filter(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def transaction(cls, *args, **kwargs):
|
||||
return cls.objects.transaction(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get(cls, *args, **kwargs):
|
||||
return cls.objects.get(*args, **kwargs)
|
||||
@@ -576,7 +578,8 @@ class BaseModel(object):
|
||||
ttl=self._ttl,
|
||||
timestamp=self._timestamp,
|
||||
consistency=self.__consistency__,
|
||||
if_not_exists=self._if_not_exists).save()
|
||||
if_not_exists=self._if_not_exists,
|
||||
transaction=self._transaction).save()
|
||||
|
||||
#reset the value managers
|
||||
for v in self._values.values():
|
||||
@@ -614,7 +617,8 @@ class BaseModel(object):
|
||||
batch=self._batch,
|
||||
ttl=self._ttl,
|
||||
timestamp=self._timestamp,
|
||||
consistency=self.__consistency__).update()
|
||||
consistency=self.__consistency__,
|
||||
transaction=self._transaction).update()
|
||||
|
||||
#reset the value managers
|
||||
for v in self._values.values():
|
||||
|
||||
@@ -422,11 +422,12 @@ class AbstractQuerySet(object):
|
||||
clone._transaction.append(operator)
|
||||
|
||||
for col_name, val in kwargs.items():
|
||||
exists = False
|
||||
try:
|
||||
column = self.model._get_column(col_name)
|
||||
except KeyError:
|
||||
if col_name in ['exists', 'not_exists']:
|
||||
pass
|
||||
exists = True
|
||||
elif col_name == 'pk__token':
|
||||
if not isinstance(val, Token):
|
||||
raise QueryException("Virtual column 'pk__token' may only be compared to Token() values")
|
||||
@@ -445,7 +446,7 @@ class AbstractQuerySet(object):
|
||||
len(val.value), len(partition_columns)))
|
||||
val.set_columns(partition_columns)
|
||||
|
||||
if isinstance(val, BaseQueryFunction):
|
||||
if isinstance(val, BaseQueryFunction) or exists is True:
|
||||
query_val = val
|
||||
else:
|
||||
query_val = column.to_database(val)
|
||||
@@ -834,7 +835,7 @@ class DMLQuery(object):
|
||||
_timestamp = None
|
||||
_if_not_exists = False
|
||||
|
||||
def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False):
|
||||
def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False, transaction=None):
|
||||
self.model = model
|
||||
self.column_family_name = self.model.column_family_name()
|
||||
self.instance = instance
|
||||
@@ -843,6 +844,7 @@ class DMLQuery(object):
|
||||
self._consistency = consistency
|
||||
self._timestamp = timestamp
|
||||
self._if_not_exists = if_not_exists
|
||||
self._transaction = transaction
|
||||
|
||||
def _execute(self, q):
|
||||
if self._batch:
|
||||
@@ -897,7 +899,8 @@ class DMLQuery(object):
|
||||
raise CQLEngineException("DML Query intance attribute is None")
|
||||
assert type(self.instance) == self.model
|
||||
static_update_only = True
|
||||
statement = UpdateStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp)
|
||||
statement = UpdateStatement(self.column_family_name, ttl=self._ttl,
|
||||
timestamp=self._timestamp, transactions=self._transaction)
|
||||
#get defined fields and their column names
|
||||
for name, col in self.model._columns.items():
|
||||
if not col.is_primary_key:
|
||||
@@ -960,7 +963,8 @@ class DMLQuery(object):
|
||||
if self.instance._has_counter or self.instance._can_update():
|
||||
return self.update()
|
||||
else:
|
||||
insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, if_not_exists=self._if_not_exists)
|
||||
insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, if_not_exists=self._if_not_exists, transactions=self._transaction)
|
||||
|
||||
for name, col in self.instance._columns.items():
|
||||
val = getattr(self.instance, name, None)
|
||||
if col._val_is_null(val):
|
||||
|
||||
@@ -152,6 +152,11 @@ class TransactionClause(BaseClause):
|
||||
def insert_tuple(self):
|
||||
return self.field, self.context_id
|
||||
|
||||
def update_context(self, ctx):
|
||||
if self.field not in ['exists', 'not_exists']:
|
||||
return super(TransactionClause, self).update_context(ctx)
|
||||
return ctx
|
||||
|
||||
|
||||
class ContainerUpdateClause(AssignmentClause):
|
||||
|
||||
@@ -646,10 +651,12 @@ class AssignmentStatement(BaseCQLStatement):
|
||||
ctx = super(AssignmentStatement, self).get_context()
|
||||
for clause in self.assignments:
|
||||
clause.update_context(ctx)
|
||||
for clause in self.transactions or []:
|
||||
clause.update_context(ctx)
|
||||
return ctx
|
||||
|
||||
def _transactions(self):
|
||||
return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.where_clauses]))
|
||||
def _get_transactions(self):
|
||||
return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.transactions]))
|
||||
|
||||
|
||||
class InsertStatement(AssignmentStatement):
|
||||
@@ -697,7 +704,7 @@ class InsertStatement(AssignmentStatement):
|
||||
qs += ["USING TIMESTAMP {}".format(self.timestamp_normalized)]
|
||||
|
||||
if len(self.transactions) > 0:
|
||||
qs += [self._transactions]
|
||||
qs += [self._get_transactions()]
|
||||
|
||||
return ' '.join(qs)
|
||||
|
||||
@@ -726,7 +733,7 @@ class UpdateStatement(AssignmentStatement):
|
||||
qs += [self._where]
|
||||
|
||||
if len(self.transactions) > 0:
|
||||
qs += [self._transactions]
|
||||
qs += [self._get_transactions()]
|
||||
|
||||
return ' '.join(qs)
|
||||
|
||||
|
||||
@@ -4,10 +4,13 @@ from uuid import uuid4
|
||||
from mock import patch
|
||||
from cqlengine.exceptions import ValidationError
|
||||
|
||||
from cqlengine.tests.base import BaseCassEngTestCase
|
||||
from unittest import TestCase
|
||||
from cqlengine.models import Model
|
||||
from cqlengine import columns
|
||||
from cqlengine.management import sync_table, drop_table
|
||||
from cqlengine import connection
|
||||
|
||||
connection.setup(['192.168.56.103'], 'test')
|
||||
|
||||
|
||||
class TestUpdateModel(Model):
|
||||
@@ -18,7 +21,7 @@ class TestUpdateModel(Model):
|
||||
text = columns.Text(required=False, index=True)
|
||||
|
||||
|
||||
class ModelUpdateTests(BaseCassEngTestCase):
|
||||
class ModelUpdateTests(TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -31,6 +34,12 @@ class ModelUpdateTests(BaseCassEngTestCase):
|
||||
drop_table(TestUpdateModel)
|
||||
|
||||
def test_transaction_insertion(self):
|
||||
m = TestUpdateModel(count=5, text='something').transaction(exists=True)
|
||||
m.save()
|
||||
m = TestUpdateModel.objects.create(count=5, text='something')
|
||||
all_models = TestUpdateModel.objects.all()
|
||||
for x in all_models:
|
||||
pass
|
||||
m = TestUpdateModel.objects
|
||||
m = m.transaction(not_exists=True)
|
||||
m = m.create(count=5, text='something')
|
||||
m.transaction(count=6).update(text='something else')
|
||||
x = 10
|
||||
Reference in New Issue
Block a user