cqle: rework statement part. key building
PYTHON-535
This commit is contained in:
@@ -26,7 +26,7 @@ from cassandra.cqlengine.operators import (InOperator, EqualsOperator, GreaterTh
|
|||||||
GreaterThanOrEqualOperator, LessThanOperator,
|
GreaterThanOrEqualOperator, LessThanOperator,
|
||||||
LessThanOrEqualOperator, ContainsOperator, BaseWhereOperator)
|
LessThanOrEqualOperator, ContainsOperator, BaseWhereOperator)
|
||||||
from cassandra.cqlengine.statements import (WhereClause, SelectStatement, DeleteStatement,
|
from cassandra.cqlengine.statements import (WhereClause, SelectStatement, DeleteStatement,
|
||||||
UpdateStatement, AssignmentClause, InsertStatement,
|
UpdateStatement, InsertStatement,
|
||||||
BaseCQLStatement, MapDeleteClause, ConditionalClause)
|
BaseCQLStatement, MapDeleteClause, ConditionalClause)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from itertools import ifilter
|
||||||
import time
|
import time
|
||||||
import six
|
import six
|
||||||
|
|
||||||
@@ -481,8 +482,6 @@ class MapDeleteClause(BaseDeleteClause):
|
|||||||
class BaseCQLStatement(UnicodeMixin):
|
class BaseCQLStatement(UnicodeMixin):
|
||||||
""" The base cql statement class """
|
""" The base cql statement class """
|
||||||
|
|
||||||
parition_key_values = None
|
|
||||||
|
|
||||||
def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, conditionals=None):
|
def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, conditionals=None):
|
||||||
super(BaseCQLStatement, self).__init__()
|
super(BaseCQLStatement, self).__init__()
|
||||||
self.table = table
|
self.table = table
|
||||||
@@ -500,24 +499,19 @@ class BaseCQLStatement(UnicodeMixin):
|
|||||||
for conditional in conditionals or []:
|
for conditional in conditionals or []:
|
||||||
self.add_conditional_clause(conditional)
|
self.add_conditional_clause(conditional)
|
||||||
|
|
||||||
def _update_partition_key(self, column, value):
|
def _update_part_key_values(self, field_index_map, clauses, parts):
|
||||||
if column.partition_key:
|
for clause in ifilter(lambda c: c.field in field_index_map, clauses):
|
||||||
if self.parition_key_values:
|
parts[field_index_map[clause.field]] = clause.value
|
||||||
self.parition_key_values.append(value)
|
|
||||||
else:
|
def partition_key_values(self, field_index_map):
|
||||||
self.parition_key_values = [value]
|
parts = [None] * len(field_index_map)
|
||||||
# assert part keys are added in order
|
self._update_part_key_values(field_index_map, self.where_clauses, parts)
|
||||||
# this is an optimization based on the way statements are constructed in
|
return parts
|
||||||
# cqlengine.query (columns always iterated in order). If that assumption
|
|
||||||
# goes away we can preallocate the key values list and insert using
|
|
||||||
# self.partition_key_values
|
|
||||||
assert column._partition_key_index == len(self.parition_key_values) - 1
|
|
||||||
|
|
||||||
def add_where(self, column, operator, value, quote_field=True):
|
def add_where(self, column, operator, value, quote_field=True):
|
||||||
value = column.to_database(value)
|
value = column.to_database(value)
|
||||||
clause = WhereClause(column.db_field_name, operator, value, quote_field)
|
clause = WhereClause(column.db_field_name, operator, value, quote_field)
|
||||||
self._add_where_clause(clause)
|
self._add_where_clause(clause)
|
||||||
self._update_partition_key(column, value)
|
|
||||||
|
|
||||||
def _add_where_clause(self, clause):
|
def _add_where_clause(self, clause):
|
||||||
clause.set_context_id(self.context_counter)
|
clause.set_context_id(self.context_counter)
|
||||||
@@ -682,11 +676,15 @@ class AssignmentStatement(BaseCQLStatement):
|
|||||||
assignment.set_context_id(self.context_counter)
|
assignment.set_context_id(self.context_counter)
|
||||||
self.context_counter += assignment.get_context_size()
|
self.context_counter += assignment.get_context_size()
|
||||||
|
|
||||||
|
def partition_key_values(self, field_index_map):
|
||||||
|
parts = super(AssignmentStatement, self).partition_key_values(field_index_map)
|
||||||
|
self._update_part_key_values(field_index_map, self.assignments, parts)
|
||||||
|
return parts
|
||||||
|
|
||||||
def add_assignment(self, column, value):
|
def add_assignment(self, column, value):
|
||||||
value = column.to_database(value)
|
value = column.to_database(value)
|
||||||
clause = AssignmentClause(column.db_field_name, value)
|
clause = AssignmentClause(column.db_field_name, value)
|
||||||
self._add_assignment_clause(clause)
|
self._add_assignment_clause(clause)
|
||||||
self._update_partition_key(column, value)
|
|
||||||
|
|
||||||
def _add_assignment_clause(self, clause):
|
def _add_assignment_clause(self, clause):
|
||||||
clause.set_context_id(self.context_counter)
|
clause.set_context_id(self.context_counter)
|
||||||
@@ -724,9 +722,6 @@ class InsertStatement(AssignmentStatement):
|
|||||||
|
|
||||||
self.if_not_exists = if_not_exists
|
self.if_not_exists = if_not_exists
|
||||||
|
|
||||||
def add_where_clause(self, clause):
|
|
||||||
raise StatementException("Cannot add where clauses to insert statements")
|
|
||||||
|
|
||||||
def __unicode__(self):
|
def __unicode__(self):
|
||||||
qs = ['INSERT INTO {0}'.format(self.table)]
|
qs = ['INSERT INTO {0}'.format(self.table)]
|
||||||
|
|
||||||
|
|||||||
@@ -16,18 +16,13 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
import unittest # noqa
|
import unittest # noqa
|
||||||
|
|
||||||
from cassandra.cqlengine.columns import Column
|
|
||||||
from cassandra.cqlengine.statements import InsertStatement, StatementException, AssignmentClause
|
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
class InsertStatementTests(unittest.TestCase):
|
from cassandra.cqlengine.columns import Column
|
||||||
|
from cassandra.cqlengine.statements import InsertStatement
|
||||||
|
|
||||||
def test_where_clause_failure(self):
|
|
||||||
""" tests that where clauses cannot be added to Insert statements """
|
class InsertStatementTests(unittest.TestCase):
|
||||||
ist = InsertStatement('table', None)
|
|
||||||
with self.assertRaises(StatementException):
|
|
||||||
ist.add_where_clause('s')
|
|
||||||
|
|
||||||
def test_statement(self):
|
def test_statement(self):
|
||||||
ist = InsertStatement('table', None)
|
ist = InsertStatement('table', None)
|
||||||
|
|||||||
Reference in New Issue
Block a user