From 46fc1dd9ee0ad40d3f9bf4c589061566d745294a Mon Sep 17 00:00:00 2001 From: Greg Hill Date: Thu, 22 May 2014 17:23:57 -0500 Subject: [PATCH] add the ability to inject arguments into tasks at task creation Similar to the rebind functionality that lets you rename parameters from those in the store, inject lets you inject arbitrary key/value pairs that will be sent to your task at task creation time. This allows for flow and flow factories to reuse tasks with differing parameters without jumping through a lot of hoops. Change-Id: If167962811d22054b89d7d35a33d4ec5cb2cd648 Implements: blueprint wbe-workers-endpoints-constructors --- taskflow/atom.py | 9 ++++-- taskflow/engines/action_engine/engine.py | 3 ++ .../engines/action_engine/retry_action.py | 3 +- taskflow/engines/action_engine/task_action.py | 6 ++-- taskflow/storage.py | 19 ++++++++++-- taskflow/task.py | 13 ++++---- taskflow/tests/unit/test_arguments_passing.py | 30 +++++++++++++++++++ 7 files changed, 69 insertions(+), 14 deletions(-) diff --git a/taskflow/atom.py b/taskflow/atom.py index d4840e50..e3ed8b34 100644 --- a/taskflow/atom.py +++ b/taskflow/atom.py @@ -88,6 +88,7 @@ def _build_arg_mapping(task_name, reqs, rebind_args, function, do_infer, for arg in ignore_list: if arg in task_args: task_args.remove(arg) + result = {} if reqs: result.update((a, a) for a in reqs) @@ -135,10 +136,11 @@ class Atom(object): of this atom). """ - def __init__(self, name=None, provides=None): + def __init__(self, name=None, provides=None, inject=None): self._name = name self.save_as = _save_as_to_mapping(provides) self.version = (1, 0) + self.inject = inject def _build_arg_mapping(self, executor, requires=None, rebind=None, auto_extract=True, ignore_list=None): @@ -180,4 +182,7 @@ class Atom(object): requires and what it produces (since this would be an impossible dependency to satisfy). """ - return set(self.rebind.values()) + requires = set(self.rebind.values()) + if self.inject: + requires = requires - set(six.iterkeys(self.inject)) + return requires diff --git a/taskflow/engines/action_engine/engine.py b/taskflow/engines/action_engine/engine.py index eecba801..3291024a 100644 --- a/taskflow/engines/action_engine/engine.py +++ b/taskflow/engines/action_engine/engine.py @@ -172,6 +172,9 @@ class ActionEngine(base.EngineBase): self.storage.ensure_retry(node.name, version, node.save_as) else: self.storage.ensure_task(node.name, version, node.save_as) + if node.inject: + self.storage.inject_task_args(node.name, node.inject) + self._change_state(states.SUSPENDED) # does nothing in PENDING state @lock_utils.locked diff --git a/taskflow/engines/action_engine/retry_action.py b/taskflow/engines/action_engine/retry_action.py index a860f698..eaedf04b 100644 --- a/taskflow/engines/action_engine/retry_action.py +++ b/taskflow/engines/action_engine/retry_action.py @@ -33,7 +33,8 @@ class RetryAction(object): self._notifier = notifier def _get_retry_args(self, retry): - kwargs = self._storage.fetch_mapped_args(retry.rebind) + kwargs = self._storage.fetch_mapped_args(retry.rebind, + task_name=retry.name) kwargs['history'] = self._storage.get_retry_history(retry.name) return kwargs diff --git a/taskflow/engines/action_engine/task_action.py b/taskflow/engines/action_engine/task_action.py index 32c0a179..9ab8c460 100644 --- a/taskflow/engines/action_engine/task_action.py +++ b/taskflow/engines/action_engine/task_action.py @@ -65,7 +65,8 @@ class TaskAction(object): if not self.change_state(task, states.RUNNING, progress=0.0): raise exceptions.InvalidState("Task %s is in invalid state and" " can't be executed" % task.name) - kwargs = self._storage.fetch_mapped_args(task.rebind) + kwargs = self._storage.fetch_mapped_args(task.rebind, + task_name=task.name) task_uuid = self._storage.get_atom_uuid(task.name) return self._task_executor.execute_task(task, task_uuid, kwargs, self._on_update_progress) @@ -81,7 +82,8 @@ class TaskAction(object): if not self.change_state(task, states.REVERTING, progress=0.0): raise exceptions.InvalidState("Task %s is in invalid state and" " can't be reverted" % task.name) - kwargs = self._storage.fetch_mapped_args(task.rebind) + kwargs = self._storage.fetch_mapped_args(task.rebind, + task_name=task.name) task_uuid = self._storage.get_atom_uuid(task.name) task_result = self._storage.get(task.name) failures = self._storage.get_failures() diff --git a/taskflow/storage.py b/taskflow/storage.py index 35ba7d0e..e3a208a4 100644 --- a/taskflow/storage.py +++ b/taskflow/storage.py @@ -52,6 +52,7 @@ class Storage(object): self._flowdetail = flow_detail self._lock = self._lock_cls() self._transients = {} + self._injected_args = {} # NOTE(imelnikov): failure serialization looses information, # so we cache failures here, in atom name -> failure mapping. @@ -410,6 +411,10 @@ class Storage(object): if self._reset_atom(ad, state): self._with_connection(self._save_atom_detail, ad) + def inject_task_args(self, task_name, injected_args): + self._injected_args.setdefault(task_name, {}) + self._injected_args[task_name].update(injected_args) + def inject(self, pairs, transient=False): """Add values into storage. @@ -516,11 +521,19 @@ class Storage(object): pass return results - def fetch_mapped_args(self, args_mapping): + def fetch_mapped_args(self, args_mapping, task_name=None): """Fetch arguments for an atom using an atoms arguments mapping.""" with self._lock.read_lock(): - return dict((key, self.fetch(name)) - for key, name in six.iteritems(args_mapping)) + injected_args = {} + if task_name: + injected_args = self._injected_args.get(task_name, {}) + mapped_args = {} + for key, name in six.iteritems(args_mapping): + if name in injected_args: + mapped_args[key] = injected_args[name] + else: + mapped_args[key] = self.fetch(name) + return mapped_args def set_flow_state(self, state): """Set flow details state and save it.""" diff --git a/taskflow/task.py b/taskflow/task.py index 9f68710d..e66b435c 100644 --- a/taskflow/task.py +++ b/taskflow/task.py @@ -36,10 +36,10 @@ class BaseTask(atom.Atom): TASK_EVENTS = ('update_progress', ) - def __init__(self, name, provides=None): + def __init__(self, name, provides=None, inject=None): if name is None: name = reflection.get_class_name(self) - super(BaseTask, self).__init__(name, provides) + super(BaseTask, self).__init__(name, provides, inject=inject) # Map of events => lists of callbacks to invoke on task events. self._events_listeners = collections.defaultdict(list) @@ -172,11 +172,11 @@ class Task(BaseTask): default_provides = None def __init__(self, name=None, provides=None, requires=None, - auto_extract=True, rebind=None): + auto_extract=True, rebind=None, inject=None): """Initialize task instance.""" if provides is None: provides = self.default_provides - super(Task, self).__init__(name, provides=provides) + super(Task, self).__init__(name, provides=provides, inject=inject) self._build_arg_mapping(self.execute, requires, rebind, auto_extract) @@ -188,7 +188,7 @@ class FunctorTask(BaseTask): def __init__(self, execute, name=None, provides=None, requires=None, auto_extract=True, rebind=None, revert=None, - version=None): + version=None, inject=None): assert six.callable(execute), ("Function to use for executing must be" " callable") if revert: @@ -196,7 +196,8 @@ class FunctorTask(BaseTask): " be callable") if name is None: name = reflection.get_callable_name(execute) - super(FunctorTask, self).__init__(name, provides=provides) + super(FunctorTask, self).__init__(name, provides=provides, + inject=inject) self._execute = execute self._revert = revert if version is not None: diff --git a/taskflow/tests/unit/test_arguments_passing.py b/taskflow/tests/unit/test_arguments_passing.py index 4e8d5bb6..0281c1ff 100644 --- a/taskflow/tests/unit/test_arguments_passing.py +++ b/taskflow/tests/unit/test_arguments_passing.py @@ -90,6 +90,36 @@ class ArgumentsPassingTest(utils.EngineTestBase): 'result': 30, }) + def test_argument_injection(self): + flow = utils.TaskMultiArgOneReturn(provides='result', + inject={'x': 1, 'y': 4, 'z': 9}) + engine = self._make_engine(flow) + engine.run() + self.assertEqual(engine.storage.fetch_all(), { + 'result': 14, + }) + + def test_argument_injection_rebind(self): + flow = utils.TaskMultiArgOneReturn(provides='result', + rebind=['a', 'b', 'c'], + inject={'a': 1, 'b': 4, 'c': 9}) + engine = self._make_engine(flow) + engine.run() + self.assertEqual(engine.storage.fetch_all(), { + 'result': 14, + }) + + def test_argument_injection_required(self): + flow = utils.TaskMultiArgOneReturn(provides='result', + requires=['a', 'b', 'c'], + inject={'x': 1, 'y': 4, 'z': 9, + 'a': 0, 'b': 0, 'c': 0}) + engine = self._make_engine(flow) + engine.run() + self.assertEqual(engine.storage.fetch_all(), { + 'result': 14, + }) + def test_all_arguments_mapping(self): flow = utils.TaskMultiArgOneReturn(provides='result', rebind=['a', 'b', 'c'])