Merge pull request #576 from datastax/151

PYTHON-151 - fail fast if batch is too large
This commit is contained in:
Adam Holmberg
2016-05-09 12:23:26 -05:00
3 changed files with 109 additions and 14 deletions

View File

@@ -699,6 +699,19 @@ class BatchStatement(Statement):
Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level,
serial_consistency_level=serial_consistency_level, custom_payload=custom_payload)
def clear(self):
"""
This is a convenience method to clear a batch statement for reuse.
*Note:* it should not be used concurrently with uncompleted execution futures executing the same
``BatchStatement``.
"""
del self._statements_and_parameters[:]
self.keyspace = None
self.routing_key = None
if self.custom_payload:
self.custom_payload.clear()
def add(self, statement, parameters=None):
"""
Adds a :class:`.Statement` and optional sequence of parameters
@@ -711,21 +724,19 @@ class BatchStatement(Statement):
if parameters:
encoder = Encoder() if self._session is None else self._session.encoder
statement = bind_params(statement, parameters, encoder)
self._statements_and_parameters.append((False, statement, ()))
self._add_statement_and_params(False, statement, ())
elif isinstance(statement, PreparedStatement):
query_id = statement.query_id
bound_statement = statement.bind(() if parameters is None else parameters)
self._update_state(bound_statement)
self._statements_and_parameters.append(
(True, query_id, bound_statement.values))
self._add_statement_and_params(True, query_id, bound_statement.values)
elif isinstance(statement, BoundStatement):
if parameters:
raise ValueError(
"Parameters cannot be passed with a BoundStatement "
"to BatchStatement.add()")
self._update_state(statement)
self._statements_and_parameters.append(
(True, statement.prepared_statement.query_id, statement.values))
self._add_statement_and_params(True, statement.prepared_statement.query_id, statement.values)
else:
# it must be a SimpleStatement
query_string = statement.query_string
@@ -733,17 +744,22 @@ class BatchStatement(Statement):
encoder = Encoder() if self._session is None else self._session.encoder
query_string = bind_params(query_string, parameters, encoder)
self._update_state(statement)
self._statements_and_parameters.append((False, query_string, ()))
self._add_statement_and_params(False, query_string, ())
return self
def add_all(self, statements, parameters):
"""
Adds a sequence of :class:`.Statement` objects and a matching sequence
of parameters to the batch. :const:`None` can be used in place of
parameters when no parameters are needed.
of parameters to the batch. Statement and parameter sequences must be of equal length or
one will be truncated. :const:`None` can be used in the parameters position where are needed.
"""
for statement, value in zip(statements, parameters):
self.add(statement, parameters)
self.add(statement, value)
def _add_statement_and_params(self, is_prepared, statement, parameters):
if len(self._statements_and_parameters) >= 0xFFFF:
raise ValueError("Batch statement cannot contain more than %d statements." % 0xFFFF)
self._statements_and_parameters.append((is_prepared, statement, parameters))
def _maybe_set_routing_attributes(self, statement):
if self.routing_key is None:

View File

@@ -471,11 +471,6 @@ class BatchStatementTests(BasicSharedKeyspaceUnitTestCase):
self.session.execute(batch)
self.confirm_results()
def test_no_parameters_many_times(self):
for i in range(1000):
self.test_no_parameters()
self.session.execute("TRUNCATE test3rf.test")
def test_unicode(self):
ddl = '''
CREATE TABLE test3rf.testtext (
@@ -491,6 +486,22 @@ class BatchStatementTests(BasicSharedKeyspaceUnitTestCase):
finally:
self.session.execute("DROP TABLE test3rf.testtext")
def test_too_many_statements(self):
max_statements = 0xFFFF
ss = SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (0, 0)")
b = BatchStatement(batch_type=BatchType.UNLOGGED, consistency_level=ConsistencyLevel.ONE)
# max works
b.add_all([ss] * max_statements, [None] * max_statements)
self.session.execute(b)
# max + 1 raises
self.assertRaises(ValueError, b.add, ss)
# also would have bombed trying to encode
b._statements_and_parameters.append((False, ss.query_string, ()))
self.assertRaises(NoHostAvailable, self.session.execute, b)
class SerialConsistencyTests(unittest.TestCase):
def setUp(self):

68
tests/unit/test_query.py Normal file
View File

@@ -0,0 +1,68 @@
# Copyright 2013-2016 DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
import unittest2 as unittest
except ImportError:
import unittest # noqa
import six
from cassandra.query import BatchStatement, SimpleStatement
class BatchStatementTest(unittest.TestCase):
# TODO: this suite could be expanded; for now just adding a test covering a PR
def test_clear(self):
keyspace = 'keyspace'
routing_key = 'routing_key'
custom_payload = {'key': six.b('value')}
ss = SimpleStatement('whatever', keyspace=keyspace, routing_key=routing_key, custom_payload=custom_payload)
batch = BatchStatement()
batch.add(ss)
self.assertTrue(batch._statements_and_parameters)
self.assertEqual(batch.keyspace, keyspace)
self.assertEqual(batch.routing_key, routing_key)
self.assertEqual(batch.custom_payload, custom_payload)
batch.clear()
self.assertFalse(batch._statements_and_parameters)
self.assertIsNone(batch.keyspace)
self.assertIsNone(batch.routing_key)
self.assertFalse(batch.custom_payload)
batch.add(ss)
def test_clear_empty(self):
batch = BatchStatement()
batch.clear()
self.assertFalse(batch._statements_and_parameters)
self.assertIsNone(batch.keyspace)
self.assertIsNone(batch.routing_key)
self.assertFalse(batch.custom_payload)
batch.add('something')
def test_add_all(self):
batch = BatchStatement()
statements = ['%s'] * 10
parameters = [(i,) for i in range(10)]
batch.add_all(statements, parameters)
bound_statements = [t[1] for t in batch._statements_and_parameters]
str_parameters = [str(i) for i in range(10)]
self.assertEqual(bound_statements, str_parameters)