Initial attempt…

This commit is contained in:
timmartin19
2014-10-08 12:24:47 -04:00
parent 66aa65be6e
commit 09427f4d75
4 changed files with 154 additions and 3 deletions

View File

@@ -74,6 +74,27 @@ class QuerySetDescriptor(object):
raise NotImplementedError
class TransactionDescriptor(object):
"""
returns a query set descriptor
"""
def __get__(self, instance, model):
if instance:
def transaction_setter(transaction):
instance._transaction = transaction
return instance
return transaction_setter
qs = model.__queryset__(model)
def transaction_setter(transaction):
qs._transaction = transaction
return instance
return transaction_setter
def __call__(self, *args, **kwargs):
raise NotImplementedError
class TTLDescriptor(object):
"""
returns a query set descriptor
@@ -239,6 +260,7 @@ class BaseModel(object):
objects = QuerySetDescriptor()
ttl = TTLDescriptor()
consistency = ConsistencyDescriptor()
transaction = TransactionDescriptor()
# custom timestamps, see USING TIMESTAMP X
timestamp = TimestampDescriptor()
@@ -303,7 +325,7 @@ class BaseModel(object):
self._timestamp = None
for name, column in self._columns.items():
value = values.get(name, None)
value = values.get(name, None)
if value is not None or isinstance(column, columns.BaseContainerColumn):
value = column.to_python(value)
value_mngr = column.value_manager(self, column, value)
@@ -531,6 +553,10 @@ class BaseModel(object):
return cls.objects.filter(*args, **kwargs)
@classmethod
def transaction(cls, *args, **kwargs):
return cls.objects.transaction(*args, **kwargs)
@classmethod
def get(cls, *args, **kwargs):
return cls.objects.get(*args, **kwargs)

View File

@@ -13,7 +13,7 @@ from cqlengine.functions import Token, BaseQueryFunction, QueryValue, UnicodeMix
#http://www.datastax.com/docs/1.1/references/cql/index
from cqlengine.operators import InOperator, EqualsOperator, GreaterThanOperator, GreaterThanOrEqualOperator
from cqlengine.operators import LessThanOperator, LessThanOrEqualOperator, BaseWhereOperator
from cqlengine.statements import WhereClause, SelectStatement, DeleteStatement, UpdateStatement, AssignmentClause, InsertStatement, BaseCQLStatement, MapUpdateClause, MapDeleteClause, ListUpdateClause, SetUpdateClause, CounterUpdateClause
from cqlengine.statements import WhereClause, SelectStatement, DeleteStatement, UpdateStatement, AssignmentClause, InsertStatement, BaseCQLStatement, MapUpdateClause, MapDeleteClause, ListUpdateClause, SetUpdateClause, CounterUpdateClause, TransactionClause
class QueryException(CQLEngineException): pass
@@ -206,6 +206,9 @@ class AbstractQuerySet(object):
#Where clause filters
self._where = []
# Transaction clause filters
self._transaction = []
#ordering arguments
self._order = []
@@ -407,6 +410,50 @@ class AbstractQuerySet(object):
else:
raise QueryException("Can't parse '{}'".format(arg))
def transaction(self, *args, **kwargs):
"""Adds IF statements to queryset"""
if len([x for x in kwargs.values() if x is None]):
raise CQLEngineException("None values on transaction are not allowed")
clone = copy.deepcopy(self)
for operator in args:
if not isinstance(operator, TransactionClause):
raise QueryException('{} is not a valid query operator'.format(operator))
clone._transaction.append(operator)
for col_name, val in kwargs.items():
try:
column = self.model._get_column(col_name)
except KeyError:
if col_name in ['exists', 'not_exists']:
pass
elif col_name == 'pk__token':
if not isinstance(val, Token):
raise QueryException("Virtual column 'pk__token' may only be compared to Token() values")
column = columns._PartitionKeysToken(self.model)
quote_field = False
else:
raise QueryException("Can't resolve column name: '{}'".format(col_name))
if isinstance(val, Token):
if col_name != 'pk__token':
raise QueryException("Token() values may only be compared to the 'pk__token' virtual column")
partition_columns = column.partition_columns
if len(partition_columns) != len(val.value):
raise QueryException(
'Token() received {} arguments but model has {} partition keys'.format(
len(val.value), len(partition_columns)))
val.set_columns(partition_columns)
if isinstance(val, BaseQueryFunction):
query_val = val
else:
query_val = column.to_database(val)
clone._transaction.append(TransactionClause(col_name, query_val))
return clone
def filter(self, *args, **kwargs):
"""
Adds WHERE arguments to the queryset, returning a new queryset

View File

@@ -139,6 +139,20 @@ class AssignmentClause(BaseClause):
return self.field, self.context_id
class TransactionClause(BaseClause):
""" A single variable transaction statement """
def __unicode__(self):
if self.field == 'exists':
return u'EXISTS'
if self.field == 'not_exists':
return u'NOT EXISTS'
return u'"{}" = %({})s'.format(self.field, self.context_id)
def insert_tuple(self):
return self.field, self.context_id
class ContainerUpdateClause(AssignmentClause):
def __init__(self, field, value, operation=None, previous=None, column=None):
@@ -573,7 +587,8 @@ class AssignmentStatement(BaseCQLStatement):
consistency=None,
where=None,
ttl=None,
timestamp=None):
timestamp=None,
transactions=None):
super(AssignmentStatement, self).__init__(
table,
consistency=consistency,
@@ -587,6 +602,11 @@ class AssignmentStatement(BaseCQLStatement):
for assignment in assignments or []:
self.add_assignment_clause(assignment)
# Add transaction statements
self.transactions = []
for transaction in transactions or []:
self.add_transaction_clause(transaction)
def update_context_id(self, i):
super(AssignmentStatement, self).update_context_id(i)
for assignment in self.assignments:
@@ -605,6 +625,19 @@ class AssignmentStatement(BaseCQLStatement):
self.context_counter += clause.get_context_size()
self.assignments.append(clause)
def add_transaction_clause(self, clause):
"""
Adds a transaction clause to this statement
:param clause: The clause that will be added to the transaction statement
:type clause: TransactionClause
"""
if not isinstance(clause, TransactionClause):
raise StatementException('only instances of AssignmentClause can be added to statements')
clause.set_context_id(self.context_counter)
self.context_counter += clause.get_context_size()
self.transactions.append(clause)
@property
def is_empty(self):
return len(self.assignments) == 0
@@ -615,6 +648,9 @@ class AssignmentStatement(BaseCQLStatement):
clause.update_context(ctx)
return ctx
def _transactions(self):
return 'IF {}'.format(' AND '.join([six.text_type(c) for c in self.where_clauses]))
class InsertStatement(AssignmentStatement):
""" an cql insert select statement """
@@ -660,6 +696,9 @@ class InsertStatement(AssignmentStatement):
if self.timestamp:
qs += ["USING TIMESTAMP {}".format(self.timestamp_normalized)]
if len(self.transactions) > 0:
qs += [self._transactions]
return ' '.join(qs)
@@ -686,6 +725,9 @@ class UpdateStatement(AssignmentStatement):
if self.where_clauses:
qs += [self._where]
if len(self.transactions) > 0:
qs += [self._transactions]
return ' '.join(qs)

View File

@@ -0,0 +1,36 @@
__author__ = 'Tim Martin'
from uuid import uuid4
from mock import patch
from cqlengine.exceptions import ValidationError
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.models import Model
from cqlengine import columns
from cqlengine.management import sync_table, drop_table
class TestUpdateModel(Model):
__keyspace__ = 'test'
partition = columns.UUID(primary_key=True, default=uuid4)
cluster = columns.UUID(primary_key=True, default=uuid4)
count = columns.Integer(required=False)
text = columns.Text(required=False, index=True)
class ModelUpdateTests(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(ModelUpdateTests, cls).setUpClass()
sync_table(TestUpdateModel)
@classmethod
def tearDownClass(cls):
super(ModelUpdateTests, cls).tearDownClass()
drop_table(TestUpdateModel)
def test_transaction_insertion(self):
m = TestUpdateModel(count=5, text='something').transaction(exists=True)
m.save()
x = 10