cqle: make statements collect primary key values

initially for DML path; still need to address queryset

PYTHON-535
This commit is contained in:
Adam Holmberg
2016-04-01 10:14:48 -05:00
parent f03e6381e7
commit 1ffd4dd0bb
8 changed files with 83 additions and 114 deletions

View File

@@ -1239,11 +1239,7 @@ class DMLQuery(object):
if deleted_fields:
for name, col in self.model._primary_keys.items():
ds.add_where_clause(WhereClause(
col.db_field_name,
EqualsOperator(),
col.to_database(getattr(self.instance, name))
))
ds.add_where(col, EqualsOperator(), getattr(self.instance, name))
self._execute(ds)
def update(self):
@@ -1285,11 +1281,7 @@ class DMLQuery(object):
# only include clustering key if clustering key is not null, and non static columns are changed to avoid cql error
if (null_clustering_key or static_changed_only) and (not col.partition_key):
continue
statement.add_where_clause(WhereClause(
col.db_field_name,
EqualsOperator(),
col.to_database(getattr(self.instance, name))
))
statement.add_where(col, EqualsOperator(), getattr(self.instance, name))
self._execute(statement)
if not null_clustering_key:
@@ -1324,10 +1316,7 @@ class DMLQuery(object):
if self.instance._values[name].changed:
nulled_fields.add(col.db_field_name)
continue
insert.add_assignment_clause(AssignmentClause(
col.db_field_name,
col.to_database(getattr(self.instance, name, None))
))
insert.add_assignment(col, getattr(self.instance, name, None))
# skip query execution if it's empty
# caused by pointless update queries
@@ -1344,12 +1333,8 @@ class DMLQuery(object):
ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists)
for name, col in self.model._primary_keys.items():
if (not col.partition_key) and (getattr(self.instance, name) is None):
val = getattr(self.instance, name)
if val is None and not col.parition_key:
continue
ds.add_where_clause(WhereClause(
col.db_field_name,
EqualsOperator(),
col.to_database(getattr(self.instance, name))
))
ds.add_where(col, EqualsOperator(), val)
self._execute(ds)

View File

@@ -481,6 +481,8 @@ class MapDeleteClause(BaseDeleteClause):
class BaseCQLStatement(UnicodeMixin):
""" The base cql statement class """
parition_key_values = None
def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, conditionals=None):
super(BaseCQLStatement, self).__init__()
self.table = table
@@ -492,20 +494,32 @@ class BaseCQLStatement(UnicodeMixin):
self.where_clauses = []
for clause in where or []:
self.add_where_clause(clause)
self._add_where_clause(clause)
self.conditionals = []
for conditional in conditionals or []:
self.add_conditional_clause(conditional)
def add_where_clause(self, clause):
"""
adds a where clause to this statement
:param clause: the clause to add
:type clause: WhereClause
"""
if not isinstance(clause, WhereClause):
raise StatementException("only instances of WhereClause can be added to statements")
def _update_partition_key(self, column, value):
if column.partition_key:
if self.parition_key_values:
self.parition_key_values.append(value)
else:
self.parition_key_values = [value]
# assert part keys are added in order
# this is an optimization based on the way statements are constructed in
# cqlengine.query (columns always iterated in order). If that assumption
# goes away we can preallocate the key values list and insert using
# self.partition_key_values
assert column._partition_key_index == len(self.parition_key_values) - 1
def add_where(self, column, operator, value, quote_field=True):
value = column.to_database(value)
clause = WhereClause(column.db_field_name, operator, value, quote_field)
self._add_where_clause(clause)
self._update_partition_key(column, value)
def _add_where_clause(self, clause):
clause.set_context_id(self.context_counter)
self.context_counter += clause.get_context_size()
self.where_clauses.append(clause)
@@ -660,7 +674,7 @@ class AssignmentStatement(BaseCQLStatement):
# add assignments
self.assignments = []
for assignment in assignments or []:
self.add_assignment_clause(assignment)
self._add_assignment_clause(assignment)
def update_context_id(self, i):
super(AssignmentStatement, self).update_context_id(i)
@@ -668,14 +682,13 @@ class AssignmentStatement(BaseCQLStatement):
assignment.set_context_id(self.context_counter)
self.context_counter += assignment.get_context_size()
def add_assignment_clause(self, clause):
"""
adds an assignment clause to this statement
:param clause: the clause to add
:type clause: AssignmentClause
"""
if not isinstance(clause, AssignmentClause):
raise StatementException("only instances of AssignmentClause can be added to statements")
def add_assignment(self, column, value):
value = column.to_database(value)
clause = AssignmentClause(column.db_field_name, value)
self._add_assignment_clause(clause)
self._update_partition_key(column, value)
def _add_assignment_clause(self, clause):
clause.set_context_id(self.context_counter)
self.context_counter += clause.get_context_size()
self.assignments.append(clause)
@@ -811,7 +824,7 @@ class UpdateStatement(AssignmentStatement):
else:
clause = AssignmentClause(column.db_field_name, value)
if clause.get_context_size(): # this is to exclude map removals from updates. Can go away if we drop support for C* < 1.2.4 and remove two-phase updates
self.add_assignment_clause(clause)
self._add_assignment_clause(clause)
class DeleteStatement(BaseCQLStatement):

View File

@@ -1,28 +0,0 @@
# Copyright 2013-2016 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
from cassandra.cqlengine.statements import AssignmentStatement, StatementException
class AssignmentStatementTest(unittest.TestCase):
def test_add_assignment_type_checking(self):
""" tests that only assignment clauses can be added to queries """
stmt = AssignmentStatement('table', [])
with self.assertRaises(StatementException):
stmt.add_assignment_clause('x=5')

View File

@@ -17,17 +17,11 @@ except ImportError:
import unittest # noqa
from cassandra.query import FETCH_SIZE_UNSET
from cassandra.cqlengine.statements import BaseCQLStatement, StatementException
from cassandra.cqlengine.statements import BaseCQLStatement
class BaseStatementTest(unittest.TestCase):
def test_where_clause_type_checking(self):
""" tests that only assignment clauses can be added to queries """
stmt = BaseCQLStatement('table', [])
with self.assertRaises(StatementException):
stmt.add_where_clause('x=5')
def test_fetch_size(self):
""" tests that fetch_size is correctly set """
stmt = BaseCQLStatement('table', None, fetch_size=1000)

View File

@@ -13,6 +13,8 @@
# limitations under the License.
from unittest import TestCase
from cassandra.cqlengine.columns import Column
from cassandra.cqlengine.statements import DeleteStatement, WhereClause, MapDeleteClause, ConditionalClause
from cassandra.cqlengine.operators import *
import six
@@ -45,13 +47,13 @@ class DeleteStatementTests(TestCase):
def test_where_clause_rendering(self):
ds = DeleteStatement('table', None)
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s', six.text_type(ds))
def test_context_update(self):
ds = DeleteStatement('table', None)
ds.add_field(MapDeleteClause('d', {1: 2}, {1: 2, 3: 4}))
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
ds.update_context_id(7)
self.assertEqual(six.text_type(ds), 'DELETE "d"[%(8)s] FROM table WHERE "a" = %(7)s')
@@ -59,19 +61,19 @@ class DeleteStatementTests(TestCase):
def test_context(self):
ds = DeleteStatement('table', None)
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
self.assertEqual(ds.get_context(), {'0': 'b'})
def test_range_deletion_rendering(self):
ds = DeleteStatement('table', None)
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
ds.add_where_clause(WhereClause('created_at', GreaterThanOrEqualOperator(), '0'))
ds.add_where_clause(WhereClause('created_at', LessThanOrEqualOperator(), '10'))
ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
ds.add_where(Column(db_field='created_at'), GreaterThanOrEqualOperator(), '0')
ds.add_where(Column(db_field='created_at'), LessThanOrEqualOperator(), '10')
self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" >= %(1)s AND "created_at" <= %(2)s', six.text_type(ds))
ds = DeleteStatement('table', None)
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
ds.add_where_clause(WhereClause('created_at', InOperator(), ['0', '10', '20']))
ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
ds.add_where(Column(db_field='created_at'), InOperator(), ['0', '10', '20'])
self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" IN %(1)s', six.text_type(ds))
def test_delete_conditional(self):

View File

@@ -16,6 +16,7 @@ try:
except ImportError:
import unittest # noqa
from cassandra.cqlengine.columns import Column
from cassandra.cqlengine.statements import InsertStatement, StatementException, AssignmentClause
import six
@@ -30,8 +31,8 @@ class InsertStatementTests(unittest.TestCase):
def test_statement(self):
ist = InsertStatement('table', None)
ist.add_assignment_clause(AssignmentClause('a', 'b'))
ist.add_assignment_clause(AssignmentClause('c', 'd'))
ist.add_assignment(Column(db_field='a'), 'b')
ist.add_assignment(Column(db_field='c'), 'd')
self.assertEqual(
six.text_type(ist),
@@ -40,8 +41,8 @@ class InsertStatementTests(unittest.TestCase):
def test_context_update(self):
ist = InsertStatement('table', None)
ist.add_assignment_clause(AssignmentClause('a', 'b'))
ist.add_assignment_clause(AssignmentClause('c', 'd'))
ist.add_assignment(Column(db_field='a'), 'b')
ist.add_assignment(Column(db_field='c'), 'd')
ist.update_context_id(4)
self.assertEqual(
@@ -53,6 +54,6 @@ class InsertStatementTests(unittest.TestCase):
def test_additional_rendering(self):
ist = InsertStatement('table', ttl=60)
ist.add_assignment_clause(AssignmentClause('a', 'b'))
ist.add_assignment_clause(AssignmentClause('c', 'd'))
ist.add_assignment(Column(db_field='a'), 'b')
ist.add_assignment(Column(db_field='c'), 'd')
self.assertIn('USING TTL 60', six.text_type(ist))

View File

@@ -16,6 +16,7 @@ try:
except ImportError:
import unittest # noqa
from cassandra.cqlengine.columns import Column
from cassandra.cqlengine.statements import SelectStatement, WhereClause
from cassandra.cqlengine.operators import *
import six
@@ -46,19 +47,19 @@ class SelectStatementTests(unittest.TestCase):
def test_where_clause_rendering(self):
ss = SelectStatement('table')
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
ss.add_where(Column(db_field='a'), EqualsOperator(), 'b')
self.assertEqual(six.text_type(ss), 'SELECT * FROM table WHERE "a" = %(0)s', six.text_type(ss))
def test_count(self):
ss = SelectStatement('table', count=True, limit=10, order_by='d')
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
ss.add_where(Column(db_field='a'), EqualsOperator(), 'b')
self.assertEqual(six.text_type(ss), 'SELECT COUNT(*) FROM table WHERE "a" = %(0)s LIMIT 10', six.text_type(ss))
self.assertIn('LIMIT', six.text_type(ss))
self.assertNotIn('ORDER', six.text_type(ss))
def test_distinct(self):
ss = SelectStatement('table', distinct_fields=['field2'])
ss.add_where_clause(WhereClause('field1', EqualsOperator(), 'b'))
ss.add_where(Column(db_field='field1'), EqualsOperator(), 'b')
self.assertEqual(six.text_type(ss), 'SELECT DISTINCT "field2" FROM table WHERE "field1" = %(0)s', six.text_type(ss))
ss = SelectStatement('table', distinct_fields=['field1', 'field2'])
@@ -69,13 +70,13 @@ class SelectStatementTests(unittest.TestCase):
def test_context(self):
ss = SelectStatement('table')
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
ss.add_where(Column(db_field='a'), EqualsOperator(), 'b')
self.assertEqual(ss.get_context(), {'0': 'b'})
def test_context_id_update(self):
""" tests that the right things happen the the context id """
ss = SelectStatement('table')
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
ss.add_where(Column(db_field='a'), EqualsOperator(), 'b')
self.assertEqual(ss.get_context(), {'0': 'b'})
self.assertEqual(str(ss), 'SELECT * FROM table WHERE "a" = %(0)s')

View File

@@ -16,6 +16,7 @@ try:
except ImportError:
import unittest # noqa
from cassandra.cqlengine.columns import Column, Set, List, Text
from cassandra.cqlengine.operators import *
from cassandra.cqlengine.statements import (UpdateStatement, WhereClause,
AssignmentClause, SetUpdateClause,
@@ -33,54 +34,54 @@ class UpdateStatementTests(unittest.TestCase):
def test_rendering(self):
us = UpdateStatement('table')
us.add_assignment_clause(AssignmentClause('a', 'b'))
us.add_assignment_clause(AssignmentClause('c', 'd'))
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x'))
us.add_assignment(Column(db_field='a'), 'b')
us.add_assignment(Column(db_field='c'), 'd')
us.add_where(Column(db_field='a'), EqualsOperator(), 'x')
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(0)s, "c" = %(1)s WHERE "a" = %(2)s', six.text_type(us))
def test_context(self):
us = UpdateStatement('table')
us.add_assignment_clause(AssignmentClause('a', 'b'))
us.add_assignment_clause(AssignmentClause('c', 'd'))
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x'))
us.add_assignment(Column(db_field='a'), 'b')
us.add_assignment(Column(db_field='c'), 'd')
us.add_where(Column(db_field='a'), EqualsOperator(), 'x')
self.assertEqual(us.get_context(), {'0': 'b', '1': 'd', '2': 'x'})
def test_context_update(self):
us = UpdateStatement('table')
us.add_assignment_clause(AssignmentClause('a', 'b'))
us.add_assignment_clause(AssignmentClause('c', 'd'))
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x'))
us.add_assignment(Column(db_field='a'), 'b')
us.add_assignment(Column(db_field='c'), 'd')
us.add_where(Column(db_field='a'), EqualsOperator(), 'x')
us.update_context_id(3)
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(4)s, "c" = %(5)s WHERE "a" = %(3)s')
self.assertEqual(us.get_context(), {'4': 'b', '5': 'd', '3': 'x'})
def test_additional_rendering(self):
us = UpdateStatement('table', ttl=60)
us.add_assignment_clause(AssignmentClause('a', 'b'))
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x'))
us.add_assignment(Column(db_field='a'), 'b')
us.add_where(Column(db_field='a'), EqualsOperator(), 'x')
self.assertIn('USING TTL 60', six.text_type(us))
def test_update_set_add(self):
us = UpdateStatement('table')
us.add_assignment_clause(SetUpdateClause('a', set((1,)), operation='add'))
us.add_update(Set(Text, db_field='a'), set((1,)), 'add')
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s')
def test_update_empty_set_add_does_not_assign(self):
us = UpdateStatement('table')
us.add_assignment_clause(SetUpdateClause('a', set(), operation='add'))
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s')
us.add_update(Set(Text, db_field='a'), set(), 'add')
self.assertFalse(us.assignments)
def test_update_empty_set_removal_does_not_assign(self):
us = UpdateStatement('table')
us.add_assignment_clause(SetUpdateClause('a', set(), operation='remove'))
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" - %(0)s')
us.add_update(Set(Text, db_field='a'), set(), 'remove')
self.assertFalse(us.assignments)
def test_update_list_prepend_with_empty_list(self):
us = UpdateStatement('table')
us.add_assignment_clause(ListUpdateClause('a', [], operation='prepend'))
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(0)s + "a"')
us.add_update(List(Text, db_field='a'), [], 'prepend')
self.assertFalse(us.assignments)
def test_update_list_append_with_empty_list(self):
us = UpdateStatement('table')
us.add_assignment_clause(ListUpdateClause('a', [], operation='append'))
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s')
us.add_update(List(Text, db_field='a'), [], 'append')
self.assertFalse(us.assignments)