cqle: make statements collect primary key values
initially for DML path; still need to address queryset PYTHON-535
This commit is contained in:
@@ -1239,11 +1239,7 @@ class DMLQuery(object):
|
||||
|
||||
if deleted_fields:
|
||||
for name, col in self.model._primary_keys.items():
|
||||
ds.add_where_clause(WhereClause(
|
||||
col.db_field_name,
|
||||
EqualsOperator(),
|
||||
col.to_database(getattr(self.instance, name))
|
||||
))
|
||||
ds.add_where(col, EqualsOperator(), getattr(self.instance, name))
|
||||
self._execute(ds)
|
||||
|
||||
def update(self):
|
||||
@@ -1285,11 +1281,7 @@ class DMLQuery(object):
|
||||
# only include clustering key if clustering key is not null, and non static columns are changed to avoid cql error
|
||||
if (null_clustering_key or static_changed_only) and (not col.partition_key):
|
||||
continue
|
||||
statement.add_where_clause(WhereClause(
|
||||
col.db_field_name,
|
||||
EqualsOperator(),
|
||||
col.to_database(getattr(self.instance, name))
|
||||
))
|
||||
statement.add_where(col, EqualsOperator(), getattr(self.instance, name))
|
||||
self._execute(statement)
|
||||
|
||||
if not null_clustering_key:
|
||||
@@ -1324,10 +1316,7 @@ class DMLQuery(object):
|
||||
if self.instance._values[name].changed:
|
||||
nulled_fields.add(col.db_field_name)
|
||||
continue
|
||||
insert.add_assignment_clause(AssignmentClause(
|
||||
col.db_field_name,
|
||||
col.to_database(getattr(self.instance, name, None))
|
||||
))
|
||||
insert.add_assignment(col, getattr(self.instance, name, None))
|
||||
|
||||
# skip query execution if it's empty
|
||||
# caused by pointless update queries
|
||||
@@ -1344,12 +1333,8 @@ class DMLQuery(object):
|
||||
|
||||
ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists)
|
||||
for name, col in self.model._primary_keys.items():
|
||||
if (not col.partition_key) and (getattr(self.instance, name) is None):
|
||||
val = getattr(self.instance, name)
|
||||
if val is None and not col.parition_key:
|
||||
continue
|
||||
|
||||
ds.add_where_clause(WhereClause(
|
||||
col.db_field_name,
|
||||
EqualsOperator(),
|
||||
col.to_database(getattr(self.instance, name))
|
||||
))
|
||||
ds.add_where(col, EqualsOperator(), val)
|
||||
self._execute(ds)
|
||||
|
||||
@@ -481,6 +481,8 @@ class MapDeleteClause(BaseDeleteClause):
|
||||
class BaseCQLStatement(UnicodeMixin):
|
||||
""" The base cql statement class """
|
||||
|
||||
parition_key_values = None
|
||||
|
||||
def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, conditionals=None):
|
||||
super(BaseCQLStatement, self).__init__()
|
||||
self.table = table
|
||||
@@ -492,20 +494,32 @@ class BaseCQLStatement(UnicodeMixin):
|
||||
|
||||
self.where_clauses = []
|
||||
for clause in where or []:
|
||||
self.add_where_clause(clause)
|
||||
self._add_where_clause(clause)
|
||||
|
||||
self.conditionals = []
|
||||
for conditional in conditionals or []:
|
||||
self.add_conditional_clause(conditional)
|
||||
|
||||
def add_where_clause(self, clause):
|
||||
"""
|
||||
adds a where clause to this statement
|
||||
:param clause: the clause to add
|
||||
:type clause: WhereClause
|
||||
"""
|
||||
if not isinstance(clause, WhereClause):
|
||||
raise StatementException("only instances of WhereClause can be added to statements")
|
||||
def _update_partition_key(self, column, value):
|
||||
if column.partition_key:
|
||||
if self.parition_key_values:
|
||||
self.parition_key_values.append(value)
|
||||
else:
|
||||
self.parition_key_values = [value]
|
||||
# assert part keys are added in order
|
||||
# this is an optimization based on the way statements are constructed in
|
||||
# cqlengine.query (columns always iterated in order). If that assumption
|
||||
# goes away we can preallocate the key values list and insert using
|
||||
# self.partition_key_values
|
||||
assert column._partition_key_index == len(self.parition_key_values) - 1
|
||||
|
||||
def add_where(self, column, operator, value, quote_field=True):
|
||||
value = column.to_database(value)
|
||||
clause = WhereClause(column.db_field_name, operator, value, quote_field)
|
||||
self._add_where_clause(clause)
|
||||
self._update_partition_key(column, value)
|
||||
|
||||
def _add_where_clause(self, clause):
|
||||
clause.set_context_id(self.context_counter)
|
||||
self.context_counter += clause.get_context_size()
|
||||
self.where_clauses.append(clause)
|
||||
@@ -660,7 +674,7 @@ class AssignmentStatement(BaseCQLStatement):
|
||||
# add assignments
|
||||
self.assignments = []
|
||||
for assignment in assignments or []:
|
||||
self.add_assignment_clause(assignment)
|
||||
self._add_assignment_clause(assignment)
|
||||
|
||||
def update_context_id(self, i):
|
||||
super(AssignmentStatement, self).update_context_id(i)
|
||||
@@ -668,14 +682,13 @@ class AssignmentStatement(BaseCQLStatement):
|
||||
assignment.set_context_id(self.context_counter)
|
||||
self.context_counter += assignment.get_context_size()
|
||||
|
||||
def add_assignment_clause(self, clause):
|
||||
"""
|
||||
adds an assignment clause to this statement
|
||||
:param clause: the clause to add
|
||||
:type clause: AssignmentClause
|
||||
"""
|
||||
if not isinstance(clause, AssignmentClause):
|
||||
raise StatementException("only instances of AssignmentClause can be added to statements")
|
||||
def add_assignment(self, column, value):
|
||||
value = column.to_database(value)
|
||||
clause = AssignmentClause(column.db_field_name, value)
|
||||
self._add_assignment_clause(clause)
|
||||
self._update_partition_key(column, value)
|
||||
|
||||
def _add_assignment_clause(self, clause):
|
||||
clause.set_context_id(self.context_counter)
|
||||
self.context_counter += clause.get_context_size()
|
||||
self.assignments.append(clause)
|
||||
@@ -811,7 +824,7 @@ class UpdateStatement(AssignmentStatement):
|
||||
else:
|
||||
clause = AssignmentClause(column.db_field_name, value)
|
||||
if clause.get_context_size(): # this is to exclude map removals from updates. Can go away if we drop support for C* < 1.2.4 and remove two-phase updates
|
||||
self.add_assignment_clause(clause)
|
||||
self._add_assignment_clause(clause)
|
||||
|
||||
|
||||
class DeleteStatement(BaseCQLStatement):
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
# Copyright 2013-2016 DataStax, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from cassandra.cqlengine.statements import AssignmentStatement, StatementException
|
||||
|
||||
|
||||
class AssignmentStatementTest(unittest.TestCase):
|
||||
|
||||
def test_add_assignment_type_checking(self):
|
||||
""" tests that only assignment clauses can be added to queries """
|
||||
stmt = AssignmentStatement('table', [])
|
||||
with self.assertRaises(StatementException):
|
||||
stmt.add_assignment_clause('x=5')
|
||||
@@ -17,17 +17,11 @@ except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from cassandra.query import FETCH_SIZE_UNSET
|
||||
from cassandra.cqlengine.statements import BaseCQLStatement, StatementException
|
||||
from cassandra.cqlengine.statements import BaseCQLStatement
|
||||
|
||||
|
||||
class BaseStatementTest(unittest.TestCase):
|
||||
|
||||
def test_where_clause_type_checking(self):
|
||||
""" tests that only assignment clauses can be added to queries """
|
||||
stmt = BaseCQLStatement('table', [])
|
||||
with self.assertRaises(StatementException):
|
||||
stmt.add_where_clause('x=5')
|
||||
|
||||
def test_fetch_size(self):
|
||||
""" tests that fetch_size is correctly set """
|
||||
stmt = BaseCQLStatement('table', None, fetch_size=1000)
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
from unittest import TestCase
|
||||
|
||||
from cassandra.cqlengine.columns import Column
|
||||
from cassandra.cqlengine.statements import DeleteStatement, WhereClause, MapDeleteClause, ConditionalClause
|
||||
from cassandra.cqlengine.operators import *
|
||||
import six
|
||||
@@ -45,13 +47,13 @@ class DeleteStatementTests(TestCase):
|
||||
|
||||
def test_where_clause_rendering(self):
|
||||
ds = DeleteStatement('table', None)
|
||||
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
||||
ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
|
||||
self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s', six.text_type(ds))
|
||||
|
||||
def test_context_update(self):
|
||||
ds = DeleteStatement('table', None)
|
||||
ds.add_field(MapDeleteClause('d', {1: 2}, {1: 2, 3: 4}))
|
||||
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
||||
ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
|
||||
|
||||
ds.update_context_id(7)
|
||||
self.assertEqual(six.text_type(ds), 'DELETE "d"[%(8)s] FROM table WHERE "a" = %(7)s')
|
||||
@@ -59,19 +61,19 @@ class DeleteStatementTests(TestCase):
|
||||
|
||||
def test_context(self):
|
||||
ds = DeleteStatement('table', None)
|
||||
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
||||
ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
|
||||
self.assertEqual(ds.get_context(), {'0': 'b'})
|
||||
|
||||
def test_range_deletion_rendering(self):
|
||||
ds = DeleteStatement('table', None)
|
||||
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
||||
ds.add_where_clause(WhereClause('created_at', GreaterThanOrEqualOperator(), '0'))
|
||||
ds.add_where_clause(WhereClause('created_at', LessThanOrEqualOperator(), '10'))
|
||||
ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
|
||||
ds.add_where(Column(db_field='created_at'), GreaterThanOrEqualOperator(), '0')
|
||||
ds.add_where(Column(db_field='created_at'), LessThanOrEqualOperator(), '10')
|
||||
self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" >= %(1)s AND "created_at" <= %(2)s', six.text_type(ds))
|
||||
|
||||
ds = DeleteStatement('table', None)
|
||||
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
||||
ds.add_where_clause(WhereClause('created_at', InOperator(), ['0', '10', '20']))
|
||||
ds.add_where(Column(db_field='a'), EqualsOperator(), 'b')
|
||||
ds.add_where(Column(db_field='created_at'), InOperator(), ['0', '10', '20'])
|
||||
self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" IN %(1)s', six.text_type(ds))
|
||||
|
||||
def test_delete_conditional(self):
|
||||
|
||||
@@ -16,6 +16,7 @@ try:
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from cassandra.cqlengine.columns import Column
|
||||
from cassandra.cqlengine.statements import InsertStatement, StatementException, AssignmentClause
|
||||
|
||||
import six
|
||||
@@ -30,8 +31,8 @@ class InsertStatementTests(unittest.TestCase):
|
||||
|
||||
def test_statement(self):
|
||||
ist = InsertStatement('table', None)
|
||||
ist.add_assignment_clause(AssignmentClause('a', 'b'))
|
||||
ist.add_assignment_clause(AssignmentClause('c', 'd'))
|
||||
ist.add_assignment(Column(db_field='a'), 'b')
|
||||
ist.add_assignment(Column(db_field='c'), 'd')
|
||||
|
||||
self.assertEqual(
|
||||
six.text_type(ist),
|
||||
@@ -40,8 +41,8 @@ class InsertStatementTests(unittest.TestCase):
|
||||
|
||||
def test_context_update(self):
|
||||
ist = InsertStatement('table', None)
|
||||
ist.add_assignment_clause(AssignmentClause('a', 'b'))
|
||||
ist.add_assignment_clause(AssignmentClause('c', 'd'))
|
||||
ist.add_assignment(Column(db_field='a'), 'b')
|
||||
ist.add_assignment(Column(db_field='c'), 'd')
|
||||
|
||||
ist.update_context_id(4)
|
||||
self.assertEqual(
|
||||
@@ -53,6 +54,6 @@ class InsertStatementTests(unittest.TestCase):
|
||||
|
||||
def test_additional_rendering(self):
|
||||
ist = InsertStatement('table', ttl=60)
|
||||
ist.add_assignment_clause(AssignmentClause('a', 'b'))
|
||||
ist.add_assignment_clause(AssignmentClause('c', 'd'))
|
||||
ist.add_assignment(Column(db_field='a'), 'b')
|
||||
ist.add_assignment(Column(db_field='c'), 'd')
|
||||
self.assertIn('USING TTL 60', six.text_type(ist))
|
||||
|
||||
@@ -16,6 +16,7 @@ try:
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from cassandra.cqlengine.columns import Column
|
||||
from cassandra.cqlengine.statements import SelectStatement, WhereClause
|
||||
from cassandra.cqlengine.operators import *
|
||||
import six
|
||||
@@ -46,19 +47,19 @@ class SelectStatementTests(unittest.TestCase):
|
||||
|
||||
def test_where_clause_rendering(self):
|
||||
ss = SelectStatement('table')
|
||||
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
||||
ss.add_where(Column(db_field='a'), EqualsOperator(), 'b')
|
||||
self.assertEqual(six.text_type(ss), 'SELECT * FROM table WHERE "a" = %(0)s', six.text_type(ss))
|
||||
|
||||
def test_count(self):
|
||||
ss = SelectStatement('table', count=True, limit=10, order_by='d')
|
||||
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
||||
ss.add_where(Column(db_field='a'), EqualsOperator(), 'b')
|
||||
self.assertEqual(six.text_type(ss), 'SELECT COUNT(*) FROM table WHERE "a" = %(0)s LIMIT 10', six.text_type(ss))
|
||||
self.assertIn('LIMIT', six.text_type(ss))
|
||||
self.assertNotIn('ORDER', six.text_type(ss))
|
||||
|
||||
def test_distinct(self):
|
||||
ss = SelectStatement('table', distinct_fields=['field2'])
|
||||
ss.add_where_clause(WhereClause('field1', EqualsOperator(), 'b'))
|
||||
ss.add_where(Column(db_field='field1'), EqualsOperator(), 'b')
|
||||
self.assertEqual(six.text_type(ss), 'SELECT DISTINCT "field2" FROM table WHERE "field1" = %(0)s', six.text_type(ss))
|
||||
|
||||
ss = SelectStatement('table', distinct_fields=['field1', 'field2'])
|
||||
@@ -69,13 +70,13 @@ class SelectStatementTests(unittest.TestCase):
|
||||
|
||||
def test_context(self):
|
||||
ss = SelectStatement('table')
|
||||
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
||||
ss.add_where(Column(db_field='a'), EqualsOperator(), 'b')
|
||||
self.assertEqual(ss.get_context(), {'0': 'b'})
|
||||
|
||||
def test_context_id_update(self):
|
||||
""" tests that the right things happen the the context id """
|
||||
ss = SelectStatement('table')
|
||||
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
|
||||
ss.add_where(Column(db_field='a'), EqualsOperator(), 'b')
|
||||
self.assertEqual(ss.get_context(), {'0': 'b'})
|
||||
self.assertEqual(str(ss), 'SELECT * FROM table WHERE "a" = %(0)s')
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ try:
|
||||
except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from cassandra.cqlengine.columns import Column, Set, List, Text
|
||||
from cassandra.cqlengine.operators import *
|
||||
from cassandra.cqlengine.statements import (UpdateStatement, WhereClause,
|
||||
AssignmentClause, SetUpdateClause,
|
||||
@@ -33,54 +34,54 @@ class UpdateStatementTests(unittest.TestCase):
|
||||
|
||||
def test_rendering(self):
|
||||
us = UpdateStatement('table')
|
||||
us.add_assignment_clause(AssignmentClause('a', 'b'))
|
||||
us.add_assignment_clause(AssignmentClause('c', 'd'))
|
||||
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x'))
|
||||
us.add_assignment(Column(db_field='a'), 'b')
|
||||
us.add_assignment(Column(db_field='c'), 'd')
|
||||
us.add_where(Column(db_field='a'), EqualsOperator(), 'x')
|
||||
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(0)s, "c" = %(1)s WHERE "a" = %(2)s', six.text_type(us))
|
||||
|
||||
def test_context(self):
|
||||
us = UpdateStatement('table')
|
||||
us.add_assignment_clause(AssignmentClause('a', 'b'))
|
||||
us.add_assignment_clause(AssignmentClause('c', 'd'))
|
||||
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x'))
|
||||
us.add_assignment(Column(db_field='a'), 'b')
|
||||
us.add_assignment(Column(db_field='c'), 'd')
|
||||
us.add_where(Column(db_field='a'), EqualsOperator(), 'x')
|
||||
self.assertEqual(us.get_context(), {'0': 'b', '1': 'd', '2': 'x'})
|
||||
|
||||
def test_context_update(self):
|
||||
us = UpdateStatement('table')
|
||||
us.add_assignment_clause(AssignmentClause('a', 'b'))
|
||||
us.add_assignment_clause(AssignmentClause('c', 'd'))
|
||||
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x'))
|
||||
us.add_assignment(Column(db_field='a'), 'b')
|
||||
us.add_assignment(Column(db_field='c'), 'd')
|
||||
us.add_where(Column(db_field='a'), EqualsOperator(), 'x')
|
||||
us.update_context_id(3)
|
||||
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(4)s, "c" = %(5)s WHERE "a" = %(3)s')
|
||||
self.assertEqual(us.get_context(), {'4': 'b', '5': 'd', '3': 'x'})
|
||||
|
||||
def test_additional_rendering(self):
|
||||
us = UpdateStatement('table', ttl=60)
|
||||
us.add_assignment_clause(AssignmentClause('a', 'b'))
|
||||
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x'))
|
||||
us.add_assignment(Column(db_field='a'), 'b')
|
||||
us.add_where(Column(db_field='a'), EqualsOperator(), 'x')
|
||||
self.assertIn('USING TTL 60', six.text_type(us))
|
||||
|
||||
def test_update_set_add(self):
|
||||
us = UpdateStatement('table')
|
||||
us.add_assignment_clause(SetUpdateClause('a', set((1,)), operation='add'))
|
||||
us.add_update(Set(Text, db_field='a'), set((1,)), 'add')
|
||||
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s')
|
||||
|
||||
def test_update_empty_set_add_does_not_assign(self):
|
||||
us = UpdateStatement('table')
|
||||
us.add_assignment_clause(SetUpdateClause('a', set(), operation='add'))
|
||||
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s')
|
||||
us.add_update(Set(Text, db_field='a'), set(), 'add')
|
||||
self.assertFalse(us.assignments)
|
||||
|
||||
def test_update_empty_set_removal_does_not_assign(self):
|
||||
us = UpdateStatement('table')
|
||||
us.add_assignment_clause(SetUpdateClause('a', set(), operation='remove'))
|
||||
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" - %(0)s')
|
||||
us.add_update(Set(Text, db_field='a'), set(), 'remove')
|
||||
self.assertFalse(us.assignments)
|
||||
|
||||
def test_update_list_prepend_with_empty_list(self):
|
||||
us = UpdateStatement('table')
|
||||
us.add_assignment_clause(ListUpdateClause('a', [], operation='prepend'))
|
||||
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(0)s + "a"')
|
||||
us.add_update(List(Text, db_field='a'), [], 'prepend')
|
||||
self.assertFalse(us.assignments)
|
||||
|
||||
def test_update_list_append_with_empty_list(self):
|
||||
us = UpdateStatement('table')
|
||||
us.add_assignment_clause(ListUpdateClause('a', [], operation='append'))
|
||||
self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s')
|
||||
us.add_update(List(Text, db_field='a'), [], 'append')
|
||||
self.assertFalse(us.assignments)
|
||||
|
||||
Reference in New Issue
Block a user