cqle: refactor transactions to base CQL statement

This commit is contained in:
Adam Holmberg
2016-03-17 16:10:15 -05:00
parent 0f3132d31a
commit c4db355311

View File

@@ -471,7 +471,7 @@ class MapDeleteClause(BaseDeleteClause):
class BaseCQLStatement(UnicodeMixin):
""" The base cql statement class """
def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None):
def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, transactions=None):
super(BaseCQLStatement, self).__init__()
self.table = table
self.consistency = consistency
@@ -484,6 +484,10 @@ class BaseCQLStatement(UnicodeMixin):
for clause in where or []:
self.add_where_clause(clause)
self.transactions = []
for transaction in transactions or []:
self.add_transaction_clause(transaction)
def add_where_clause(self, clause):
"""
adds a where clause to this statement
@@ -506,6 +510,22 @@ class BaseCQLStatement(UnicodeMixin):
clause.update_context(ctx)
return ctx
def add_transaction_clause(self, clause):
"""
Adds a iff clause to this statement
:param clause: The clause that will be added to the iff statement
:type clause: TransactionClause
"""
if not isinstance(clause, TransactionClause):
raise StatementException('only instances of AssignmentClause can be added to statements')
clause.set_context_id(self.context_counter)
self.context_counter += clause.get_context_size()
self.transactions.append(clause)
def _get_transactions(self):
return 'IF {0}'.format(' AND '.join([six.text_type(c) for c in self.transactions]))
def get_context_size(self):
return len(self.get_context())
@@ -616,11 +636,13 @@ class AssignmentStatement(BaseCQLStatement):
consistency=None,
where=None,
ttl=None,
timestamp=None):
timestamp=None,
transactions=None):
super(AssignmentStatement, self).__init__(
table,
consistency=consistency,
where=where,
transactions=transactions
)
self.ttl = ttl
self.timestamp = timestamp
@@ -722,12 +744,8 @@ class UpdateStatement(AssignmentStatement):
consistency=consistency,
where=where,
ttl=ttl,
timestamp=timestamp)
# Add iff statements
self.transactions = []
for transaction in transactions or []:
self.add_transaction_clause(transaction)
timestamp=timestamp,
transactions=transactions)
self.if_exists = if_exists
@@ -759,28 +777,12 @@ class UpdateStatement(AssignmentStatement):
return ' '.join(qs)
def add_transaction_clause(self, clause):
"""
Adds a iff clause to this statement
:param clause: The clause that will be added to the iff statement
:type clause: TransactionClause
"""
if not isinstance(clause, TransactionClause):
raise StatementException('only instances of AssignmentClause can be added to statements')
clause.set_context_id(self.context_counter)
self.context_counter += clause.get_context_size()
self.transactions.append(clause)
def get_context(self):
ctx = super(UpdateStatement, self).get_context()
for clause in self.transactions or []:
clause.update_context(ctx)
return ctx
def _get_transactions(self):
return 'IF {0}'.format(' AND '.join([six.text_type(c) for c in self.transactions]))
def update_context_id(self, i):
super(UpdateStatement, self).update_context_id(i)
for transaction in self.transactions:
@@ -796,7 +798,8 @@ class DeleteStatement(BaseCQLStatement):
table,
consistency=consistency,
where=where,
timestamp=timestamp
timestamp=timestamp,
transactions=transactions
)
self.fields = []
if isinstance(fields, six.string_types):
@@ -804,10 +807,6 @@ class DeleteStatement(BaseCQLStatement):
for field in fields or []:
self.add_field(field)
self.transactions = []
for transaction in transactions or []:
self.add_transaction_clause(transaction)
self.if_exists = if_exists
def update_context_id(self, i):
@@ -833,22 +832,6 @@ class DeleteStatement(BaseCQLStatement):
self.context_counter += field.get_context_size()
self.fields.append(field)
def add_transaction_clause(self, clause):
"""
Adds a iff clause to this statement
:param clause: The clause that will be added to the iff statement
:type clause: TransactionClause
"""
if not isinstance(clause, TransactionClause):
raise StatementException('only instances of AssignmentClause can be added to statements')
clause.set_context_id(self.context_counter)
self.context_counter += clause.get_context_size()
self.transactions.append(clause)
def _get_transactions(self):
return 'IF {0}'.format(' AND '.join([six.text_type(c) for c in self.transactions]))
def __unicode__(self):
qs = ['DELETE']
if self.fields: