diff --git a/cassandra/cqlengine/statements.py b/cassandra/cqlengine/statements.py index 9613589e..bb34f66f 100644 --- a/cassandra/cqlengine/statements.py +++ b/cassandra/cqlengine/statements.py @@ -471,7 +471,7 @@ class MapDeleteClause(BaseDeleteClause): class BaseCQLStatement(UnicodeMixin): """ The base cql statement class """ - def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None): + def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, transactions=None): super(BaseCQLStatement, self).__init__() self.table = table self.consistency = consistency @@ -484,6 +484,10 @@ class BaseCQLStatement(UnicodeMixin): for clause in where or []: self.add_where_clause(clause) + self.transactions = [] + for transaction in transactions or []: + self.add_transaction_clause(transaction) + def add_where_clause(self, clause): """ adds a where clause to this statement @@ -506,6 +510,22 @@ class BaseCQLStatement(UnicodeMixin): clause.update_context(ctx) return ctx + def add_transaction_clause(self, clause): + """ + Adds a iff clause to this statement + + :param clause: The clause that will be added to the iff statement + :type clause: TransactionClause + """ + if not isinstance(clause, TransactionClause): + raise StatementException('only instances of AssignmentClause can be added to statements') + clause.set_context_id(self.context_counter) + self.context_counter += clause.get_context_size() + self.transactions.append(clause) + + def _get_transactions(self): + return 'IF {0}'.format(' AND '.join([six.text_type(c) for c in self.transactions])) + def get_context_size(self): return len(self.get_context()) @@ -616,11 +636,13 @@ class AssignmentStatement(BaseCQLStatement): consistency=None, where=None, ttl=None, - timestamp=None): + timestamp=None, + transactions=None): super(AssignmentStatement, self).__init__( table, consistency=consistency, where=where, + transactions=transactions ) self.ttl = ttl self.timestamp = timestamp @@ -722,12 +744,8 @@ class UpdateStatement(AssignmentStatement): consistency=consistency, where=where, ttl=ttl, - timestamp=timestamp) - - # Add iff statements - self.transactions = [] - for transaction in transactions or []: - self.add_transaction_clause(transaction) + timestamp=timestamp, + transactions=transactions) self.if_exists = if_exists @@ -759,28 +777,12 @@ class UpdateStatement(AssignmentStatement): return ' '.join(qs) - def add_transaction_clause(self, clause): - """ - Adds a iff clause to this statement - - :param clause: The clause that will be added to the iff statement - :type clause: TransactionClause - """ - if not isinstance(clause, TransactionClause): - raise StatementException('only instances of AssignmentClause can be added to statements') - clause.set_context_id(self.context_counter) - self.context_counter += clause.get_context_size() - self.transactions.append(clause) - def get_context(self): ctx = super(UpdateStatement, self).get_context() for clause in self.transactions or []: clause.update_context(ctx) return ctx - def _get_transactions(self): - return 'IF {0}'.format(' AND '.join([six.text_type(c) for c in self.transactions])) - def update_context_id(self, i): super(UpdateStatement, self).update_context_id(i) for transaction in self.transactions: @@ -796,7 +798,8 @@ class DeleteStatement(BaseCQLStatement): table, consistency=consistency, where=where, - timestamp=timestamp + timestamp=timestamp, + transactions=transactions ) self.fields = [] if isinstance(fields, six.string_types): @@ -804,10 +807,6 @@ class DeleteStatement(BaseCQLStatement): for field in fields or []: self.add_field(field) - self.transactions = [] - for transaction in transactions or []: - self.add_transaction_clause(transaction) - self.if_exists = if_exists def update_context_id(self, i): @@ -833,22 +832,6 @@ class DeleteStatement(BaseCQLStatement): self.context_counter += field.get_context_size() self.fields.append(field) - def add_transaction_clause(self, clause): - """ - Adds a iff clause to this statement - - :param clause: The clause that will be added to the iff statement - :type clause: TransactionClause - """ - if not isinstance(clause, TransactionClause): - raise StatementException('only instances of AssignmentClause can be added to statements') - clause.set_context_id(self.context_counter) - self.context_counter += clause.get_context_size() - self.transactions.append(clause) - - def _get_transactions(self): - return 'IF {0}'.format(' AND '.join([six.text_type(c) for c in self.transactions])) - def __unicode__(self): qs = ['DELETE'] if self.fields: