Bug fixes

This commit is contained in:
timmartin19
2014-10-08 14:48:35 -04:00
parent 09427f4d75
commit 874a63cdbc
4 changed files with 45 additions and 21 deletions

View File

@@ -80,9 +80,14 @@ class TransactionDescriptor(object):
"""
def __get__(self, instance, model):
if instance:
def transaction_setter(transaction):
instance._transaction = transaction
def transaction_setter(*prepared_transaction, **unprepared_transactions):
if len(prepared_transaction) > 0:
transactions = prepared_transaction[0]
else:
transactions = instance.objects.transaction(**unprepared_transactions)._transaction
instance._transaction = transactions
return instance
return transaction_setter
qs = model.__queryset__(model)
@@ -323,6 +328,7 @@ class BaseModel(object):
self._values = {}
self._ttl = self.__default_ttl__
self._timestamp = None
self._transaction = None
for name, column in self._columns.items():
value = values.get(name, None)
@@ -553,10 +559,6 @@ class BaseModel(object):
return cls.objects.filter(*args, **kwargs)
@classmethod
def transaction(cls, *args, **kwargs):
return cls.objects.transaction(*args, **kwargs)
@classmethod
def get(cls, *args, **kwargs):
return cls.objects.get(*args, **kwargs)
@@ -576,7 +578,8 @@ class BaseModel(object):
ttl=self._ttl,
timestamp=self._timestamp,
consistency=self.__consistency__,
if_not_exists=self._if_not_exists).save()
if_not_exists=self._if_not_exists,
transaction=self._transaction).save()
#reset the value managers
for v in self._values.values():
@@ -614,7 +617,8 @@ class BaseModel(object):
batch=self._batch,
ttl=self._ttl,
timestamp=self._timestamp,
consistency=self.__consistency__).update()
consistency=self.__consistency__,
transaction=self._transaction).update()
#reset the value managers
for v in self._values.values():

View File

@@ -422,11 +422,12 @@ class AbstractQuerySet(object):
clone._transaction.append(operator)
for col_name, val in kwargs.items():
exists = False
try:
column = self.model._get_column(col_name)
except KeyError:
if col_name in ['exists', 'not_exists']:
pass
exists = True
elif col_name == 'pk__token':
if not isinstance(val, Token):
raise QueryException("Virtual column 'pk__token' may only be compared to Token() values")
@@ -445,7 +446,7 @@ class AbstractQuerySet(object):
len(val.value), len(partition_columns)))
val.set_columns(partition_columns)
if isinstance(val, BaseQueryFunction):
if isinstance(val, BaseQueryFunction) or exists is True:
query_val = val
else:
query_val = column.to_database(val)
@@ -834,7 +835,7 @@ class DMLQuery(object):
_timestamp = None
_if_not_exists = False
def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False):
def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False, transaction=None):
self.model = model
self.column_family_name = self.model.column_family_name()
self.instance = instance
@@ -843,6 +844,7 @@ class DMLQuery(object):
self._consistency = consistency
self._timestamp = timestamp
self._if_not_exists = if_not_exists
self._transaction = transaction
def _execute(self, q):
if self._batch:
@@ -897,7 +899,8 @@ class DMLQuery(object):
raise CQLEngineException("DML Query intance attribute is None")
assert type(self.instance) == self.model
static_update_only = True
statement = UpdateStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp)
statement = UpdateStatement(self.column_family_name, ttl=self._ttl,
timestamp=self._timestamp, transactions=self._transaction)
#get defined fields and their column names
for name, col in self.model._columns.items():
if not col.is_primary_key:
@@ -960,7 +963,8 @@ class DMLQuery(object):
if self.instance._has_counter or self.instance._can_update():
return self.update()
else:
insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, if_not_exists=self._if_not_exists)
insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, if_not_exists=self._if_not_exists, transactions=self._transaction)
for name, col in self.instance._columns.items():
val = getattr(self.instance, name, None)
if col._val_is_null(val):

View File

@@ -152,6 +152,11 @@ class TransactionClause(BaseClause):
def insert_tuple(self):
return self.field, self.context_id
def update_context(self, ctx):
if self.field not in ['exists', 'not_exists']:
return super(TransactionClause, self).update_context(ctx)
return ctx
class ContainerUpdateClause(AssignmentClause):
@@ -646,10 +651,12 @@ class AssignmentStatement(BaseCQLStatement):
ctx = super(AssignmentStatement, self).get_context()
for clause in self.assignments:
clause.update_context(ctx)
for clause in self.transactions or []:
clause.update_context(ctx)
return ctx
def _transactions(self):
return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.where_clauses]))
def _get_transactions(self):
return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.transactions]))
class InsertStatement(AssignmentStatement):
@@ -697,7 +704,7 @@ class InsertStatement(AssignmentStatement):
qs += ["USING TIMESTAMP {}".format(self.timestamp_normalized)]
if len(self.transactions) > 0:
qs += [self._transactions]
qs += [self._get_transactions()]
return ' '.join(qs)
@@ -726,7 +733,7 @@ class UpdateStatement(AssignmentStatement):
qs += [self._where]
if len(self.transactions) > 0:
qs += [self._transactions]
qs += [self._get_transactions()]
return ' '.join(qs)

View File

@@ -4,10 +4,13 @@ from uuid import uuid4
from mock import patch
from cqlengine.exceptions import ValidationError
from cqlengine.tests.base import BaseCassEngTestCase
from unittest import TestCase
from cqlengine.models import Model
from cqlengine import columns
from cqlengine.management import sync_table, drop_table
from cqlengine import connection
connection.setup(['192.168.56.103'], 'test')
class TestUpdateModel(Model):
@@ -18,7 +21,7 @@ class TestUpdateModel(Model):
text = columns.Text(required=False, index=True)
class ModelUpdateTests(BaseCassEngTestCase):
class ModelUpdateTests(TestCase):
@classmethod
def setUpClass(cls):
@@ -31,6 +34,12 @@ class ModelUpdateTests(BaseCassEngTestCase):
drop_table(TestUpdateModel)
def test_transaction_insertion(self):
m = TestUpdateModel(count=5, text='something').transaction(exists=True)
m.save()
m = TestUpdateModel.objects.create(count=5, text='something')
all_models = TestUpdateModel.objects.all()
for x in all_models:
pass
m = TestUpdateModel.objects
m = m.transaction(not_exists=True)
m = m.create(count=5, text='something')
m.transaction(count=6).update(text='something else')
x = 10