cqle: rework statement part. key building

PYTHON-535
This commit is contained in:
Adam Holmberg
2016-04-01 14:39:42 -05:00
parent 1ffd4dd0bb
commit f239392747
3 changed files with 19 additions and 29 deletions

View File

@@ -26,7 +26,7 @@ from cassandra.cqlengine.operators import (InOperator, EqualsOperator, GreaterTh
GreaterThanOrEqualOperator, LessThanOperator, GreaterThanOrEqualOperator, LessThanOperator,
LessThanOrEqualOperator, ContainsOperator, BaseWhereOperator) LessThanOrEqualOperator, ContainsOperator, BaseWhereOperator)
from cassandra.cqlengine.statements import (WhereClause, SelectStatement, DeleteStatement, from cassandra.cqlengine.statements import (WhereClause, SelectStatement, DeleteStatement,
UpdateStatement, AssignmentClause, InsertStatement, UpdateStatement, InsertStatement,
BaseCQLStatement, MapDeleteClause, ConditionalClause) BaseCQLStatement, MapDeleteClause, ConditionalClause)

View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from datetime import datetime, timedelta from datetime import datetime, timedelta
from itertools import ifilter
import time import time
import six import six
@@ -481,8 +482,6 @@ class MapDeleteClause(BaseDeleteClause):
class BaseCQLStatement(UnicodeMixin): class BaseCQLStatement(UnicodeMixin):
""" The base cql statement class """ """ The base cql statement class """
parition_key_values = None
def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, conditionals=None): def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, conditionals=None):
super(BaseCQLStatement, self).__init__() super(BaseCQLStatement, self).__init__()
self.table = table self.table = table
@@ -500,24 +499,19 @@ class BaseCQLStatement(UnicodeMixin):
for conditional in conditionals or []: for conditional in conditionals or []:
self.add_conditional_clause(conditional) self.add_conditional_clause(conditional)
def _update_partition_key(self, column, value): def _update_part_key_values(self, field_index_map, clauses, parts):
if column.partition_key: for clause in ifilter(lambda c: c.field in field_index_map, clauses):
if self.parition_key_values: parts[field_index_map[clause.field]] = clause.value
self.parition_key_values.append(value)
else: def partition_key_values(self, field_index_map):
self.parition_key_values = [value] parts = [None] * len(field_index_map)
# assert part keys are added in order self._update_part_key_values(field_index_map, self.where_clauses, parts)
# this is an optimization based on the way statements are constructed in return parts
# cqlengine.query (columns always iterated in order). If that assumption
# goes away we can preallocate the key values list and insert using
# self.partition_key_values
assert column._partition_key_index == len(self.parition_key_values) - 1
def add_where(self, column, operator, value, quote_field=True): def add_where(self, column, operator, value, quote_field=True):
value = column.to_database(value) value = column.to_database(value)
clause = WhereClause(column.db_field_name, operator, value, quote_field) clause = WhereClause(column.db_field_name, operator, value, quote_field)
self._add_where_clause(clause) self._add_where_clause(clause)
self._update_partition_key(column, value)
def _add_where_clause(self, clause): def _add_where_clause(self, clause):
clause.set_context_id(self.context_counter) clause.set_context_id(self.context_counter)
@@ -682,11 +676,15 @@ class AssignmentStatement(BaseCQLStatement):
assignment.set_context_id(self.context_counter) assignment.set_context_id(self.context_counter)
self.context_counter += assignment.get_context_size() self.context_counter += assignment.get_context_size()
def partition_key_values(self, field_index_map):
parts = super(AssignmentStatement, self).partition_key_values(field_index_map)
self._update_part_key_values(field_index_map, self.assignments, parts)
return parts
def add_assignment(self, column, value): def add_assignment(self, column, value):
value = column.to_database(value) value = column.to_database(value)
clause = AssignmentClause(column.db_field_name, value) clause = AssignmentClause(column.db_field_name, value)
self._add_assignment_clause(clause) self._add_assignment_clause(clause)
self._update_partition_key(column, value)
def _add_assignment_clause(self, clause): def _add_assignment_clause(self, clause):
clause.set_context_id(self.context_counter) clause.set_context_id(self.context_counter)
@@ -724,9 +722,6 @@ class InsertStatement(AssignmentStatement):
self.if_not_exists = if_not_exists self.if_not_exists = if_not_exists
def add_where_clause(self, clause):
raise StatementException("Cannot add where clauses to insert statements")
def __unicode__(self): def __unicode__(self):
qs = ['INSERT INTO {0}'.format(self.table)] qs = ['INSERT INTO {0}'.format(self.table)]

View File

@@ -16,18 +16,13 @@ try:
except ImportError: except ImportError:
import unittest # noqa import unittest # noqa
from cassandra.cqlengine.columns import Column
from cassandra.cqlengine.statements import InsertStatement, StatementException, AssignmentClause
import six import six
class InsertStatementTests(unittest.TestCase): from cassandra.cqlengine.columns import Column
from cassandra.cqlengine.statements import InsertStatement
def test_where_clause_failure(self):
""" tests that where clauses cannot be added to Insert statements """ class InsertStatementTests(unittest.TestCase):
ist = InsertStatement('table', None)
with self.assertRaises(StatementException):
ist.add_where_clause('s')
def test_statement(self): def test_statement(self):
ist = InsertStatement('table', None) ist = InsertStatement('table', None)