diff --git a/doc/source/arguments_and_results.rst b/doc/source/arguments_and_results.rst index cb2c87610..998db889a 100644 --- a/doc/source/arguments_and_results.rst +++ b/doc/source/arguments_and_results.rst @@ -22,9 +22,12 @@ are and how to use those ways to accomplish your desired usage pattern. Task/retry arguments Set of names of task/retry arguments available as the ``requires`` - property of the task/retry instance. When a task or retry object is - about to be executed values with these names are retrieved from storage - and passed to the ``execute`` method of the task/retry. + and/or ``optional`` property of the task/retry instance. When a task or + retry object is about to be executed values with these names are + retrieved from storage and passed to the ``execute`` method of the + task/retry. If any names in the ``requires`` property cannot be + found in storage, an exception will be thrown. Any names in the + ``optional`` property that cannot be found are ignored. Task/retry results Set of names of task/retry results (what task/retry provides) available @@ -53,32 +56,26 @@ method of a task (or the |retry.execute| of a retry object). .. doctest:: >>> class MyTask(task.Task): - ... def execute(self, spam, eggs): + ... def execute(self, spam, eggs, bacon=None): ... return spam + eggs ... >>> sorted(MyTask().requires) ['eggs', 'spam'] + >>> sorted(MyTask().optional) + ['bacon'] Inference from the method signature is the ''simplest'' way to specify -arguments. Optional arguments (with default values), and special arguments like -``self``, ``*args`` and ``**kwargs`` are ignored during inference (as these -names have special meaning/usage in python). +arguments. Special arguments like ``self``, ``*args`` and ``**kwargs`` are +ignored during inference (as these names have special meaning/usage in python). .. doctest:: - >>> class MyTask(task.Task): - ... def execute(self, spam, eggs=()): - ... return spam + eggs - ... - >>> MyTask().requires - set(['spam']) - >>> >>> class UniTask(task.Task): ... def execute(self, *args, **kwargs): ... pass ... >>> UniTask().requires - set([]) + frozenset([]) .. make vim sphinx highlighter* happy** diff --git a/taskflow/atom.py b/taskflow/atom.py index d236ff902..3ece83fc9 100644 --- a/taskflow/atom.py +++ b/taskflow/atom.py @@ -82,33 +82,50 @@ def _build_arg_mapping(atom_name, reqs, rebind_args, function, do_infer, well as verify that the final argument mapping does not have missing or extra arguments (where applicable). """ - atom_args = reflection.get_callable_args(function, required_only=True) + + # build a list of required arguments based on function signature + req_args = reflection.get_callable_args(function, required_only=True) + all_args = reflection.get_callable_args(function, required_only=False) + + # remove arguments that are part of ignore_list if ignore_list: for arg in ignore_list: - if arg in atom_args: - atom_args.remove(arg) + if arg in req_args: + req_args.remove(arg) + else: + ignore_list = [] - result = {} + required = {} + # add reqs to required mappings if reqs: - result.update((a, a) for a in reqs) + required.update((a, a) for a in reqs) + + # add req_args to required mappings if do_infer is set if do_infer: - result.update((a, a) for a in atom_args) - result.update(_build_rebind_dict(atom_args, rebind_args)) + required.update((a, a) for a in req_args) + + # update required mappings based on rebind_args + required.update(_build_rebind_dict(req_args, rebind_args)) + + if do_infer: + opt_args = set(all_args) - set(required) - set(ignore_list) + optional = dict((a, a) for a in opt_args) + else: + optional = {} if not reflection.accepts_kwargs(function): - all_args = reflection.get_callable_args(function, required_only=False) - extra_args = set(result) - set(all_args) + extra_args = set(required) - set(all_args) if extra_args: extra_args_str = ', '.join(sorted(extra_args)) raise ValueError('Extra arguments given to atom %s: %s' % (atom_name, extra_args_str)) # NOTE(imelnikov): don't use set to preserve order in error message - missing_args = [arg for arg in atom_args if arg not in result] + missing_args = [arg for arg in req_args if arg not in required] if missing_args: raise ValueError('Missing arguments for atom %s: %s' % (atom_name, ' ,'.join(missing_args))) - return result + return required, optional class Atom(object): @@ -146,6 +163,13 @@ class Atom(object): commences (this allows for providing atom *local* values that do not need to be provided by other atoms/dependents). :ivar inject: See parameter ``inject``. + :ivar requires: Any inputs this atom requires to function (if applicable). + NOTE(harlowja): there can be no intersection between what + this atom requires and what it produces (since this would + be an impossible dependency to satisfy). + :ivar optional: Any inputs that are optional for this atom's execute + method. + """ def __init__(self, name=None, provides=None, inject=None): @@ -153,11 +177,27 @@ class Atom(object): self.save_as = _save_as_to_mapping(provides) self.version = (1, 0) self.inject = inject + self.requires = frozenset() + self.optional = frozenset() def _build_arg_mapping(self, executor, requires=None, rebind=None, auto_extract=True, ignore_list=None): - self.rebind = _build_arg_mapping(self.name, requires, rebind, - executor, auto_extract, ignore_list) + req_arg, opt_arg = _build_arg_mapping(self.name, requires, rebind, + executor, auto_extract, + ignore_list) + + self.rebind = {} + if opt_arg: + self.rebind.update(opt_arg) + if req_arg: + self.rebind.update(req_arg) + self.requires = frozenset(req_arg.values()) + self.optional = frozenset(opt_arg.values()) + if self.inject: + inject_set = set(six.iterkeys(self.inject)) + self.requires -= inject_set + self.optional -= inject_set + out_of_order = self.provides.intersection(self.requires) if out_of_order: raise exceptions.DependencyFailure( @@ -185,16 +225,3 @@ class Atom(object): dependency to satisfy). """ return set(self.save_as) - - @property - def requires(self): - """Any inputs this atom requires to function (if applicable). - - NOTE(harlowja): there can be no intersection between what this atom - requires and what it produces (since this would be an impossible - dependency to satisfy). - """ - requires = set(self.rebind.values()) - if self.inject: - requires = requires - set(six.iterkeys(self.inject)) - return requires diff --git a/taskflow/engines/action_engine/actions/retry.py b/taskflow/engines/action_engine/actions/retry.py index bd96c8992..05496d96c 100644 --- a/taskflow/engines/action_engine/actions/retry.py +++ b/taskflow/engines/action_engine/actions/retry.py @@ -54,9 +54,12 @@ class RetryAction(base.Action): def _get_retry_args(self, retry, addons=None): scope_walker = self._walker_factory(retry) - arguments = self._storage.fetch_mapped_args(retry.rebind, - atom_name=retry.name, - scope_walker=scope_walker) + arguments = self._storage.fetch_mapped_args( + retry.rebind, + atom_name=retry.name, + scope_walker=scope_walker, + optional_args=retry.optional + ) history = self._storage.get_retry_history(retry.name) arguments[retry_atom.EXECUTE_REVERT_HISTORY] = history if addons: diff --git a/taskflow/engines/action_engine/actions/task.py b/taskflow/engines/action_engine/actions/task.py index 607b26d5d..8c64931ab 100644 --- a/taskflow/engines/action_engine/actions/task.py +++ b/taskflow/engines/action_engine/actions/task.py @@ -101,9 +101,12 @@ class TaskAction(base.Action): def schedule_execution(self, task): self.change_state(task, states.RUNNING, progress=0.0) scope_walker = self._walker_factory(task) - arguments = self._storage.fetch_mapped_args(task.rebind, - atom_name=task.name, - scope_walker=scope_walker) + arguments = self._storage.fetch_mapped_args( + task.rebind, + atom_name=task.name, + scope_walker=scope_walker, + optional_args=task.optional + ) if task.notifier.can_be_registered(task_atom.EVENT_UPDATE_PROGRESS): progress_callback = functools.partial(self._on_update_progress, task) @@ -124,9 +127,12 @@ class TaskAction(base.Action): def schedule_reversion(self, task): self.change_state(task, states.REVERTING, progress=0.0) scope_walker = self._walker_factory(task) - arguments = self._storage.fetch_mapped_args(task.rebind, - atom_name=task.name, - scope_walker=scope_walker) + arguments = self._storage.fetch_mapped_args( + task.rebind, + atom_name=task.name, + scope_walker=scope_walker, + optional_args=task.optional + ) task_uuid = self._storage.get_atom_uuid(task.name) task_result = self._storage.get(task.name) failures = self._storage.get_failures() diff --git a/taskflow/examples/optional_arguments.py b/taskflow/examples/optional_arguments.py new file mode 100644 index 000000000..66a4d380d --- /dev/null +++ b/taskflow/examples/optional_arguments.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2015 Hewlett-Packard Development Company, L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +from taskflow import engines +from taskflow.patterns import linear_flow +from taskflow import task + + +class TestTask(task.Task): + def execute(self, a, b=5): + result = a * b + return result + +flow_no_inject = linear_flow.Flow("flow").add(TestTask(provides='result')) +flow_inject_a = linear_flow.Flow("flow").add(TestTask(provides='result', + inject={'a': 10})) +flow_inject_b = linear_flow.Flow("flow").add(TestTask(provides='result', + inject={'b': 1000})) + +ASSERT = True + + +class MyTest(unittest.TestCase): + def test_my_test(self): + print("Expected result = 15") + result = engines.run(flow_no_inject, store={'a': 3}) + print(result) + if ASSERT: + self.assertEqual(result, + {'a': 3, 'result': 15} + ) + + print("Expected result = 39") + result = engines.run(flow_no_inject, store={'a': 3, 'b': 7}) + print(result) + if ASSERT: + self.assertEqual( + result, + {'a': 3, 'b': 7, 'result': 21} + ) + + print("Expected result = 200") + result = engines.run(flow_inject_a, store={'a': 3}) + print(result) + if ASSERT: + self.assertEqual( + result, + {'a': 3, 'result': 50} + ) + + print("Expected result = 400") + result = engines.run(flow_inject_a, store={'a': 3, 'b': 7}) + print(result) + if ASSERT: + self.assertEqual( + result, + {'a': 3, 'b': 7, 'result': 70} + ) + + print("Expected result = 40") + result = engines.run(flow_inject_b, store={'a': 3}) + print(result) + if ASSERT: + self.assertEqual( + result, + {'a': 3, 'result': 3000} + ) + + print("Expected result = 40") + result = engines.run(flow_inject_b, store={'a': 3, 'b': 7}) + print(result) + if ASSERT: + self.assertEqual( + result, + {'a': 3, 'b': 7, 'result': 3000} + ) + +if __name__ == '__main__': + unittest.main() diff --git a/taskflow/storage.py b/taskflow/storage.py index 8734b2a0d..e96874acb 100644 --- a/taskflow/storage.py +++ b/taskflow/storage.py @@ -629,7 +629,8 @@ class Storage(object): return results def fetch_mapped_args(self, args_mapping, - atom_name=None, scope_walker=None): + atom_name=None, scope_walker=None, + optional_args=None): """Fetch arguments for an atom using an atoms argument mapping.""" def _get_results(looking_for, provider): @@ -668,10 +669,14 @@ class Storage(object): return [] with self._lock.read_lock(): + if optional_args is None: + optional_args = [] + if atom_name and atom_name not in self._atom_name_to_uuid: raise exceptions.NotFound("Unknown atom name: %s" % atom_name) if not args_mapping: return {} + # The order of lookup is the following: # # 1. Injected atom specific arguments. @@ -705,6 +710,8 @@ class Storage(object): try: possible_providers = self._reverse_mapping[name] except KeyError: + if bound_name in optional_args: + continue raise exceptions.NotFound("Name %r is not mapped as a" " produced output by any" " providers" % name) diff --git a/taskflow/tests/unit/test_storage.py b/taskflow/tests/unit/test_storage.py index af6afb2e4..a27a811cb 100644 --- a/taskflow/tests/unit/test_storage.py +++ b/taskflow/tests/unit/test_storage.py @@ -351,6 +351,20 @@ class StorageTestMixin(object): self.assertRaises(exceptions.NotFound, s.fetch_mapped_args, {'viking': 'helmet'}) + def test_fetch_optional_args_found(self): + s = self._get_storage() + s.inject({'foo': 'bar', 'spam': 'eggs'}) + self.assertEqual(s.fetch_mapped_args({'viking': 'spam'}, + optional_args=set(['viking'])), + {'viking': 'eggs'}) + + def test_fetch_optional_args_not_found(self): + s = self._get_storage() + s.inject({'foo': 'bar', 'spam': 'eggs'}) + self.assertEqual(s.fetch_mapped_args({'viking': 'helmet'}, + optional_args=set(['viking'])), + {}) + def test_set_and_get_task_state(self): s = self._get_storage() state = states.PENDING diff --git a/taskflow/tests/unit/test_task.py b/taskflow/tests/unit/test_task.py index 50e783f33..9a9ae1c90 100644 --- a/taskflow/tests/unit/test_task.py +++ b/taskflow/tests/unit/test_task.py @@ -120,10 +120,19 @@ class TaskTest(test.TestCase): def test_requires_ignores_optional(self): my_task = DefaultArgTask() self.assertEqual(my_task.requires, set(['spam'])) + self.assertEqual(my_task.optional, set(['eggs'])) def test_requires_allows_optional(self): my_task = DefaultArgTask(requires=('spam', 'eggs')) self.assertEqual(my_task.requires, set(['spam', 'eggs'])) + self.assertEqual(my_task.optional, set()) + + def test_rebind_includes_optional(self): + my_task = DefaultArgTask() + self.assertEqual(my_task.rebind, { + 'spam': 'spam', + 'eggs': 'eggs', + }) def test_rebind_all_args(self): my_task = MyTask(rebind={'spam': 'a', 'eggs': 'b', 'context': 'c'})