cqle: make statements collect primary key values

initially for DML path; still need to address queryset

PYTHON-535
This commit is contained in:
Adam Holmberg
2016-04-01 10:14:48 -05:00
parent f03e6381e7
commit 1ffd4dd0bb
8 changed files with 83 additions and 114 deletions

View File

@@ -1239,11 +1239,7 @@ class DMLQuery(object):
if deleted_fields: if deleted_fields:
for name, col in self.model._primary_keys.items(): for name, col in self.model._primary_keys.items():
ds.add_where_clause(WhereClause( ds.add_where(col, EqualsOperator(), getattr(self.instance, name))
col.db_field_name,
EqualsOperator(),
col.to_database(getattr(self.instance, name))
))
self._execute(ds) self._execute(ds)
def update(self): def update(self):
@@ -1285,11 +1281,7 @@ class DMLQuery(object):
# only include clustering key if clustering key is not null, and non static columns are changed to avoid cql error # only include clustering key if clustering key is not null, and non static columns are changed to avoid cql error
if (null_clustering_key or static_changed_only) and (not col.partition_key): if (null_clustering_key or static_changed_only) and (not col.partition_key):
continue continue
statement.add_where_clause(WhereClause( statement.add_where(col, EqualsOperator(), getattr(self.instance, name))
col.db_field_name,
EqualsOperator(),
col.to_database(getattr(self.instance, name))
))
self._execute(statement) self._execute(statement)
if not null_clustering_key: if not null_clustering_key:
@@ -1324,10 +1316,7 @@ class DMLQuery(object):
if self.instance._values[name].changed: if self.instance._values[name].changed:
nulled_fields.add(col.db_field_name) nulled_fields.add(col.db_field_name)
continue continue
insert.add_assignment_clause(AssignmentClause( insert.add_assignment(col, getattr(self.instance, name, None))
col.db_field_name,
col.to_database(getattr(self.instance, name, None))
))
# skip query execution if it's empty # skip query execution if it's empty
# caused by pointless update queries # caused by pointless update queries
@@ -1344,12 +1333,8 @@ class DMLQuery(object):
ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists) ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists)
for name, col in self.model._primary_keys.items(): for name, col in self.model._primary_keys.items():
if (not col.partition_key) and (getattr(self.instance, name) is None): val = getattr(self.instance, name)
if val is None and not col.parition_key:
continue continue
ds.add_where(col, EqualsOperator(), val)
ds.add_where_clause(WhereClause(
col.db_field_name,
EqualsOperator(),
col.to_database(getattr(self.instance, name))
))
self._execute(ds) self._execute(ds)

View File

@@ -481,6 +481,8 @@ class MapDeleteClause(BaseDeleteClause):
class BaseCQLStatement(UnicodeMixin): class BaseCQLStatement(UnicodeMixin):
""" The base cql statement class """ """ The base cql statement class """
parition_key_values = None
def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, conditionals=None): def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, conditionals=None):
super(BaseCQLStatement, self).__init__() super(BaseCQLStatement, self).__init__()
self.table = table self.table = table
@@ -492,20 +494,32 @@ class BaseCQLStatement(UnicodeMixin):
self.where_clauses = [] self.where_clauses = []
for clause in where or []: for clause in where or []:
self.add_where_clause(clause) self._add_where_clause(clause)
self.conditionals = [] self.conditionals = []
for conditional in conditionals or []: for conditional in conditionals or []:
self.add_conditional_clause(conditional) self.add_conditional_clause(conditional)
def add_where_clause(self, clause): def _update_partition_key(self, column, value):
""" if column.partition_key:
adds a where clause to this statement if self.parition_key_values:
:param clause: the clause to add self.parition_key_values.append(value)
:type clause: WhereClause else:
""" self.parition_key_values = [value]
if not isinstance(clause, WhereClause): # assert part keys are added in order
raise StatementException("only instances of WhereClause can be added to statements") # this is an optimization based on the way statements are constructed in
# cqlengine.query (columns always iterated in order). If that assumption
# goes away we can preallocate the key values list and insert using
# self.partition_key_values
assert column._partition_key_index == len(self.parition_key_values) - 1
def add_where(self, column, operator, value, quote_field=True):
value = column.to_database(value)
clause = WhereClause(column.db_field_name, operator, value, quote_field)
self._add_where_clause(clause)
self._update_partition_key(column, value)
def _add_where_clause(self, clause):
clause.set_context_id(self.context_counter) clause.set_context_id(self.context_counter)
self.context_counter += clause.get_context_size() self.context_counter += clause.get_context_size()
self.where_clauses.append(clause) self.where_clauses.append(clause)
@@ -660,7 +674,7 @@ class AssignmentStatement(BaseCQLStatement):
# add assignments # add assignments
self.assignments = [] self.assignments = []
for assignment in assignments or []: for assignment in assignments or []:
self.add_assignment_clause(assignment) self._add_assignment_clause(assignment)
def update_context_id(self, i): def update_context_id(self, i):
super(AssignmentStatement, self).update_context_id(i) super(AssignmentStatement, self).update_context_id(i)
@@ -668,14 +682,13 @@ class AssignmentStatement(BaseCQLStatement):
assignment.set_context_id(self.context_counter) assignment.set_context_id(self.context_counter)
self.context_counter += assignment.get_context_size() self.context_counter += assignment.get_context_size()
def add_assignment_clause(self, clause): def add_assignment(self, column, value):
""" value = column.to_database(value)
adds an assignment clause to this statement clause = AssignmentClause(column.db_field_name, value)
:param clause: the clause to add self._add_assignment_clause(clause)
:type clause: AssignmentClause self._update_partition_key(column, value)
"""
if not isinstance(clause, AssignmentClause): def _add_assignment_clause(self, clause):
raise StatementException("only instances of AssignmentClause can be added to statements")
clause.set_context_id(self.context_counter) clause.set_context_id(self.context_counter)
self.context_counter += clause.get_context_size() self.context_counter += clause.get_context_size()
self.assignments.append(clause) self.assignments.append(clause)
@@ -811,7 +824,7 @@ class UpdateStatement(AssignmentStatement):
else: else:
clause = AssignmentClause(column.db_field_name, value) clause = AssignmentClause(column.db_field_name, value)
if clause.get_context_size(): # this is to exclude map removals from updates. Can go away if we drop support for C* < 1.2.4 and remove two-phase updates if clause.get_context_size(): # this is to exclude map removals from updates. Can go away if we drop support for C* < 1.2.4 and remove two-phase updates
self.add_assignment_clause(clause) self._add_assignment_clause(clause)
class DeleteStatement(BaseCQLStatement): class DeleteStatement(BaseCQLStatement):

View File

@@ -1,28 +0,0 @@
# Copyright 2013-2016 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
from cassandra.cqlengine.statements import AssignmentStatement, StatementException
class AssignmentStatementTest(unittest.TestCase):
def test_add_assignment_type_checking(self):
""" tests that only assignment clauses can be added to queries """
stmt = AssignmentStatement('table', [])
with self.assertRaises(StatementException):
stmt.add_assignment_clause('x=5')

View File

@@ -17,17 +17,11 @@ except ImportError:
import unittest # noqa import unittest # noqa
from cassandra.query import FETCH_SIZE_UNSET from cassandra.query import FETCH_SIZE_UNSET
from cassandra.cqlengine.statements import BaseCQLStatement, StatementException from cassandra.cqlengine.statements import BaseCQLStatement
class BaseStatementTest(unittest.TestCase): class BaseStatementTest(unittest.TestCase):
def test_where_clause_type_checking(self):
""" tests that only assignment clauses can be added to queries """
stmt = BaseCQLStatement('table', [])
with self.assertRaises(StatementException):
stmt.add_where_clause('x=5')
def test_fetch_size(self): def test_fetch_size(self):
""" tests that fetch_size is correctly set """ """ tests that fetch_size is correctly set """
stmt = BaseCQLStatement('table', None, fetch_size=1000) stmt = BaseCQLStatement('table', None, fetch_size=1000)

View File

@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
from unittest import TestCase from unittest import TestCase
from cassandra.cqlengine.columns import Column
from cassandra.cqlengine.statements import DeleteStatement, WhereClause, MapDeleteClause, ConditionalClause from cassandra.cqlengine.statements import DeleteStatement, WhereClause, MapDeleteClause, ConditionalClause
from cassandra.cqlengine.operators import * from cassandra.cqlengine.operators import *
import six import six
@@ -45,13 +47,13 @@ class DeleteStatementTests(TestCase):
def test_where_clause_rendering(self): def test_where_clause_rendering(self):
ds = DeleteStatement('table', None) ds = DeleteStatement('table', None)
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s', six.text_type(ds)) self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s', six.text_type(ds))
def test_context_update(self): def test_context_update(self):
ds = DeleteStatement('table', None) ds = DeleteStatement('table', None)
ds.add_field(MapDeleteClause('d', {1: 2}, {1: 2, 3: 4})) ds.add_field(MapDeleteClause('d', {1: 2}, {1: 2, 3: 4}))
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
ds.update_context_id(7) ds.update_context_id(7)
self.assertEqual(six.text_type(ds), 'DELETE "d"[%(8)s] FROM table WHERE "a" = %(7)s') self.assertEqual(six.text_type(ds), 'DELETE "d"[%(8)s] FROM table WHERE "a" = %(7)s')
@@ -59,19 +61,19 @@ class DeleteStatementTests(TestCase):
def test_context(self): def test_context(self):
ds = DeleteStatement('table', None) ds = DeleteStatement('table', None)
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
self.assertEqual(ds.get_context(), {'0': 'b'}) self.assertEqual(ds.get_context(), {'0': 'b'})
def test_range_deletion_rendering(self): def test_range_deletion_rendering(self):
ds = DeleteStatement('table', None) ds = DeleteStatement('table', None)
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
ds.add_where_clause(WhereClause('created_at', GreaterThanOrEqualOperator(), '0')) ds.add_where(Column(db_field='created_at'), GreaterThanOrEqualOperator(), '0')
ds.add_where_clause(WhereClause('created_at', LessThanOrEqualOperator(), '10')) ds.add_where(Column(db_field='created_at'), LessThanOrEqualOperator(), '10')
self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" >= %(1)s AND "created_at" <= %(2)s', six.text_type(ds)) self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" >= %(1)s AND "created_at" <= %(2)s', six.text_type(ds))
ds = DeleteStatement('table', None) ds = DeleteStatement('table', None)
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
ds.add_where_clause(WhereClause('created_at', InOperator(), ['0', '10', '20'])) ds.add_where(Column(db_field='created_at'), InOperator(), ['0', '10', '20'])
self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" IN %(1)s', six.text_type(ds)) self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" IN %(1)s', six.text_type(ds))
def test_delete_conditional(self): def test_delete_conditional(self):

View File

@@ -16,6 +16,7 @@ try:
except ImportError: except ImportError:
import unittest # noqa import unittest # noqa
from cassandra.cqlengine.columns import Column
from cassandra.cqlengine.statements import InsertStatement, StatementException, AssignmentClause from cassandra.cqlengine.statements import InsertStatement, StatementException, AssignmentClause
import six import six
@@ -30,8 +31,8 @@ class InsertStatementTests(unittest.TestCase):
def test_statement(self): def test_statement(self):
ist = InsertStatement('table', None) ist = InsertStatement('table', None)
ist.add_assignment_clause(AssignmentClause('a', 'b')) ist.add_assignment(Column(db_field='a'), 'b')
ist.add_assignment_clause(AssignmentClause('c', 'd')) ist.add_assignment(Column(db_field='c'), 'd')
self.assertEqual( self.assertEqual(
six.text_type(ist), six.text_type(ist),
@@ -40,8 +41,8 @@ class InsertStatementTests(unittest.TestCase):
def test_context_update(self): def test_context_update(self):
ist = InsertStatement('table', None) ist = InsertStatement('table', None)
ist.add_assignment_clause(AssignmentClause('a', 'b')) ist.add_assignment(Column(db_field='a'), 'b')
ist.add_assignment_clause(AssignmentClause('c', 'd')) ist.add_assignment(Column(db_field='c'), 'd')
ist.update_context_id(4) ist.update_context_id(4)
self.assertEqual( self.assertEqual(
@@ -53,6 +54,6 @@ class InsertStatementTests(unittest.TestCase):
def test_additional_rendering(self): def test_additional_rendering(self):
ist = InsertStatement('table', ttl=60) ist = InsertStatement('table', ttl=60)
ist.add_assignment_clause(AssignmentClause('a', 'b')) ist.add_assignment(Column(db_field='a'), 'b')
ist.add_assignment_clause(AssignmentClause('c', 'd')) ist.add_assignment(Column(db_field='c'), 'd')
self.assertIn('USING TTL 60', six.text_type(ist)) self.assertIn('USING TTL 60', six.text_type(ist))

View File

@@ -16,6 +16,7 @@ try:
except ImportError: except ImportError:
import unittest # noqa import unittest # noqa
from cassandra.cqlengine.columns import Column
from cassandra.cqlengine.statements import SelectStatement, WhereClause from cassandra.cqlengine.statements import SelectStatement, WhereClause
from cassandra.cqlengine.operators import * from cassandra.cqlengine.operators import *
import six import six
@@ -46,19 +47,19 @@ class SelectStatementTests(unittest.TestCase):
def test_where_clause_rendering(self): def test_where_clause_rendering(self):
ss = SelectStatement('table') ss = SelectStatement('table')
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) ss.add_where(Column(db_field='a'), EqualsOperator(), 'b')
self.assertEqual(six.text_type(ss), 'SELECT * FROM table WHERE "a" = %(0)s', six.text_type(ss)) self.assertEqual(six.text_type(ss), 'SELECT * FROM table WHERE "a" = %(0)s', six.text_type(ss))
def test_count(self): def test_count(self):
ss = SelectStatement('table', count=True, limit=10, order_by='d') ss = SelectStatement('table', count=True, limit=10, order_by='d')
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) ss.add_where(Column(db_field='a'), EqualsOperator(), 'b')
self.assertEqual(six.text_type(ss), 'SELECT COUNT(*) FROM table WHERE "a" = %(0)s LIMIT 10', six.text_type(ss)) self.assertEqual(six.text_type(ss), 'SELECT COUNT(*) FROM table WHERE "a" = %(0)s LIMIT 10', six.text_type(ss))
self.assertIn('LIMIT', six.text_type(ss)) self.assertIn('LIMIT', six.text_type(ss))
self.assertNotIn('ORDER', six.text_type(ss)) self.assertNotIn('ORDER', six.text_type(ss))
def test_distinct(self): def test_distinct(self):
ss = SelectStatement('table', distinct_fields=['field2']) ss = SelectStatement('table', distinct_fields=['field2'])
ss.add_where_clause(WhereClause('field1', EqualsOperator(), 'b')) ss.add_where(Column(db_field='field1'), EqualsOperator(), 'b')
self.assertEqual(six.text_type(ss), 'SELECT DISTINCT "field2" FROM table WHERE "field1" = %(0)s', six.text_type(ss)) self.assertEqual(six.text_type(ss), 'SELECT DISTINCT "field2" FROM table WHERE "field1" = %(0)s', six.text_type(ss))
ss = SelectStatement('table', distinct_fields=['field1', 'field2']) ss = SelectStatement('table', distinct_fields=['field1', 'field2'])
@@ -69,13 +70,13 @@ class SelectStatementTests(unittest.TestCase):
def test_context(self): def test_context(self):
ss = SelectStatement('table') ss = SelectStatement('table')
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) ss.add_where(Column(db_field='a'), EqualsOperator(), 'b')
self.assertEqual(ss.get_context(), {'0': 'b'}) self.assertEqual(ss.get_context(), {'0': 'b'})
def test_context_id_update(self): def test_context_id_update(self):
""" tests that the right things happen the the context id """ """ tests that the right things happen the the context id """
ss = SelectStatement('table') ss = SelectStatement('table')
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) ss.add_where(Column(db_field='a'), EqualsOperator(), 'b')
self.assertEqual(ss.get_context(), {'0': 'b'}) self.assertEqual(ss.get_context(), {'0': 'b'})
self.assertEqual(str(ss), 'SELECT * FROM table WHERE "a" = %(0)s') self.assertEqual(str(ss), 'SELECT * FROM table WHERE "a" = %(0)s')

View File

@@ -16,6 +16,7 @@ try:
except ImportError: except ImportError:
import unittest # noqa import unittest # noqa
from cassandra.cqlengine.columns import Column, Set, List, Text
from cassandra.cqlengine.operators import * from cassandra.cqlengine.operators import *
from cassandra.cqlengine.statements import (UpdateStatement, WhereClause, from cassandra.cqlengine.statements import (UpdateStatement, WhereClause,
AssignmentClause, SetUpdateClause, AssignmentClause, SetUpdateClause,
@@ -33,54 +34,54 @@ class UpdateStatementTests(unittest.TestCase):
def test_rendering(self): def test_rendering(self):
us = UpdateStatement('table') us = UpdateStatement('table')
us.add_assignment_clause(AssignmentClause('a', 'b')) us.add_assignment(Column(db_field='a'), 'b')
us.add_assignment_clause(AssignmentClause('c', 'd')) us.add_assignment(Column(db_field='c'), 'd')
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x')) us.add_where(Column(db_field='a'), EqualsOperator(), 'x')
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(0)s, "c" = %(1)s WHERE "a" = %(2)s', six.text_type(us)) self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(0)s, "c" = %(1)s WHERE "a" = %(2)s', six.text_type(us))
def test_context(self): def test_context(self):
us = UpdateStatement('table') us = UpdateStatement('table')
us.add_assignment_clause(AssignmentClause('a', 'b')) us.add_assignment(Column(db_field='a'), 'b')
us.add_assignment_clause(AssignmentClause('c', 'd')) us.add_assignment(Column(db_field='c'), 'd')
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x')) us.add_where(Column(db_field='a'), EqualsOperator(), 'x')
self.assertEqual(us.get_context(), {'0': 'b', '1': 'd', '2': 'x'}) self.assertEqual(us.get_context(), {'0': 'b', '1': 'd', '2': 'x'})
def test_context_update(self): def test_context_update(self):
us = UpdateStatement('table') us = UpdateStatement('table')
us.add_assignment_clause(AssignmentClause('a', 'b')) us.add_assignment(Column(db_field='a'), 'b')
us.add_assignment_clause(AssignmentClause('c', 'd')) us.add_assignment(Column(db_field='c'), 'd')
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x')) us.add_where(Column(db_field='a'), EqualsOperator(), 'x')
us.update_context_id(3) us.update_context_id(3)
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(4)s, "c" = %(5)s WHERE "a" = %(3)s') self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(4)s, "c" = %(5)s WHERE "a" = %(3)s')
self.assertEqual(us.get_context(), {'4': 'b', '5': 'd', '3': 'x'}) self.assertEqual(us.get_context(), {'4': 'b', '5': 'd', '3': 'x'})
def test_additional_rendering(self): def test_additional_rendering(self):
us = UpdateStatement('table', ttl=60) us = UpdateStatement('table', ttl=60)
us.add_assignment_clause(AssignmentClause('a', 'b')) us.add_assignment(Column(db_field='a'), 'b')
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x')) us.add_where(Column(db_field='a'), EqualsOperator(), 'x')
self.assertIn('USING TTL 60', six.text_type(us)) self.assertIn('USING TTL 60', six.text_type(us))
def test_update_set_add(self): def test_update_set_add(self):
us = UpdateStatement('table') us = UpdateStatement('table')
us.add_assignment_clause(SetUpdateClause('a', set((1,)), operation='add')) us.add_update(Set(Text, db_field='a'), set((1,)), 'add')
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s') self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s')
def test_update_empty_set_add_does_not_assign(self): def test_update_empty_set_add_does_not_assign(self):
us = UpdateStatement('table') us = UpdateStatement('table')
us.add_assignment_clause(SetUpdateClause('a', set(), operation='add')) us.add_update(Set(Text, db_field='a'), set(), 'add')
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s') self.assertFalse(us.assignments)
def test_update_empty_set_removal_does_not_assign(self): def test_update_empty_set_removal_does_not_assign(self):
us = UpdateStatement('table') us = UpdateStatement('table')
us.add_assignment_clause(SetUpdateClause('a', set(), operation='remove')) us.add_update(Set(Text, db_field='a'), set(), 'remove')
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" - %(0)s') self.assertFalse(us.assignments)
def test_update_list_prepend_with_empty_list(self): def test_update_list_prepend_with_empty_list(self):
us = UpdateStatement('table') us = UpdateStatement('table')
us.add_assignment_clause(ListUpdateClause('a', [], operation='prepend')) us.add_update(List(Text, db_field='a'), [], 'prepend')
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(0)s + "a"') self.assertFalse(us.assignments)
def test_update_list_append_with_empty_list(self): def test_update_list_append_with_empty_list(self):
us = UpdateStatement('table') us = UpdateStatement('table')
us.add_assignment_clause(ListUpdateClause('a', [], operation='append')) us.add_update(List(Text, db_field='a'), [], 'append')
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s') self.assertFalse(us.assignments)