Bug fixes
This commit is contained in:
@@ -80,9 +80,14 @@ class TransactionDescriptor(object):
|
|||||||
"""
|
"""
|
||||||
def __get__(self, instance, model):
|
def __get__(self, instance, model):
|
||||||
if instance:
|
if instance:
|
||||||
def transaction_setter(transaction):
|
def transaction_setter(*prepared_transaction, **unprepared_transactions):
|
||||||
instance._transaction = transaction
|
if len(prepared_transaction) > 0:
|
||||||
|
transactions = prepared_transaction[0]
|
||||||
|
else:
|
||||||
|
transactions = instance.objects.transaction(**unprepared_transactions)._transaction
|
||||||
|
instance._transaction = transactions
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
return transaction_setter
|
return transaction_setter
|
||||||
qs = model.__queryset__(model)
|
qs = model.__queryset__(model)
|
||||||
|
|
||||||
@@ -323,6 +328,7 @@ class BaseModel(object):
|
|||||||
self._values = {}
|
self._values = {}
|
||||||
self._ttl = self.__default_ttl__
|
self._ttl = self.__default_ttl__
|
||||||
self._timestamp = None
|
self._timestamp = None
|
||||||
|
self._transaction = None
|
||||||
|
|
||||||
for name, column in self._columns.items():
|
for name, column in self._columns.items():
|
||||||
value = values.get(name, None)
|
value = values.get(name, None)
|
||||||
@@ -553,10 +559,6 @@ class BaseModel(object):
|
|||||||
|
|
||||||
return cls.objects.filter(*args, **kwargs)
|
return cls.objects.filter(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def transaction(cls, *args, **kwargs):
|
|
||||||
return cls.objects.transaction(*args, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, *args, **kwargs):
|
def get(cls, *args, **kwargs):
|
||||||
return cls.objects.get(*args, **kwargs)
|
return cls.objects.get(*args, **kwargs)
|
||||||
@@ -576,7 +578,8 @@ class BaseModel(object):
|
|||||||
ttl=self._ttl,
|
ttl=self._ttl,
|
||||||
timestamp=self._timestamp,
|
timestamp=self._timestamp,
|
||||||
consistency=self.__consistency__,
|
consistency=self.__consistency__,
|
||||||
if_not_exists=self._if_not_exists).save()
|
if_not_exists=self._if_not_exists,
|
||||||
|
transaction=self._transaction).save()
|
||||||
|
|
||||||
#reset the value managers
|
#reset the value managers
|
||||||
for v in self._values.values():
|
for v in self._values.values():
|
||||||
@@ -614,7 +617,8 @@ class BaseModel(object):
|
|||||||
batch=self._batch,
|
batch=self._batch,
|
||||||
ttl=self._ttl,
|
ttl=self._ttl,
|
||||||
timestamp=self._timestamp,
|
timestamp=self._timestamp,
|
||||||
consistency=self.__consistency__).update()
|
consistency=self.__consistency__,
|
||||||
|
transaction=self._transaction).update()
|
||||||
|
|
||||||
#reset the value managers
|
#reset the value managers
|
||||||
for v in self._values.values():
|
for v in self._values.values():
|
||||||
|
|||||||
@@ -422,11 +422,12 @@ class AbstractQuerySet(object):
|
|||||||
clone._transaction.append(operator)
|
clone._transaction.append(operator)
|
||||||
|
|
||||||
for col_name, val in kwargs.items():
|
for col_name, val in kwargs.items():
|
||||||
|
exists = False
|
||||||
try:
|
try:
|
||||||
column = self.model._get_column(col_name)
|
column = self.model._get_column(col_name)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
if col_name in ['exists', 'not_exists']:
|
if col_name in ['exists', 'not_exists']:
|
||||||
pass
|
exists = True
|
||||||
elif col_name == 'pk__token':
|
elif col_name == 'pk__token':
|
||||||
if not isinstance(val, Token):
|
if not isinstance(val, Token):
|
||||||
raise QueryException("Virtual column 'pk__token' may only be compared to Token() values")
|
raise QueryException("Virtual column 'pk__token' may only be compared to Token() values")
|
||||||
@@ -445,7 +446,7 @@ class AbstractQuerySet(object):
|
|||||||
len(val.value), len(partition_columns)))
|
len(val.value), len(partition_columns)))
|
||||||
val.set_columns(partition_columns)
|
val.set_columns(partition_columns)
|
||||||
|
|
||||||
if isinstance(val, BaseQueryFunction):
|
if isinstance(val, BaseQueryFunction) or exists is True:
|
||||||
query_val = val
|
query_val = val
|
||||||
else:
|
else:
|
||||||
query_val = column.to_database(val)
|
query_val = column.to_database(val)
|
||||||
@@ -834,7 +835,7 @@ class DMLQuery(object):
|
|||||||
_timestamp = None
|
_timestamp = None
|
||||||
_if_not_exists = False
|
_if_not_exists = False
|
||||||
|
|
||||||
def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False):
|
def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False, transaction=None):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.column_family_name = self.model.column_family_name()
|
self.column_family_name = self.model.column_family_name()
|
||||||
self.instance = instance
|
self.instance = instance
|
||||||
@@ -843,6 +844,7 @@ class DMLQuery(object):
|
|||||||
self._consistency = consistency
|
self._consistency = consistency
|
||||||
self._timestamp = timestamp
|
self._timestamp = timestamp
|
||||||
self._if_not_exists = if_not_exists
|
self._if_not_exists = if_not_exists
|
||||||
|
self._transaction = transaction
|
||||||
|
|
||||||
def _execute(self, q):
|
def _execute(self, q):
|
||||||
if self._batch:
|
if self._batch:
|
||||||
@@ -897,7 +899,8 @@ class DMLQuery(object):
|
|||||||
raise CQLEngineException("DML Query intance attribute is None")
|
raise CQLEngineException("DML Query intance attribute is None")
|
||||||
assert type(self.instance) == self.model
|
assert type(self.instance) == self.model
|
||||||
static_update_only = True
|
static_update_only = True
|
||||||
statement = UpdateStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp)
|
statement = UpdateStatement(self.column_family_name, ttl=self._ttl,
|
||||||
|
timestamp=self._timestamp, transactions=self._transaction)
|
||||||
#get defined fields and their column names
|
#get defined fields and their column names
|
||||||
for name, col in self.model._columns.items():
|
for name, col in self.model._columns.items():
|
||||||
if not col.is_primary_key:
|
if not col.is_primary_key:
|
||||||
@@ -960,7 +963,8 @@ class DMLQuery(object):
|
|||||||
if self.instance._has_counter or self.instance._can_update():
|
if self.instance._has_counter or self.instance._can_update():
|
||||||
return self.update()
|
return self.update()
|
||||||
else:
|
else:
|
||||||
insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, if_not_exists=self._if_not_exists)
|
insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, if_not_exists=self._if_not_exists, transactions=self._transaction)
|
||||||
|
|
||||||
for name, col in self.instance._columns.items():
|
for name, col in self.instance._columns.items():
|
||||||
val = getattr(self.instance, name, None)
|
val = getattr(self.instance, name, None)
|
||||||
if col._val_is_null(val):
|
if col._val_is_null(val):
|
||||||
|
|||||||
@@ -152,6 +152,11 @@ class TransactionClause(BaseClause):
|
|||||||
def insert_tuple(self):
|
def insert_tuple(self):
|
||||||
return self.field, self.context_id
|
return self.field, self.context_id
|
||||||
|
|
||||||
|
def update_context(self, ctx):
|
||||||
|
if self.field not in ['exists', 'not_exists']:
|
||||||
|
return super(TransactionClause, self).update_context(ctx)
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
|
||||||
class ContainerUpdateClause(AssignmentClause):
|
class ContainerUpdateClause(AssignmentClause):
|
||||||
|
|
||||||
@@ -646,10 +651,12 @@ class AssignmentStatement(BaseCQLStatement):
|
|||||||
ctx = super(AssignmentStatement, self).get_context()
|
ctx = super(AssignmentStatement, self).get_context()
|
||||||
for clause in self.assignments:
|
for clause in self.assignments:
|
||||||
clause.update_context(ctx)
|
clause.update_context(ctx)
|
||||||
|
for clause in self.transactions or []:
|
||||||
|
clause.update_context(ctx)
|
||||||
return ctx
|
return ctx
|
||||||
|
|
||||||
def _transactions(self):
|
def _get_transactions(self):
|
||||||
return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.where_clauses]))
|
return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.transactions]))
|
||||||
|
|
||||||
|
|
||||||
class InsertStatement(AssignmentStatement):
|
class InsertStatement(AssignmentStatement):
|
||||||
@@ -697,7 +704,7 @@ class InsertStatement(AssignmentStatement):
|
|||||||
qs += ["USING TIMESTAMP {}".format(self.timestamp_normalized)]
|
qs += ["USING TIMESTAMP {}".format(self.timestamp_normalized)]
|
||||||
|
|
||||||
if len(self.transactions) > 0:
|
if len(self.transactions) > 0:
|
||||||
qs += [self._transactions]
|
qs += [self._get_transactions()]
|
||||||
|
|
||||||
return ' '.join(qs)
|
return ' '.join(qs)
|
||||||
|
|
||||||
@@ -726,7 +733,7 @@ class UpdateStatement(AssignmentStatement):
|
|||||||
qs += [self._where]
|
qs += [self._where]
|
||||||
|
|
||||||
if len(self.transactions) > 0:
|
if len(self.transactions) > 0:
|
||||||
qs += [self._transactions]
|
qs += [self._get_transactions()]
|
||||||
|
|
||||||
return ' '.join(qs)
|
return ' '.join(qs)
|
||||||
|
|
||||||
|
|||||||
@@ -4,10 +4,13 @@ from uuid import uuid4
|
|||||||
from mock import patch
|
from mock import patch
|
||||||
from cqlengine.exceptions import ValidationError
|
from cqlengine.exceptions import ValidationError
|
||||||
|
|
||||||
from cqlengine.tests.base import BaseCassEngTestCase
|
from unittest import TestCase
|
||||||
from cqlengine.models import Model
|
from cqlengine.models import Model
|
||||||
from cqlengine import columns
|
from cqlengine import columns
|
||||||
from cqlengine.management import sync_table, drop_table
|
from cqlengine.management import sync_table, drop_table
|
||||||
|
from cqlengine import connection
|
||||||
|
|
||||||
|
connection.setup(['192.168.56.103'], 'test')
|
||||||
|
|
||||||
|
|
||||||
class TestUpdateModel(Model):
|
class TestUpdateModel(Model):
|
||||||
@@ -18,7 +21,7 @@ class TestUpdateModel(Model):
|
|||||||
text = columns.Text(required=False, index=True)
|
text = columns.Text(required=False, index=True)
|
||||||
|
|
||||||
|
|
||||||
class ModelUpdateTests(BaseCassEngTestCase):
|
class ModelUpdateTests(TestCase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@@ -31,6 +34,12 @@ class ModelUpdateTests(BaseCassEngTestCase):
|
|||||||
drop_table(TestUpdateModel)
|
drop_table(TestUpdateModel)
|
||||||
|
|
||||||
def test_transaction_insertion(self):
|
def test_transaction_insertion(self):
|
||||||
m = TestUpdateModel(count=5, text='something').transaction(exists=True)
|
m = TestUpdateModel.objects.create(count=5, text='something')
|
||||||
m.save()
|
all_models = TestUpdateModel.objects.all()
|
||||||
|
for x in all_models:
|
||||||
|
pass
|
||||||
|
m = TestUpdateModel.objects
|
||||||
|
m = m.transaction(not_exists=True)
|
||||||
|
m = m.create(count=5, text='something')
|
||||||
|
m.transaction(count=6).update(text='something else')
|
||||||
x = 10
|
x = 10
|
||||||
Reference in New Issue
Block a user