cqle: rework statement part. key building

PYTHON-535
This commit is contained in:
Adam Holmberg
2016-04-01 14:39:42 -05:00
parent 1ffd4dd0bb
commit f239392747
3 changed files with 19 additions and 29 deletions

View File

@@ -26,7 +26,7 @@ from cassandra.cqlengine.operators import (InOperator, EqualsOperator, GreaterTh
GreaterThanOrEqualOperator, LessThanOperator,
LessThanOrEqualOperator, ContainsOperator, BaseWhereOperator)
from cassandra.cqlengine.statements import (WhereClause, SelectStatement, DeleteStatement,
UpdateStatement, AssignmentClause, InsertStatement,
UpdateStatement, InsertStatement,
BaseCQLStatement, MapDeleteClause, ConditionalClause)

View File

@@ -13,6 +13,7 @@
# limitations under the License.
from datetime import datetime, timedelta
from itertools import ifilter
import time
import six
@@ -481,8 +482,6 @@ class MapDeleteClause(BaseDeleteClause):
class BaseCQLStatement(UnicodeMixin):
""" The base cql statement class """
parition_key_values = None
def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, conditionals=None):
super(BaseCQLStatement, self).__init__()
self.table = table
@@ -500,24 +499,19 @@ class BaseCQLStatement(UnicodeMixin):
for conditional in conditionals or []:
self.add_conditional_clause(conditional)
def _update_partition_key(self, column, value):
if column.partition_key:
if self.parition_key_values:
self.parition_key_values.append(value)
else:
self.parition_key_values = [value]
# assert part keys are added in order
# this is an optimization based on the way statements are constructed in
# cqlengine.query (columns always iterated in order). If that assumption
# goes away we can preallocate the key values list and insert using
# self.partition_key_values
assert column._partition_key_index == len(self.parition_key_values) - 1
def _update_part_key_values(self, field_index_map, clauses, parts):
for clause in ifilter(lambda c: c.field in field_index_map, clauses):
parts[field_index_map[clause.field]] = clause.value
def partition_key_values(self, field_index_map):
parts = [None] * len(field_index_map)
self._update_part_key_values(field_index_map, self.where_clauses, parts)
return parts
def add_where(self, column, operator, value, quote_field=True):
value = column.to_database(value)
clause = WhereClause(column.db_field_name, operator, value, quote_field)
self._add_where_clause(clause)
self._update_partition_key(column, value)
def _add_where_clause(self, clause):
clause.set_context_id(self.context_counter)
@@ -682,11 +676,15 @@ class AssignmentStatement(BaseCQLStatement):
assignment.set_context_id(self.context_counter)
self.context_counter += assignment.get_context_size()
def partition_key_values(self, field_index_map):
parts = super(AssignmentStatement, self).partition_key_values(field_index_map)
self._update_part_key_values(field_index_map, self.assignments, parts)
return parts
def add_assignment(self, column, value):
value = column.to_database(value)
clause = AssignmentClause(column.db_field_name, value)
self._add_assignment_clause(clause)
self._update_partition_key(column, value)
def _add_assignment_clause(self, clause):
clause.set_context_id(self.context_counter)
@@ -724,9 +722,6 @@ class InsertStatement(AssignmentStatement):
self.if_not_exists = if_not_exists
def add_where_clause(self, clause):
raise StatementException("Cannot add where clauses to insert statements")
def __unicode__(self):
qs = ['INSERT INTO {0}'.format(self.table)]

View File

@@ -16,18 +16,13 @@ try:
except ImportError:
import unittest # noqa
from cassandra.cqlengine.columns import Column
from cassandra.cqlengine.statements import InsertStatement, StatementException, AssignmentClause
import six
class InsertStatementTests(unittest.TestCase):
from cassandra.cqlengine.columns import Column
from cassandra.cqlengine.statements import InsertStatement
def test_where_clause_failure(self):
""" tests that where clauses cannot be added to Insert statements """
ist = InsertStatement('table', None)
with self.assertRaises(StatementException):
ist.add_where_clause('s')
class InsertStatementTests(unittest.TestCase):
def test_statement(self):
ist = InsertStatement('table', None)