diff --git a/taskflow/task.py b/taskflow/task.py index 073adbff..1796ba9b 100644 --- a/taskflow/task.py +++ b/taskflow/task.py @@ -204,13 +204,19 @@ class Task(BaseTask): Adds following features to Task: - auto-generates name from type of self - adds all execute argument names to task requirements + - items provided by the task may be specified via + 'default_provides' class attribute or property """ + default_provides = None + def __init__(self, name=None, provides=None, requires=None, auto_extract=True, rebind=None): """Initialize task instance""" if name is None: name = reflection.get_callable_name(self) + if provides is None: + provides = self.default_provides super(Task, self).__init__(name, provides=provides) self.rebind = _build_arg_mapping(self.name, requires, rebind, diff --git a/taskflow/tests/unit/test_task.py b/taskflow/tests/unit/test_task.py index 5c45659e..e661aa40 100644 --- a/taskflow/tests/unit/test_task.py +++ b/taskflow/tests/unit/test_task.py @@ -31,6 +31,13 @@ class KwargsTask(task.Task): pass +class DefaultProvidesTask(task.Task): + default_provides = 'def' + + def execute(self): + return None + + class TaskTestCase(test.TestCase): def test_passed_name(self): @@ -149,3 +156,13 @@ class TaskTestCase(test.TestCase): def test_rebind_list_bad_value(self): with self.assertRaisesRegexp(TypeError, '^Invalid rebind value:'): MyTask(rebind=object()) + + def test_default_provides(self): + task = DefaultProvidesTask() + self.assertEquals(task.provides, set(['def'])) + self.assertEquals(task.save_as, {'def': None}) + + def test_default_provides_can_be_overridden(self): + task = DefaultProvidesTask(provides=('spam', 'eggs')) + self.assertEquals(task.provides, set(['spam', 'eggs'])) + self.assertEquals(task.save_as, {'spam': 0, 'eggs': 1})