diff --git a/taskflow/tests/unit/test_task.py b/taskflow/tests/unit/test_task.py index f7668059..b85514b4 100644 --- a/taskflow/tests/unit/test_task.py +++ b/taskflow/tests/unit/test_task.py @@ -65,6 +65,14 @@ class SeparateRevertOptionalTask(task.Task): pass +class RevertKwargsTask(task.Task): + def execute(self, execute_arg1, execute_arg2): + pass + + def revert(self, execute_arg1, *args, **kwargs): + pass + + class TaskTest(test.TestCase): def test_passed_name(self): @@ -185,12 +193,12 @@ class TaskTest(test.TestCase): MyTask, rebind={'foo': 'bar'}) def test_rebind_unknown_kwargs(self): - task = KwargsTask(rebind={'foo': 'bar'}) + my_task = KwargsTask(rebind={'foo': 'bar'}) expected = { 'foo': 'bar', 'spam': 'spam' } - self.assertEqual(expected, task.rebind) + self.assertEqual(expected, my_task.rebind) def test_rebind_list_all(self): my_task = MyTask(rebind=('a', 'b', 'c')) @@ -219,29 +227,29 @@ class TaskTest(test.TestCase): MyTask, rebind=('a', 'b', 'c', 'd')) def test_rebind_list_more_kwargs(self): - task = KwargsTask(rebind=('a', 'b', 'c')) + my_task = KwargsTask(rebind=('a', 'b', 'c')) expected = { 'spam': 'a', 'b': 'b', 'c': 'c' } - self.assertEqual(expected, task.rebind) + self.assertEqual(expected, my_task.rebind) self.assertEqual(set(['a', 'b', 'c']), - task.requires) + my_task.requires) def test_rebind_list_bad_value(self): self.assertRaisesRegexp(TypeError, '^Invalid rebind value', MyTask, rebind=object()) def test_default_provides(self): - task = DefaultProvidesTask() - self.assertEqual(set(['def']), task.provides) - self.assertEqual({'def': None}, task.save_as) + my_task = DefaultProvidesTask() + self.assertEqual(set(['def']), my_task.provides) + self.assertEqual({'def': None}, my_task.save_as) def test_default_provides_can_be_overridden(self): - task = DefaultProvidesTask(provides=('spam', 'eggs')) - self.assertEqual(set(['spam', 'eggs']), task.provides) - self.assertEqual({'spam': 0, 'eggs': 1}, task.save_as) + my_task = DefaultProvidesTask(provides=('spam', 'eggs')) + self.assertEqual(set(['spam', 'eggs']), my_task.provides) + self.assertEqual({'spam': 0, 'eggs': 1}, my_task.save_as) def test_update_progress_within_bounds(self): values = [0.0, 0.5, 1.0] @@ -355,24 +363,34 @@ class TaskTest(test.TestCase): self.assertEqual(0, len(a_task.notifier)) def test_separate_revert_args(self): - task = SeparateRevertTask(rebind=('a',), revert_rebind=('b',)) - self.assertEqual({'execute_arg': 'a'}, task.rebind) - self.assertEqual({'revert_arg': 'b'}, task.revert_rebind) + my_task = SeparateRevertTask(rebind=('a',), revert_rebind=('b',)) + self.assertEqual({'execute_arg': 'a'}, my_task.rebind) + self.assertEqual({'revert_arg': 'b'}, my_task.revert_rebind) self.assertEqual(set(['a', 'b']), - task.requires) + my_task.requires) - task = SeparateRevertTask(requires='execute_arg', - revert_requires='revert_arg') + my_task = SeparateRevertTask(requires='execute_arg', + revert_requires='revert_arg') - self.assertEqual({'execute_arg': 'execute_arg'}, task.rebind) - self.assertEqual({'revert_arg': 'revert_arg'}, task.revert_rebind) + self.assertEqual({'execute_arg': 'execute_arg'}, my_task.rebind) + self.assertEqual({'revert_arg': 'revert_arg'}, my_task.revert_rebind) self.assertEqual(set(['execute_arg', 'revert_arg']), - task.requires) + my_task.requires) def test_separate_revert_optional_args(self): - task = SeparateRevertOptionalTask() - self.assertEqual(set(['execute_arg']), task.optional) - self.assertEqual(set(['revert_arg']), task.revert_optional) + my_task = SeparateRevertOptionalTask() + self.assertEqual(set(['execute_arg']), my_task.optional) + self.assertEqual(set(['revert_arg']), my_task.revert_optional) + + def test_revert_kwargs(self): + my_task = RevertKwargsTask() + expected_rebind = {'execute_arg1': 'execute_arg1', + 'execute_arg2': 'execute_arg2'} + self.assertEqual(expected_rebind, my_task.rebind) + expected_rebind = {'execute_arg1': 'execute_arg1'} + self.assertEqual(expected_rebind, my_task.revert_rebind) + self.assertEqual(set(['execute_arg1', 'execute_arg2']), + my_task.requires) class FunctorTaskTest(test.TestCase):