cqle: rework statement part. key building
PYTHON-535
This commit is contained in:
@@ -26,7 +26,7 @@ from cassandra.cqlengine.operators import (InOperator, EqualsOperator, GreaterTh
|
||||
GreaterThanOrEqualOperator, LessThanOperator,
|
||||
LessThanOrEqualOperator, ContainsOperator, BaseWhereOperator)
|
||||
from cassandra.cqlengine.statements import (WhereClause, SelectStatement, DeleteStatement,
|
||||
UpdateStatement, AssignmentClause, InsertStatement,
|
||||
UpdateStatement, InsertStatement,
|
||||
BaseCQLStatement, MapDeleteClause, ConditionalClause)
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from itertools import ifilter
|
||||
import time
|
||||
import six
|
||||
|
||||
@@ -481,8 +482,6 @@ class MapDeleteClause(BaseDeleteClause):
|
||||
class BaseCQLStatement(UnicodeMixin):
|
||||
""" The base cql statement class """
|
||||
|
||||
parition_key_values = None
|
||||
|
||||
def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, conditionals=None):
|
||||
super(BaseCQLStatement, self).__init__()
|
||||
self.table = table
|
||||
@@ -500,24 +499,19 @@ class BaseCQLStatement(UnicodeMixin):
|
||||
for conditional in conditionals or []:
|
||||
self.add_conditional_clause(conditional)
|
||||
|
||||
def _update_partition_key(self, column, value):
|
||||
if column.partition_key:
|
||||
if self.parition_key_values:
|
||||
self.parition_key_values.append(value)
|
||||
else:
|
||||
self.parition_key_values = [value]
|
||||
# assert part keys are added in order
|
||||
# this is an optimization based on the way statements are constructed in
|
||||
# cqlengine.query (columns always iterated in order). If that assumption
|
||||
# goes away we can preallocate the key values list and insert using
|
||||
# self.partition_key_values
|
||||
assert column._partition_key_index == len(self.parition_key_values) - 1
|
||||
def _update_part_key_values(self, field_index_map, clauses, parts):
|
||||
for clause in ifilter(lambda c: c.field in field_index_map, clauses):
|
||||
parts[field_index_map[clause.field]] = clause.value
|
||||
|
||||
def partition_key_values(self, field_index_map):
|
||||
parts = [None] * len(field_index_map)
|
||||
self._update_part_key_values(field_index_map, self.where_clauses, parts)
|
||||
return parts
|
||||
|
||||
def add_where(self, column, operator, value, quote_field=True):
|
||||
value = column.to_database(value)
|
||||
clause = WhereClause(column.db_field_name, operator, value, quote_field)
|
||||
self._add_where_clause(clause)
|
||||
self._update_partition_key(column, value)
|
||||
|
||||
def _add_where_clause(self, clause):
|
||||
clause.set_context_id(self.context_counter)
|
||||
@@ -682,11 +676,15 @@ class AssignmentStatement(BaseCQLStatement):
|
||||
assignment.set_context_id(self.context_counter)
|
||||
self.context_counter += assignment.get_context_size()
|
||||
|
||||
def partition_key_values(self, field_index_map):
|
||||
parts = super(AssignmentStatement, self).partition_key_values(field_index_map)
|
||||
self._update_part_key_values(field_index_map, self.assignments, parts)
|
||||
return parts
|
||||
|
||||
def add_assignment(self, column, value):
|
||||
value = column.to_database(value)
|
||||
clause = AssignmentClause(column.db_field_name, value)
|
||||
self._add_assignment_clause(clause)
|
||||
self._update_partition_key(column, value)
|
||||
|
||||
def _add_assignment_clause(self, clause):
|
||||
clause.set_context_id(self.context_counter)
|
||||
@@ -724,9 +722,6 @@ class InsertStatement(AssignmentStatement):
|
||||
|
||||
self.if_not_exists = if_not_exists
|
||||
|
||||
def add_where_clause(self, clause):
|
||||
raise StatementException("Cannot add where clauses to insert statements")
|
||||
|
||||
def __unicode__(self):
|
||||
qs = ['INSERT INTO {0}'.format(self.table)]
|
||||
|
||||
|
||||
@@ -16,18 +16,13 @@ try:
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from cassandra.cqlengine.columns import Column
|
||||
from cassandra.cqlengine.statements import InsertStatement, StatementException, AssignmentClause
|
||||
|
||||
import six
|
||||
|
||||
class InsertStatementTests(unittest.TestCase):
|
||||
from cassandra.cqlengine.columns import Column
|
||||
from cassandra.cqlengine.statements import InsertStatement
|
||||
|
||||
def test_where_clause_failure(self):
|
||||
""" tests that where clauses cannot be added to Insert statements """
|
||||
ist = InsertStatement('table', None)
|
||||
with self.assertRaises(StatementException):
|
||||
ist.add_where_clause('s')
|
||||
|
||||
class InsertStatementTests(unittest.TestCase):
|
||||
|
||||
def test_statement(self):
|
||||
ist = InsertStatement('table', None)
|
||||
|
||||
Reference in New Issue
Block a user