# Copyright 2013-2017 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings import sure from cassandra.cqlengine import columns from cassandra.cqlengine.management import drop_table, sync_table from cassandra.cqlengine.models import Model from cassandra.cqlengine.query import BatchQuery from tests.integration.cqlengine.base import BaseCassEngTestCase from mock import patch class TestMultiKeyModel(Model): partition = columns.Integer(primary_key=True) cluster = columns.Integer(primary_key=True) count = columns.Integer(required=False) text = columns.Text(required=False) class BatchQueryTests(BaseCassEngTestCase): @classmethod def setUpClass(cls): super(BatchQueryTests, cls).setUpClass() drop_table(TestMultiKeyModel) sync_table(TestMultiKeyModel) @classmethod def tearDownClass(cls): super(BatchQueryTests, cls).tearDownClass() drop_table(TestMultiKeyModel) def setUp(self): super(BatchQueryTests, self).setUp() self.pkey = 1 for obj in TestMultiKeyModel.filter(partition=self.pkey): obj.delete() def test_insert_success_case(self): b = BatchQuery() TestMultiKeyModel.batch(b).create(partition=self.pkey, cluster=2, count=3, text='4') with self.assertRaises(TestMultiKeyModel.DoesNotExist): TestMultiKeyModel.get(partition=self.pkey, cluster=2) b.execute() TestMultiKeyModel.get(partition=self.pkey, cluster=2) def test_update_success_case(self): inst = TestMultiKeyModel.create(partition=self.pkey, cluster=2, count=3, text='4') b = BatchQuery() inst.count = 4 inst.batch(b).save() inst2 = TestMultiKeyModel.get(partition=self.pkey, cluster=2) self.assertEqual(inst2.count, 3) b.execute() inst3 = TestMultiKeyModel.get(partition=self.pkey, cluster=2) self.assertEqual(inst3.count, 4) def test_delete_success_case(self): inst = TestMultiKeyModel.create(partition=self.pkey, cluster=2, count=3, text='4') b = BatchQuery() inst.batch(b).delete() TestMultiKeyModel.get(partition=self.pkey, cluster=2) b.execute() with self.assertRaises(TestMultiKeyModel.DoesNotExist): TestMultiKeyModel.get(partition=self.pkey, cluster=2) def test_context_manager(self): with BatchQuery() as b: for i in range(5): TestMultiKeyModel.batch(b).create(partition=self.pkey, cluster=i, count=3, text='4') for i in range(5): with self.assertRaises(TestMultiKeyModel.DoesNotExist): TestMultiKeyModel.get(partition=self.pkey, cluster=i) for i in range(5): TestMultiKeyModel.get(partition=self.pkey, cluster=i) def test_bulk_delete_success_case(self): for i in range(1): for j in range(5): TestMultiKeyModel.create(partition=i, cluster=j, count=i*j, text='{0}:{1}'.format(i,j)) with BatchQuery() as b: TestMultiKeyModel.objects.batch(b).filter(partition=0).delete() self.assertEqual(TestMultiKeyModel.filter(partition=0).count(), 5) self.assertEqual(TestMultiKeyModel.filter(partition=0).count(), 0) #cleanup for m in TestMultiKeyModel.all(): m.delete() def test_empty_batch(self): b = BatchQuery() b.execute() with BatchQuery() as b: pass class BatchQueryCallbacksTests(BaseCassEngTestCase): def test_API_managing_callbacks(self): # Callbacks can be added at init and after def my_callback(*args, **kwargs): pass # adding on init: batch = BatchQuery() batch.add_callback(my_callback) batch.add_callback(my_callback, 2, named_arg='value') batch.add_callback(my_callback, 1, 3) self.assertEqual(batch._callbacks, [ (my_callback, (), {}), (my_callback, (2,), {'named_arg':'value'}), (my_callback, (1, 3), {}) ]) def test_callbacks_properly_execute_callables_and_tuples(self): call_history = [] def my_callback(*args, **kwargs): call_history.append(args) # adding on init: batch = BatchQuery() batch.add_callback(my_callback) batch.add_callback(my_callback, 'more', 'args') batch.execute() self.assertEqual(len(call_history), 2) self.assertEqual([(), ('more', 'args')], call_history) def test_callbacks_tied_to_execute(self): """Batch callbacks should NOT fire if batch is not executed in context manager mode""" call_history = [] def my_callback(*args, **kwargs): call_history.append(args) with BatchQuery() as batch: batch.add_callback(my_callback) self.assertEqual(len(call_history), 1) class SomeError(Exception): pass with self.assertRaises(SomeError): with BatchQuery() as batch: batch.add_callback(my_callback) # this error bubbling up through context manager # should prevent callback runs (along with b.execute()) raise SomeError # still same call history. Nothing added self.assertEqual(len(call_history), 1) # but if execute ran, even with an error bubbling through # the callbacks also would have fired with self.assertRaises(SomeError): with BatchQuery(execute_on_exception=True) as batch: batch.add_callback(my_callback) raise SomeError # updated call history self.assertEqual(len(call_history), 2) def test_callbacks_work_multiple_times(self): """ Tests that multiple executions of execute on a batch statement logs a warning, and that we don't encounter an attribute error. @since 3.1 @jira_ticket PYTHON-445 @expected_result warning message is logged @test_category object_mapper """ call_history = [] def my_callback(*args, **kwargs): call_history.append(args) with warnings.catch_warnings(record=True) as w: with BatchQuery() as batch: batch.add_callback(my_callback) batch.execute() batch.execute() self.assertEqual(len(w), 2) # package filter setup to warn always self.assertRegexpMatches(str(w[0].message), r"^Batch.*multiple.*") def test_disable_multiple_callback_warning(self): """ Tests that multiple executions of a batch statement don't log a warning when warn_multiple_exec flag is set, and that we don't encounter an attribute error. @since 3.1 @jira_ticket PYTHON-445 @expected_result warning message is logged @test_category object_mapper """ call_history = [] def my_callback(*args, **kwargs): call_history.append(args) with patch('cassandra.cqlengine.query.BatchQuery.warn_multiple_exec', False): with warnings.catch_warnings(record=True) as w: with BatchQuery() as batch: batch.add_callback(my_callback) batch.execute() batch.execute() self.assertFalse(w)