Initial attempt…
This commit is contained in:
@@ -74,6 +74,27 @@ class QuerySetDescriptor(object):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransactionDescriptor(object):
|
||||
"""
|
||||
returns a query set descriptor
|
||||
"""
|
||||
def __get__(self, instance, model):
|
||||
if instance:
|
||||
def transaction_setter(transaction):
|
||||
instance._transaction = transaction
|
||||
return instance
|
||||
return transaction_setter
|
||||
qs = model.__queryset__(model)
|
||||
|
||||
def transaction_setter(transaction):
|
||||
qs._transaction = transaction
|
||||
return instance
|
||||
return transaction_setter
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TTLDescriptor(object):
|
||||
"""
|
||||
returns a query set descriptor
|
||||
@@ -239,6 +260,7 @@ class BaseModel(object):
|
||||
objects = QuerySetDescriptor()
|
||||
ttl = TTLDescriptor()
|
||||
consistency = ConsistencyDescriptor()
|
||||
transaction = TransactionDescriptor()
|
||||
|
||||
# custom timestamps, see USING TIMESTAMP X
|
||||
timestamp = TimestampDescriptor()
|
||||
@@ -303,7 +325,7 @@ class BaseModel(object):
|
||||
self._timestamp = None
|
||||
|
||||
for name, column in self._columns.items():
|
||||
value = values.get(name, None)
|
||||
value = values.get(name, None)
|
||||
if value is not None or isinstance(column, columns.BaseContainerColumn):
|
||||
value = column.to_python(value)
|
||||
value_mngr = column.value_manager(self, column, value)
|
||||
@@ -531,6 +553,10 @@ class BaseModel(object):
|
||||
|
||||
return cls.objects.filter(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def transaction(cls, *args, **kwargs):
|
||||
return cls.objects.transaction(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get(cls, *args, **kwargs):
|
||||
return cls.objects.get(*args, **kwargs)
|
||||
|
||||
@@ -13,7 +13,7 @@ from cqlengine.functions import Token, BaseQueryFunction, QueryValue, UnicodeMix
|
||||
#http://www.datastax.com/docs/1.1/references/cql/index
|
||||
from cqlengine.operators import InOperator, EqualsOperator, GreaterThanOperator, GreaterThanOrEqualOperator
|
||||
from cqlengine.operators import LessThanOperator, LessThanOrEqualOperator, BaseWhereOperator
|
||||
from cqlengine.statements import WhereClause, SelectStatement, DeleteStatement, UpdateStatement, AssignmentClause, InsertStatement, BaseCQLStatement, MapUpdateClause, MapDeleteClause, ListUpdateClause, SetUpdateClause, CounterUpdateClause
|
||||
from cqlengine.statements import WhereClause, SelectStatement, DeleteStatement, UpdateStatement, AssignmentClause, InsertStatement, BaseCQLStatement, MapUpdateClause, MapDeleteClause, ListUpdateClause, SetUpdateClause, CounterUpdateClause, TransactionClause
|
||||
|
||||
|
||||
class QueryException(CQLEngineException): pass
|
||||
@@ -206,6 +206,9 @@ class AbstractQuerySet(object):
|
||||
#Where clause filters
|
||||
self._where = []
|
||||
|
||||
# Transaction clause filters
|
||||
self._transaction = []
|
||||
|
||||
#ordering arguments
|
||||
self._order = []
|
||||
|
||||
@@ -407,6 +410,50 @@ class AbstractQuerySet(object):
|
||||
else:
|
||||
raise QueryException("Can't parse '{}'".format(arg))
|
||||
|
||||
def transaction(self, *args, **kwargs):
|
||||
"""Adds IF statements to queryset"""
|
||||
if len([x for x in kwargs.values() if x is None]):
|
||||
raise CQLEngineException("None values on transaction are not allowed")
|
||||
|
||||
clone = copy.deepcopy(self)
|
||||
for operator in args:
|
||||
if not isinstance(operator, TransactionClause):
|
||||
raise QueryException('{} is not a valid query operator'.format(operator))
|
||||
clone._transaction.append(operator)
|
||||
|
||||
for col_name, val in kwargs.items():
|
||||
try:
|
||||
column = self.model._get_column(col_name)
|
||||
except KeyError:
|
||||
if col_name in ['exists', 'not_exists']:
|
||||
pass
|
||||
elif col_name == 'pk__token':
|
||||
if not isinstance(val, Token):
|
||||
raise QueryException("Virtual column 'pk__token' may only be compared to Token() values")
|
||||
column = columns._PartitionKeysToken(self.model)
|
||||
quote_field = False
|
||||
else:
|
||||
raise QueryException("Can't resolve column name: '{}'".format(col_name))
|
||||
|
||||
if isinstance(val, Token):
|
||||
if col_name != 'pk__token':
|
||||
raise QueryException("Token() values may only be compared to the 'pk__token' virtual column")
|
||||
partition_columns = column.partition_columns
|
||||
if len(partition_columns) != len(val.value):
|
||||
raise QueryException(
|
||||
'Token() received {} arguments but model has {} partition keys'.format(
|
||||
len(val.value), len(partition_columns)))
|
||||
val.set_columns(partition_columns)
|
||||
|
||||
if isinstance(val, BaseQueryFunction):
|
||||
query_val = val
|
||||
else:
|
||||
query_val = column.to_database(val)
|
||||
|
||||
clone._transaction.append(TransactionClause(col_name, query_val))
|
||||
|
||||
return clone
|
||||
|
||||
def filter(self, *args, **kwargs):
|
||||
"""
|
||||
Adds WHERE arguments to the queryset, returning a new queryset
|
||||
|
||||
@@ -139,6 +139,20 @@ class AssignmentClause(BaseClause):
|
||||
return self.field, self.context_id
|
||||
|
||||
|
||||
class TransactionClause(BaseClause):
|
||||
""" A single variable transaction statement """
|
||||
|
||||
def __unicode__(self):
|
||||
if self.field == 'exists':
|
||||
return u'EXISTS'
|
||||
if self.field == 'not_exists':
|
||||
return u'NOT EXISTS'
|
||||
return u'"{}" = %({})s'.format(self.field, self.context_id)
|
||||
|
||||
def insert_tuple(self):
|
||||
return self.field, self.context_id
|
||||
|
||||
|
||||
class ContainerUpdateClause(AssignmentClause):
|
||||
|
||||
def __init__(self, field, value, operation=None, previous=None, column=None):
|
||||
@@ -573,7 +587,8 @@ class AssignmentStatement(BaseCQLStatement):
|
||||
consistency=None,
|
||||
where=None,
|
||||
ttl=None,
|
||||
timestamp=None):
|
||||
timestamp=None,
|
||||
transactions=None):
|
||||
super(AssignmentStatement, self).__init__(
|
||||
table,
|
||||
consistency=consistency,
|
||||
@@ -587,6 +602,11 @@ class AssignmentStatement(BaseCQLStatement):
|
||||
for assignment in assignments or []:
|
||||
self.add_assignment_clause(assignment)
|
||||
|
||||
# Add transaction statements
|
||||
self.transactions = []
|
||||
for transaction in transactions or []:
|
||||
self.add_transaction_clause(transaction)
|
||||
|
||||
def update_context_id(self, i):
|
||||
super(AssignmentStatement, self).update_context_id(i)
|
||||
for assignment in self.assignments:
|
||||
@@ -605,6 +625,19 @@ class AssignmentStatement(BaseCQLStatement):
|
||||
self.context_counter += clause.get_context_size()
|
||||
self.assignments.append(clause)
|
||||
|
||||
def add_transaction_clause(self, clause):
|
||||
"""
|
||||
Adds a transaction clause to this statement
|
||||
|
||||
:param clause: The clause that will be added to the transaction statement
|
||||
:type clause: TransactionClause
|
||||
"""
|
||||
if not isinstance(clause, TransactionClause):
|
||||
raise StatementException('only instances of AssignmentClause can be added to statements')
|
||||
clause.set_context_id(self.context_counter)
|
||||
self.context_counter += clause.get_context_size()
|
||||
self.transactions.append(clause)
|
||||
|
||||
@property
|
||||
def is_empty(self):
|
||||
return len(self.assignments) == 0
|
||||
@@ -615,6 +648,9 @@ class AssignmentStatement(BaseCQLStatement):
|
||||
clause.update_context(ctx)
|
||||
return ctx
|
||||
|
||||
def _transactions(self):
|
||||
return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.where_clauses]))
|
||||
|
||||
|
||||
class InsertStatement(AssignmentStatement):
|
||||
""" an cql insert select statement """
|
||||
@@ -660,6 +696,9 @@ class InsertStatement(AssignmentStatement):
|
||||
if self.timestamp:
|
||||
qs += ["USING TIMESTAMP {}".format(self.timestamp_normalized)]
|
||||
|
||||
if len(self.transactions) > 0:
|
||||
qs += [self._transactions]
|
||||
|
||||
return ' '.join(qs)
|
||||
|
||||
|
||||
@@ -686,6 +725,9 @@ class UpdateStatement(AssignmentStatement):
|
||||
if self.where_clauses:
|
||||
qs += [self._where]
|
||||
|
||||
if len(self.transactions) > 0:
|
||||
qs += [self._transactions]
|
||||
|
||||
return ' '.join(qs)
|
||||
|
||||
|
||||
|
||||
36
cqlengine/tests/model/test_transaction_statements.py
Normal file
36
cqlengine/tests/model/test_transaction_statements.py
Normal file
@@ -0,0 +1,36 @@
|
||||
__author__ = 'Tim Martin'
|
||||
from uuid import uuid4
|
||||
|
||||
from mock import patch
|
||||
from cqlengine.exceptions import ValidationError
|
||||
|
||||
from cqlengine.tests.base import BaseCassEngTestCase
|
||||
from cqlengine.models import Model
|
||||
from cqlengine import columns
|
||||
from cqlengine.management import sync_table, drop_table
|
||||
|
||||
|
||||
class TestUpdateModel(Model):
|
||||
__keyspace__ = 'test'
|
||||
partition = columns.UUID(primary_key=True, default=uuid4)
|
||||
cluster = columns.UUID(primary_key=True, default=uuid4)
|
||||
count = columns.Integer(required=False)
|
||||
text = columns.Text(required=False, index=True)
|
||||
|
||||
|
||||
class ModelUpdateTests(BaseCassEngTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(ModelUpdateTests, cls).setUpClass()
|
||||
sync_table(TestUpdateModel)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super(ModelUpdateTests, cls).tearDownClass()
|
||||
drop_table(TestUpdateModel)
|
||||
|
||||
def test_transaction_insertion(self):
|
||||
m = TestUpdateModel(count=5, text='something').transaction(exists=True)
|
||||
m.save()
|
||||
x = 10
|
||||
Reference in New Issue
Block a user