diff --git a/cassandra/cqlengine/query.py b/cassandra/cqlengine/query.py index 613fd112..7ac8f362 100644 --- a/cassandra/cqlengine/query.py +++ b/cassandra/cqlengine/query.py @@ -26,7 +26,7 @@ from cassandra.cqlengine.operators import (InOperator, EqualsOperator, GreaterTh GreaterThanOrEqualOperator, LessThanOperator, LessThanOrEqualOperator, ContainsOperator, BaseWhereOperator) from cassandra.cqlengine.statements import (WhereClause, SelectStatement, DeleteStatement, - UpdateStatement, AssignmentClause, InsertStatement, + UpdateStatement, InsertStatement, BaseCQLStatement, MapDeleteClause, ConditionalClause) diff --git a/cassandra/cqlengine/statements.py b/cassandra/cqlengine/statements.py index 2156751b..1df7ead1 100644 --- a/cassandra/cqlengine/statements.py +++ b/cassandra/cqlengine/statements.py @@ -13,6 +13,7 @@ # limitations under the License. from datetime import datetime, timedelta +from itertools import ifilter import time import six @@ -481,8 +482,6 @@ class MapDeleteClause(BaseDeleteClause): class BaseCQLStatement(UnicodeMixin): """ The base cql statement class """ - parition_key_values = None - def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, conditionals=None): super(BaseCQLStatement, self).__init__() self.table = table @@ -500,24 +499,19 @@ class BaseCQLStatement(UnicodeMixin): for conditional in conditionals or []: self.add_conditional_clause(conditional) - def _update_partition_key(self, column, value): - if column.partition_key: - if self.parition_key_values: - self.parition_key_values.append(value) - else: - self.parition_key_values = [value] - # assert part keys are added in order - # this is an optimization based on the way statements are constructed in - # cqlengine.query (columns always iterated in order). If that assumption - # goes away we can preallocate the key values list and insert using - # self.partition_key_values - assert column._partition_key_index == len(self.parition_key_values) - 1 + def _update_part_key_values(self, field_index_map, clauses, parts): + for clause in ifilter(lambda c: c.field in field_index_map, clauses): + parts[field_index_map[clause.field]] = clause.value + + def partition_key_values(self, field_index_map): + parts = [None] * len(field_index_map) + self._update_part_key_values(field_index_map, self.where_clauses, parts) + return parts def add_where(self, column, operator, value, quote_field=True): value = column.to_database(value) clause = WhereClause(column.db_field_name, operator, value, quote_field) self._add_where_clause(clause) - self._update_partition_key(column, value) def _add_where_clause(self, clause): clause.set_context_id(self.context_counter) @@ -682,11 +676,15 @@ class AssignmentStatement(BaseCQLStatement): assignment.set_context_id(self.context_counter) self.context_counter += assignment.get_context_size() + def partition_key_values(self, field_index_map): + parts = super(AssignmentStatement, self).partition_key_values(field_index_map) + self._update_part_key_values(field_index_map, self.assignments, parts) + return parts + def add_assignment(self, column, value): value = column.to_database(value) clause = AssignmentClause(column.db_field_name, value) self._add_assignment_clause(clause) - self._update_partition_key(column, value) def _add_assignment_clause(self, clause): clause.set_context_id(self.context_counter) @@ -724,9 +722,6 @@ class InsertStatement(AssignmentStatement): self.if_not_exists = if_not_exists - def add_where_clause(self, clause): - raise StatementException("Cannot add where clauses to insert statements") - def __unicode__(self): qs = ['INSERT INTO {0}'.format(self.table)] diff --git a/tests/integration/cqlengine/statements/test_insert_statement.py b/tests/integration/cqlengine/statements/test_insert_statement.py index df2aa0f5..7cde948a 100644 --- a/tests/integration/cqlengine/statements/test_insert_statement.py +++ b/tests/integration/cqlengine/statements/test_insert_statement.py @@ -16,18 +16,13 @@ try: except ImportError: import unittest # noqa -from cassandra.cqlengine.columns import Column -from cassandra.cqlengine.statements import InsertStatement, StatementException, AssignmentClause - import six -class InsertStatementTests(unittest.TestCase): +from cassandra.cqlengine.columns import Column +from cassandra.cqlengine.statements import InsertStatement - def test_where_clause_failure(self): - """ tests that where clauses cannot be added to Insert statements """ - ist = InsertStatement('table', None) - with self.assertRaises(StatementException): - ist.add_where_clause('s') + +class InsertStatementTests(unittest.TestCase): def test_statement(self): ist = InsertStatement('table', None)