diff --git a/cassandra/query.py b/cassandra/query.py index 1fb4d4cc..cefdeb53 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -699,6 +699,19 @@ class BatchStatement(Statement): Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level, serial_consistency_level=serial_consistency_level, custom_payload=custom_payload) + def clear(self): + """ + This is a convenience method to clear a batch statement for reuse. + + *Note:* it should not be used concurrently with uncompleted execution futures executing the same + ``BatchStatement``. + """ + del self._statements_and_parameters[:] + self.keyspace = None + self.routing_key = None + if self.custom_payload: + self.custom_payload.clear() + def add(self, statement, parameters=None): """ Adds a :class:`.Statement` and optional sequence of parameters @@ -711,21 +724,19 @@ class BatchStatement(Statement): if parameters: encoder = Encoder() if self._session is None else self._session.encoder statement = bind_params(statement, parameters, encoder) - self._statements_and_parameters.append((False, statement, ())) + self._add_statement_and_params(False, statement, ()) elif isinstance(statement, PreparedStatement): query_id = statement.query_id bound_statement = statement.bind(() if parameters is None else parameters) self._update_state(bound_statement) - self._statements_and_parameters.append( - (True, query_id, bound_statement.values)) + self._add_statement_and_params(True, query_id, bound_statement.values) elif isinstance(statement, BoundStatement): if parameters: raise ValueError( "Parameters cannot be passed with a BoundStatement " "to BatchStatement.add()") self._update_state(statement) - self._statements_and_parameters.append( - (True, statement.prepared_statement.query_id, statement.values)) + self._add_statement_and_params(True, statement.prepared_statement.query_id, statement.values) else: # it must be a SimpleStatement query_string = statement.query_string @@ -733,17 +744,22 @@ class BatchStatement(Statement): encoder = Encoder() if self._session is None else self._session.encoder query_string = bind_params(query_string, parameters, encoder) self._update_state(statement) - self._statements_and_parameters.append((False, query_string, ())) + self._add_statement_and_params(False, query_string, ()) return self def add_all(self, statements, parameters): """ Adds a sequence of :class:`.Statement` objects and a matching sequence - of parameters to the batch. :const:`None` can be used in place of - parameters when no parameters are needed. + of parameters to the batch. Statement and parameter sequences must be of equal length or + one will be truncated. :const:`None` can be used in the parameters position where are needed. """ for statement, value in zip(statements, parameters): - self.add(statement, parameters) + self.add(statement, value) + + def _add_statement_and_params(self, is_prepared, statement, parameters): + if len(self._statements_and_parameters) >= 0xFFFF: + raise ValueError("Batch statement cannot contain more than %d statements." % 0xFFFF) + self._statements_and_parameters.append((is_prepared, statement, parameters)) def _maybe_set_routing_attributes(self, statement): if self.routing_key is None: diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index dff9715a..881eed54 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -471,11 +471,6 @@ class BatchStatementTests(BasicSharedKeyspaceUnitTestCase): self.session.execute(batch) self.confirm_results() - def test_no_parameters_many_times(self): - for i in range(1000): - self.test_no_parameters() - self.session.execute("TRUNCATE test3rf.test") - def test_unicode(self): ddl = ''' CREATE TABLE test3rf.testtext ( @@ -491,6 +486,22 @@ class BatchStatementTests(BasicSharedKeyspaceUnitTestCase): finally: self.session.execute("DROP TABLE test3rf.testtext") + def test_too_many_statements(self): + max_statements = 0xFFFF + ss = SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (0, 0)") + b = BatchStatement(batch_type=BatchType.UNLOGGED, consistency_level=ConsistencyLevel.ONE) + + # max works + b.add_all([ss] * max_statements, [None] * max_statements) + self.session.execute(b) + + # max + 1 raises + self.assertRaises(ValueError, b.add, ss) + + # also would have bombed trying to encode + b._statements_and_parameters.append((False, ss.query_string, ())) + self.assertRaises(NoHostAvailable, self.session.execute, b) + class SerialConsistencyTests(unittest.TestCase): def setUp(self): diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py new file mode 100644 index 00000000..687bca7c --- /dev/null +++ b/tests/unit/test_query.py @@ -0,0 +1,68 @@ +# Copyright 2013-2016 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +import six + +from cassandra.query import BatchStatement, SimpleStatement + + +class BatchStatementTest(unittest.TestCase): + # TODO: this suite could be expanded; for now just adding a test covering a PR + + def test_clear(self): + keyspace = 'keyspace' + routing_key = 'routing_key' + custom_payload = {'key': six.b('value')} + + ss = SimpleStatement('whatever', keyspace=keyspace, routing_key=routing_key, custom_payload=custom_payload) + + batch = BatchStatement() + batch.add(ss) + + self.assertTrue(batch._statements_and_parameters) + self.assertEqual(batch.keyspace, keyspace) + self.assertEqual(batch.routing_key, routing_key) + self.assertEqual(batch.custom_payload, custom_payload) + + batch.clear() + self.assertFalse(batch._statements_and_parameters) + self.assertIsNone(batch.keyspace) + self.assertIsNone(batch.routing_key) + self.assertFalse(batch.custom_payload) + + batch.add(ss) + + def test_clear_empty(self): + batch = BatchStatement() + batch.clear() + self.assertFalse(batch._statements_and_parameters) + self.assertIsNone(batch.keyspace) + self.assertIsNone(batch.routing_key) + self.assertFalse(batch.custom_payload) + + batch.add('something') + + def test_add_all(self): + batch = BatchStatement() + statements = ['%s'] * 10 + parameters = [(i,) for i in range(10)] + batch.add_all(statements, parameters) + bound_statements = [t[1] for t in batch._statements_and_parameters] + str_parameters = [str(i) for i in range(10)] + self.assertEqual(bound_statements, str_parameters)