Bug fixes

This commit is contained in:
timmartin19
2014-10-08 14:48:35 -04:00
parent 09427f4d75
commit 874a63cdbc
4 changed files with 45 additions and 21 deletions

View File

@@ -80,9 +80,14 @@ class TransactionDescriptor(object):
""" """
def __get__(self, instance, model): def __get__(self, instance, model):
if instance: if instance:
def transaction_setter(transaction): def transaction_setter(*prepared_transaction, **unprepared_transactions):
instance._transaction = transaction if len(prepared_transaction) > 0:
transactions = prepared_transaction[0]
else:
transactions = instance.objects.transaction(**unprepared_transactions)._transaction
instance._transaction = transactions
return instance return instance
return transaction_setter return transaction_setter
qs = model.__queryset__(model) qs = model.__queryset__(model)
@@ -323,6 +328,7 @@ class BaseModel(object):
self._values = {} self._values = {}
self._ttl = self.__default_ttl__ self._ttl = self.__default_ttl__
self._timestamp = None self._timestamp = None
self._transaction = None
for name, column in self._columns.items(): for name, column in self._columns.items():
value = values.get(name, None) value = values.get(name, None)
@@ -553,10 +559,6 @@ class BaseModel(object):
return cls.objects.filter(*args, **kwargs) return cls.objects.filter(*args, **kwargs)
@classmethod
def transaction(cls, *args, **kwargs):
return cls.objects.transaction(*args, **kwargs)
@classmethod @classmethod
def get(cls, *args, **kwargs): def get(cls, *args, **kwargs):
return cls.objects.get(*args, **kwargs) return cls.objects.get(*args, **kwargs)
@@ -576,7 +578,8 @@ class BaseModel(object):
ttl=self._ttl, ttl=self._ttl,
timestamp=self._timestamp, timestamp=self._timestamp,
consistency=self.__consistency__, consistency=self.__consistency__,
if_not_exists=self._if_not_exists).save() if_not_exists=self._if_not_exists,
transaction=self._transaction).save()
#reset the value managers #reset the value managers
for v in self._values.values(): for v in self._values.values():
@@ -614,7 +617,8 @@ class BaseModel(object):
batch=self._batch, batch=self._batch,
ttl=self._ttl, ttl=self._ttl,
timestamp=self._timestamp, timestamp=self._timestamp,
consistency=self.__consistency__).update() consistency=self.__consistency__,
transaction=self._transaction).update()
#reset the value managers #reset the value managers
for v in self._values.values(): for v in self._values.values():

View File

@@ -422,11 +422,12 @@ class AbstractQuerySet(object):
clone._transaction.append(operator) clone._transaction.append(operator)
for col_name, val in kwargs.items(): for col_name, val in kwargs.items():
exists = False
try: try:
column = self.model._get_column(col_name) column = self.model._get_column(col_name)
except KeyError: except KeyError:
if col_name in ['exists', 'not_exists']: if col_name in ['exists', 'not_exists']:
pass exists = True
elif col_name == 'pk__token': elif col_name == 'pk__token':
if not isinstance(val, Token): if not isinstance(val, Token):
raise QueryException("Virtual column 'pk__token' may only be compared to Token() values") raise QueryException("Virtual column 'pk__token' may only be compared to Token() values")
@@ -445,7 +446,7 @@ class AbstractQuerySet(object):
len(val.value), len(partition_columns))) len(val.value), len(partition_columns)))
val.set_columns(partition_columns) val.set_columns(partition_columns)
if isinstance(val, BaseQueryFunction): if isinstance(val, BaseQueryFunction) or exists is True:
query_val = val query_val = val
else: else:
query_val = column.to_database(val) query_val = column.to_database(val)
@@ -834,7 +835,7 @@ class DMLQuery(object):
_timestamp = None _timestamp = None
_if_not_exists = False _if_not_exists = False
def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False): def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False, transaction=None):
self.model = model self.model = model
self.column_family_name = self.model.column_family_name() self.column_family_name = self.model.column_family_name()
self.instance = instance self.instance = instance
@@ -843,6 +844,7 @@ class DMLQuery(object):
self._consistency = consistency self._consistency = consistency
self._timestamp = timestamp self._timestamp = timestamp
self._if_not_exists = if_not_exists self._if_not_exists = if_not_exists
self._transaction = transaction
def _execute(self, q): def _execute(self, q):
if self._batch: if self._batch:
@@ -897,7 +899,8 @@ class DMLQuery(object):
raise CQLEngineException("DML Query intance attribute is None") raise CQLEngineException("DML Query intance attribute is None")
assert type(self.instance) == self.model assert type(self.instance) == self.model
static_update_only = True static_update_only = True
statement = UpdateStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp) statement = UpdateStatement(self.column_family_name, ttl=self._ttl,
timestamp=self._timestamp, transactions=self._transaction)
#get defined fields and their column names #get defined fields and their column names
for name, col in self.model._columns.items(): for name, col in self.model._columns.items():
if not col.is_primary_key: if not col.is_primary_key:
@@ -960,7 +963,8 @@ class DMLQuery(object):
if self.instance._has_counter or self.instance._can_update(): if self.instance._has_counter or self.instance._can_update():
return self.update() return self.update()
else: else:
insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, if_not_exists=self._if_not_exists) insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, if_not_exists=self._if_not_exists, transactions=self._transaction)
for name, col in self.instance._columns.items(): for name, col in self.instance._columns.items():
val = getattr(self.instance, name, None) val = getattr(self.instance, name, None)
if col._val_is_null(val): if col._val_is_null(val):

View File

@@ -152,6 +152,11 @@ class TransactionClause(BaseClause):
def insert_tuple(self): def insert_tuple(self):
return self.field, self.context_id return self.field, self.context_id
def update_context(self, ctx):
if self.field not in ['exists', 'not_exists']:
return super(TransactionClause, self).update_context(ctx)
return ctx
class ContainerUpdateClause(AssignmentClause): class ContainerUpdateClause(AssignmentClause):
@@ -646,10 +651,12 @@ class AssignmentStatement(BaseCQLStatement):
ctx = super(AssignmentStatement, self).get_context() ctx = super(AssignmentStatement, self).get_context()
for clause in self.assignments: for clause in self.assignments:
clause.update_context(ctx) clause.update_context(ctx)
for clause in self.transactions or []:
clause.update_context(ctx)
return ctx return ctx
def _transactions(self): def _get_transactions(self):
return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.where_clauses])) return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.transactions]))
class InsertStatement(AssignmentStatement): class InsertStatement(AssignmentStatement):
@@ -697,7 +704,7 @@ class InsertStatement(AssignmentStatement):
qs += ["USING TIMESTAMP {}".format(self.timestamp_normalized)] qs += ["USING TIMESTAMP {}".format(self.timestamp_normalized)]
if len(self.transactions) > 0: if len(self.transactions) > 0:
qs += [self._transactions] qs += [self._get_transactions()]
return ' '.join(qs) return ' '.join(qs)
@@ -726,7 +733,7 @@ class UpdateStatement(AssignmentStatement):
qs += [self._where] qs += [self._where]
if len(self.transactions) > 0: if len(self.transactions) > 0:
qs += [self._transactions] qs += [self._get_transactions()]
return ' '.join(qs) return ' '.join(qs)

View File

@@ -4,10 +4,13 @@ from uuid import uuid4
from mock import patch from mock import patch
from cqlengine.exceptions import ValidationError from cqlengine.exceptions import ValidationError
from cqlengine.tests.base import BaseCassEngTestCase from unittest import TestCase
from cqlengine.models import Model from cqlengine.models import Model
from cqlengine import columns from cqlengine import columns
from cqlengine.management import sync_table, drop_table from cqlengine.management import sync_table, drop_table
from cqlengine import connection
connection.setup(['192.168.56.103'], 'test')
class TestUpdateModel(Model): class TestUpdateModel(Model):
@@ -18,7 +21,7 @@ class TestUpdateModel(Model):
text = columns.Text(required=False, index=True) text = columns.Text(required=False, index=True)
class ModelUpdateTests(BaseCassEngTestCase): class ModelUpdateTests(TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
@@ -31,6 +34,12 @@ class ModelUpdateTests(BaseCassEngTestCase):
drop_table(TestUpdateModel) drop_table(TestUpdateModel)
def test_transaction_insertion(self): def test_transaction_insertion(self):
m = TestUpdateModel(count=5, text='something').transaction(exists=True) m = TestUpdateModel.objects.create(count=5, text='something')
m.save() all_models = TestUpdateModel.objects.all()
for x in all_models:
pass
m = TestUpdateModel.objects
m = m.transaction(not_exists=True)
m = m.create(count=5, text='something')
m.transaction(count=6).update(text='something else')
x = 10 x = 10