diff --git a/taskflow/atom.py b/taskflow/atom.py index d4840e50..e3ed8b34 100644 --- a/taskflow/atom.py +++ b/taskflow/atom.py @@ -88,6 +88,7 @@ def _build_arg_mapping(task_name, reqs, rebind_args, function, do_infer, for arg in ignore_list: if arg in task_args: task_args.remove(arg) + result = {} if reqs: result.update((a, a) for a in reqs) @@ -135,10 +136,11 @@ class Atom(object): of this atom). """ - def __init__(self, name=None, provides=None): + def __init__(self, name=None, provides=None, inject=None): self._name = name self.save_as = _save_as_to_mapping(provides) self.version = (1, 0) + self.inject = inject def _build_arg_mapping(self, executor, requires=None, rebind=None, auto_extract=True, ignore_list=None): @@ -180,4 +182,7 @@ class Atom(object): requires and what it produces (since this would be an impossible dependency to satisfy). """ - return set(self.rebind.values()) + requires = set(self.rebind.values()) + if self.inject: + requires = requires - set(six.iterkeys(self.inject)) + return requires diff --git a/taskflow/engines/action_engine/engine.py b/taskflow/engines/action_engine/engine.py index eecba801..3291024a 100644 --- a/taskflow/engines/action_engine/engine.py +++ b/taskflow/engines/action_engine/engine.py @@ -172,6 +172,9 @@ class ActionEngine(base.EngineBase): self.storage.ensure_retry(node.name, version, node.save_as) else: self.storage.ensure_task(node.name, version, node.save_as) + if node.inject: + self.storage.inject_task_args(node.name, node.inject) + self._change_state(states.SUSPENDED) # does nothing in PENDING state @lock_utils.locked diff --git a/taskflow/engines/action_engine/retry_action.py b/taskflow/engines/action_engine/retry_action.py index a860f698..eaedf04b 100644 --- a/taskflow/engines/action_engine/retry_action.py +++ b/taskflow/engines/action_engine/retry_action.py @@ -33,7 +33,8 @@ class RetryAction(object): self._notifier = notifier def _get_retry_args(self, retry): - kwargs = self._storage.fetch_mapped_args(retry.rebind) + kwargs = self._storage.fetch_mapped_args(retry.rebind, + task_name=retry.name) kwargs['history'] = self._storage.get_retry_history(retry.name) return kwargs diff --git a/taskflow/engines/action_engine/task_action.py b/taskflow/engines/action_engine/task_action.py index 32c0a179..9ab8c460 100644 --- a/taskflow/engines/action_engine/task_action.py +++ b/taskflow/engines/action_engine/task_action.py @@ -65,7 +65,8 @@ class TaskAction(object): if not self.change_state(task, states.RUNNING, progress=0.0): raise exceptions.InvalidState("Task %s is in invalid state and" " can't be executed" % task.name) - kwargs = self._storage.fetch_mapped_args(task.rebind) + kwargs = self._storage.fetch_mapped_args(task.rebind, + task_name=task.name) task_uuid = self._storage.get_atom_uuid(task.name) return self._task_executor.execute_task(task, task_uuid, kwargs, self._on_update_progress) @@ -81,7 +82,8 @@ class TaskAction(object): if not self.change_state(task, states.REVERTING, progress=0.0): raise exceptions.InvalidState("Task %s is in invalid state and" " can't be reverted" % task.name) - kwargs = self._storage.fetch_mapped_args(task.rebind) + kwargs = self._storage.fetch_mapped_args(task.rebind, + task_name=task.name) task_uuid = self._storage.get_atom_uuid(task.name) task_result = self._storage.get(task.name) failures = self._storage.get_failures() diff --git a/taskflow/storage.py b/taskflow/storage.py index 35ba7d0e..e3a208a4 100644 --- a/taskflow/storage.py +++ b/taskflow/storage.py @@ -52,6 +52,7 @@ class Storage(object): self._flowdetail = flow_detail self._lock = self._lock_cls() self._transients = {} + self._injected_args = {} # NOTE(imelnikov): failure serialization looses information, # so we cache failures here, in atom name -> failure mapping. @@ -410,6 +411,10 @@ class Storage(object): if self._reset_atom(ad, state): self._with_connection(self._save_atom_detail, ad) + def inject_task_args(self, task_name, injected_args): + self._injected_args.setdefault(task_name, {}) + self._injected_args[task_name].update(injected_args) + def inject(self, pairs, transient=False): """Add values into storage. @@ -516,11 +521,19 @@ class Storage(object): pass return results - def fetch_mapped_args(self, args_mapping): + def fetch_mapped_args(self, args_mapping, task_name=None): """Fetch arguments for an atom using an atoms arguments mapping.""" with self._lock.read_lock(): - return dict((key, self.fetch(name)) - for key, name in six.iteritems(args_mapping)) + injected_args = {} + if task_name: + injected_args = self._injected_args.get(task_name, {}) + mapped_args = {} + for key, name in six.iteritems(args_mapping): + if name in injected_args: + mapped_args[key] = injected_args[name] + else: + mapped_args[key] = self.fetch(name) + return mapped_args def set_flow_state(self, state): """Set flow details state and save it.""" diff --git a/taskflow/task.py b/taskflow/task.py index 9f68710d..e66b435c 100644 --- a/taskflow/task.py +++ b/taskflow/task.py @@ -36,10 +36,10 @@ class BaseTask(atom.Atom): TASK_EVENTS = ('update_progress', ) - def __init__(self, name, provides=None): + def __init__(self, name, provides=None, inject=None): if name is None: name = reflection.get_class_name(self) - super(BaseTask, self).__init__(name, provides) + super(BaseTask, self).__init__(name, provides, inject=inject) # Map of events => lists of callbacks to invoke on task events. self._events_listeners = collections.defaultdict(list) @@ -172,11 +172,11 @@ class Task(BaseTask): default_provides = None def __init__(self, name=None, provides=None, requires=None, - auto_extract=True, rebind=None): + auto_extract=True, rebind=None, inject=None): """Initialize task instance.""" if provides is None: provides = self.default_provides - super(Task, self).__init__(name, provides=provides) + super(Task, self).__init__(name, provides=provides, inject=inject) self._build_arg_mapping(self.execute, requires, rebind, auto_extract) @@ -188,7 +188,7 @@ class FunctorTask(BaseTask): def __init__(self, execute, name=None, provides=None, requires=None, auto_extract=True, rebind=None, revert=None, - version=None): + version=None, inject=None): assert six.callable(execute), ("Function to use for executing must be" " callable") if revert: @@ -196,7 +196,8 @@ class FunctorTask(BaseTask): " be callable") if name is None: name = reflection.get_callable_name(execute) - super(FunctorTask, self).__init__(name, provides=provides) + super(FunctorTask, self).__init__(name, provides=provides, + inject=inject) self._execute = execute self._revert = revert if version is not None: diff --git a/taskflow/tests/unit/test_arguments_passing.py b/taskflow/tests/unit/test_arguments_passing.py index 4e8d5bb6..0281c1ff 100644 --- a/taskflow/tests/unit/test_arguments_passing.py +++ b/taskflow/tests/unit/test_arguments_passing.py @@ -90,6 +90,36 @@ class ArgumentsPassingTest(utils.EngineTestBase): 'result': 30, }) + def test_argument_injection(self): + flow = utils.TaskMultiArgOneReturn(provides='result', + inject={'x': 1, 'y': 4, 'z': 9}) + engine = self._make_engine(flow) + engine.run() + self.assertEqual(engine.storage.fetch_all(), { + 'result': 14, + }) + + def test_argument_injection_rebind(self): + flow = utils.TaskMultiArgOneReturn(provides='result', + rebind=['a', 'b', 'c'], + inject={'a': 1, 'b': 4, 'c': 9}) + engine = self._make_engine(flow) + engine.run() + self.assertEqual(engine.storage.fetch_all(), { + 'result': 14, + }) + + def test_argument_injection_required(self): + flow = utils.TaskMultiArgOneReturn(provides='result', + requires=['a', 'b', 'c'], + inject={'x': 1, 'y': 4, 'z': 9, + 'a': 0, 'b': 0, 'c': 0}) + engine = self._make_engine(flow) + engine.run() + self.assertEqual(engine.storage.fetch_all(), { + 'result': 14, + }) + def test_all_arguments_mapping(self): flow = utils.TaskMultiArgOneReturn(provides='result', rebind=['a', 'b', 'c'])