cqle: rework statement part. key building
PYTHON-535
This commit is contained in:
		@@ -26,7 +26,7 @@ from cassandra.cqlengine.operators import (InOperator, EqualsOperator, GreaterTh
 | 
				
			|||||||
                                           GreaterThanOrEqualOperator, LessThanOperator,
 | 
					                                           GreaterThanOrEqualOperator, LessThanOperator,
 | 
				
			||||||
                                           LessThanOrEqualOperator, ContainsOperator, BaseWhereOperator)
 | 
					                                           LessThanOrEqualOperator, ContainsOperator, BaseWhereOperator)
 | 
				
			||||||
from cassandra.cqlengine.statements import (WhereClause, SelectStatement, DeleteStatement,
 | 
					from cassandra.cqlengine.statements import (WhereClause, SelectStatement, DeleteStatement,
 | 
				
			||||||
                                            UpdateStatement, AssignmentClause, InsertStatement,
 | 
					                                            UpdateStatement, InsertStatement,
 | 
				
			||||||
                                            BaseCQLStatement, MapDeleteClause, ConditionalClause)
 | 
					                                            BaseCQLStatement, MapDeleteClause, ConditionalClause)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -13,6 +13,7 @@
 | 
				
			|||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from datetime import datetime, timedelta
 | 
					from datetime import datetime, timedelta
 | 
				
			||||||
 | 
					from itertools import ifilter
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import six
 | 
					import six
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -481,8 +482,6 @@ class MapDeleteClause(BaseDeleteClause):
 | 
				
			|||||||
class BaseCQLStatement(UnicodeMixin):
 | 
					class BaseCQLStatement(UnicodeMixin):
 | 
				
			||||||
    """ The base cql statement class """
 | 
					    """ The base cql statement class """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    parition_key_values = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, conditionals=None):
 | 
					    def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, conditionals=None):
 | 
				
			||||||
        super(BaseCQLStatement, self).__init__()
 | 
					        super(BaseCQLStatement, self).__init__()
 | 
				
			||||||
        self.table = table
 | 
					        self.table = table
 | 
				
			||||||
@@ -500,24 +499,19 @@ class BaseCQLStatement(UnicodeMixin):
 | 
				
			|||||||
        for conditional in conditionals or []:
 | 
					        for conditional in conditionals or []:
 | 
				
			||||||
            self.add_conditional_clause(conditional)
 | 
					            self.add_conditional_clause(conditional)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _update_partition_key(self, column, value):
 | 
					    def _update_part_key_values(self, field_index_map, clauses, parts):
 | 
				
			||||||
        if column.partition_key:
 | 
					        for clause in ifilter(lambda c: c.field in field_index_map, clauses):
 | 
				
			||||||
            if self.parition_key_values:
 | 
					            parts[field_index_map[clause.field]] = clause.value
 | 
				
			||||||
                self.parition_key_values.append(value)
 | 
					
 | 
				
			||||||
            else:
 | 
					    def partition_key_values(self, field_index_map):
 | 
				
			||||||
                self.parition_key_values = [value]
 | 
					        parts = [None] * len(field_index_map)
 | 
				
			||||||
            # assert part keys are added in order
 | 
					        self._update_part_key_values(field_index_map, self.where_clauses, parts)
 | 
				
			||||||
            # this is an optimization based on the way statements are constructed in
 | 
					        return parts
 | 
				
			||||||
            # cqlengine.query (columns always iterated in order). If that assumption
 | 
					 | 
				
			||||||
            # goes away we can preallocate the key values list and insert using
 | 
					 | 
				
			||||||
            # self.partition_key_values
 | 
					 | 
				
			||||||
            assert column._partition_key_index == len(self.parition_key_values) - 1
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def add_where(self, column, operator, value, quote_field=True):
 | 
					    def add_where(self, column, operator, value, quote_field=True):
 | 
				
			||||||
        value = column.to_database(value)
 | 
					        value = column.to_database(value)
 | 
				
			||||||
        clause = WhereClause(column.db_field_name, operator, value, quote_field)
 | 
					        clause = WhereClause(column.db_field_name, operator, value, quote_field)
 | 
				
			||||||
        self._add_where_clause(clause)
 | 
					        self._add_where_clause(clause)
 | 
				
			||||||
        self._update_partition_key(column, value)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _add_where_clause(self, clause):
 | 
					    def _add_where_clause(self, clause):
 | 
				
			||||||
        clause.set_context_id(self.context_counter)
 | 
					        clause.set_context_id(self.context_counter)
 | 
				
			||||||
@@ -682,11 +676,15 @@ class AssignmentStatement(BaseCQLStatement):
 | 
				
			|||||||
            assignment.set_context_id(self.context_counter)
 | 
					            assignment.set_context_id(self.context_counter)
 | 
				
			||||||
            self.context_counter += assignment.get_context_size()
 | 
					            self.context_counter += assignment.get_context_size()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def partition_key_values(self, field_index_map):
 | 
				
			||||||
 | 
					        parts = super(AssignmentStatement, self).partition_key_values(field_index_map)
 | 
				
			||||||
 | 
					        self._update_part_key_values(field_index_map, self.assignments, parts)
 | 
				
			||||||
 | 
					        return parts
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def add_assignment(self, column, value):
 | 
					    def add_assignment(self, column, value):
 | 
				
			||||||
        value = column.to_database(value)
 | 
					        value = column.to_database(value)
 | 
				
			||||||
        clause = AssignmentClause(column.db_field_name, value)
 | 
					        clause = AssignmentClause(column.db_field_name, value)
 | 
				
			||||||
        self._add_assignment_clause(clause)
 | 
					        self._add_assignment_clause(clause)
 | 
				
			||||||
        self._update_partition_key(column, value)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _add_assignment_clause(self, clause):
 | 
					    def _add_assignment_clause(self, clause):
 | 
				
			||||||
        clause.set_context_id(self.context_counter)
 | 
					        clause.set_context_id(self.context_counter)
 | 
				
			||||||
@@ -724,9 +722,6 @@ class InsertStatement(AssignmentStatement):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        self.if_not_exists = if_not_exists
 | 
					        self.if_not_exists = if_not_exists
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def add_where_clause(self, clause):
 | 
					 | 
				
			||||||
        raise StatementException("Cannot add where clauses to insert statements")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __unicode__(self):
 | 
					    def __unicode__(self):
 | 
				
			||||||
        qs = ['INSERT INTO {0}'.format(self.table)]
 | 
					        qs = ['INSERT INTO {0}'.format(self.table)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -16,18 +16,13 @@ try:
 | 
				
			|||||||
except ImportError:
 | 
					except ImportError:
 | 
				
			||||||
    import unittest  # noqa
 | 
					    import unittest  # noqa
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from cassandra.cqlengine.columns import Column
 | 
					 | 
				
			||||||
from cassandra.cqlengine.statements import InsertStatement, StatementException, AssignmentClause
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import six
 | 
					import six
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class InsertStatementTests(unittest.TestCase):
 | 
					from cassandra.cqlengine.columns import Column
 | 
				
			||||||
 | 
					from cassandra.cqlengine.statements import InsertStatement
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_where_clause_failure(self):
 | 
					
 | 
				
			||||||
        """ tests that where clauses cannot be added to Insert statements """
 | 
					class InsertStatementTests(unittest.TestCase):
 | 
				
			||||||
        ist = InsertStatement('table', None)
 | 
					 | 
				
			||||||
        with self.assertRaises(StatementException):
 | 
					 | 
				
			||||||
            ist.add_where_clause('s')
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_statement(self):
 | 
					    def test_statement(self):
 | 
				
			||||||
        ist = InsertStatement('table', None)
 | 
					        ist = InsertStatement('table', None)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user