Merge "add the ability to inject arguments into tasks at task creation"
This commit is contained in:
@@ -88,6 +88,7 @@ def _build_arg_mapping(task_name, reqs, rebind_args, function, do_infer,
|
||||
for arg in ignore_list:
|
||||
if arg in task_args:
|
||||
task_args.remove(arg)
|
||||
|
||||
result = {}
|
||||
if reqs:
|
||||
result.update((a, a) for a in reqs)
|
||||
@@ -135,10 +136,11 @@ class Atom(object):
|
||||
of this atom).
|
||||
"""
|
||||
|
||||
def __init__(self, name=None, provides=None):
|
||||
def __init__(self, name=None, provides=None, inject=None):
|
||||
self._name = name
|
||||
self.save_as = _save_as_to_mapping(provides)
|
||||
self.version = (1, 0)
|
||||
self.inject = inject
|
||||
|
||||
def _build_arg_mapping(self, executor, requires=None, rebind=None,
|
||||
auto_extract=True, ignore_list=None):
|
||||
@@ -180,4 +182,7 @@ class Atom(object):
|
||||
requires and what it produces (since this would be an impossible
|
||||
dependency to satisfy).
|
||||
"""
|
||||
return set(self.rebind.values())
|
||||
requires = set(self.rebind.values())
|
||||
if self.inject:
|
||||
requires = requires - set(six.iterkeys(self.inject))
|
||||
return requires
|
||||
|
||||
@@ -172,6 +172,9 @@ class ActionEngine(base.EngineBase):
|
||||
self.storage.ensure_retry(node.name, version, node.save_as)
|
||||
else:
|
||||
self.storage.ensure_task(node.name, version, node.save_as)
|
||||
if node.inject:
|
||||
self.storage.inject_task_args(node.name, node.inject)
|
||||
|
||||
self._change_state(states.SUSPENDED) # does nothing in PENDING state
|
||||
|
||||
@lock_utils.locked
|
||||
|
||||
@@ -33,7 +33,8 @@ class RetryAction(object):
|
||||
self._notifier = notifier
|
||||
|
||||
def _get_retry_args(self, retry):
|
||||
kwargs = self._storage.fetch_mapped_args(retry.rebind)
|
||||
kwargs = self._storage.fetch_mapped_args(retry.rebind,
|
||||
task_name=retry.name)
|
||||
kwargs['history'] = self._storage.get_retry_history(retry.name)
|
||||
return kwargs
|
||||
|
||||
|
||||
@@ -65,7 +65,8 @@ class TaskAction(object):
|
||||
if not self.change_state(task, states.RUNNING, progress=0.0):
|
||||
raise exceptions.InvalidState("Task %s is in invalid state and"
|
||||
" can't be executed" % task.name)
|
||||
kwargs = self._storage.fetch_mapped_args(task.rebind)
|
||||
kwargs = self._storage.fetch_mapped_args(task.rebind,
|
||||
task_name=task.name)
|
||||
task_uuid = self._storage.get_atom_uuid(task.name)
|
||||
return self._task_executor.execute_task(task, task_uuid, kwargs,
|
||||
self._on_update_progress)
|
||||
@@ -81,7 +82,8 @@ class TaskAction(object):
|
||||
if not self.change_state(task, states.REVERTING, progress=0.0):
|
||||
raise exceptions.InvalidState("Task %s is in invalid state and"
|
||||
" can't be reverted" % task.name)
|
||||
kwargs = self._storage.fetch_mapped_args(task.rebind)
|
||||
kwargs = self._storage.fetch_mapped_args(task.rebind,
|
||||
task_name=task.name)
|
||||
task_uuid = self._storage.get_atom_uuid(task.name)
|
||||
task_result = self._storage.get(task.name)
|
||||
failures = self._storage.get_failures()
|
||||
|
||||
@@ -52,6 +52,7 @@ class Storage(object):
|
||||
self._flowdetail = flow_detail
|
||||
self._lock = self._lock_cls()
|
||||
self._transients = {}
|
||||
self._injected_args = {}
|
||||
|
||||
# NOTE(imelnikov): failure serialization looses information,
|
||||
# so we cache failures here, in atom name -> failure mapping.
|
||||
@@ -410,6 +411,10 @@ class Storage(object):
|
||||
if self._reset_atom(ad, state):
|
||||
self._with_connection(self._save_atom_detail, ad)
|
||||
|
||||
def inject_task_args(self, task_name, injected_args):
|
||||
self._injected_args.setdefault(task_name, {})
|
||||
self._injected_args[task_name].update(injected_args)
|
||||
|
||||
def inject(self, pairs, transient=False):
|
||||
"""Add values into storage.
|
||||
|
||||
@@ -516,11 +521,19 @@ class Storage(object):
|
||||
pass
|
||||
return results
|
||||
|
||||
def fetch_mapped_args(self, args_mapping):
|
||||
def fetch_mapped_args(self, args_mapping, task_name=None):
|
||||
"""Fetch arguments for an atom using an atoms arguments mapping."""
|
||||
with self._lock.read_lock():
|
||||
return dict((key, self.fetch(name))
|
||||
for key, name in six.iteritems(args_mapping))
|
||||
injected_args = {}
|
||||
if task_name:
|
||||
injected_args = self._injected_args.get(task_name, {})
|
||||
mapped_args = {}
|
||||
for key, name in six.iteritems(args_mapping):
|
||||
if name in injected_args:
|
||||
mapped_args[key] = injected_args[name]
|
||||
else:
|
||||
mapped_args[key] = self.fetch(name)
|
||||
return mapped_args
|
||||
|
||||
def set_flow_state(self, state):
|
||||
"""Set flow details state and save it."""
|
||||
|
||||
@@ -36,10 +36,10 @@ class BaseTask(atom.Atom):
|
||||
|
||||
TASK_EVENTS = ('update_progress', )
|
||||
|
||||
def __init__(self, name, provides=None):
|
||||
def __init__(self, name, provides=None, inject=None):
|
||||
if name is None:
|
||||
name = reflection.get_class_name(self)
|
||||
super(BaseTask, self).__init__(name, provides)
|
||||
super(BaseTask, self).__init__(name, provides, inject=inject)
|
||||
# Map of events => lists of callbacks to invoke on task events.
|
||||
self._events_listeners = collections.defaultdict(list)
|
||||
|
||||
@@ -172,11 +172,11 @@ class Task(BaseTask):
|
||||
default_provides = None
|
||||
|
||||
def __init__(self, name=None, provides=None, requires=None,
|
||||
auto_extract=True, rebind=None):
|
||||
auto_extract=True, rebind=None, inject=None):
|
||||
"""Initialize task instance."""
|
||||
if provides is None:
|
||||
provides = self.default_provides
|
||||
super(Task, self).__init__(name, provides=provides)
|
||||
super(Task, self).__init__(name, provides=provides, inject=inject)
|
||||
self._build_arg_mapping(self.execute, requires, rebind, auto_extract)
|
||||
|
||||
|
||||
@@ -188,7 +188,7 @@ class FunctorTask(BaseTask):
|
||||
|
||||
def __init__(self, execute, name=None, provides=None,
|
||||
requires=None, auto_extract=True, rebind=None, revert=None,
|
||||
version=None):
|
||||
version=None, inject=None):
|
||||
assert six.callable(execute), ("Function to use for executing must be"
|
||||
" callable")
|
||||
if revert:
|
||||
@@ -196,7 +196,8 @@ class FunctorTask(BaseTask):
|
||||
" be callable")
|
||||
if name is None:
|
||||
name = reflection.get_callable_name(execute)
|
||||
super(FunctorTask, self).__init__(name, provides=provides)
|
||||
super(FunctorTask, self).__init__(name, provides=provides,
|
||||
inject=inject)
|
||||
self._execute = execute
|
||||
self._revert = revert
|
||||
if version is not None:
|
||||
|
||||
@@ -90,6 +90,36 @@ class ArgumentsPassingTest(utils.EngineTestBase):
|
||||
'result': 30,
|
||||
})
|
||||
|
||||
def test_argument_injection(self):
|
||||
flow = utils.TaskMultiArgOneReturn(provides='result',
|
||||
inject={'x': 1, 'y': 4, 'z': 9})
|
||||
engine = self._make_engine(flow)
|
||||
engine.run()
|
||||
self.assertEqual(engine.storage.fetch_all(), {
|
||||
'result': 14,
|
||||
})
|
||||
|
||||
def test_argument_injection_rebind(self):
|
||||
flow = utils.TaskMultiArgOneReturn(provides='result',
|
||||
rebind=['a', 'b', 'c'],
|
||||
inject={'a': 1, 'b': 4, 'c': 9})
|
||||
engine = self._make_engine(flow)
|
||||
engine.run()
|
||||
self.assertEqual(engine.storage.fetch_all(), {
|
||||
'result': 14,
|
||||
})
|
||||
|
||||
def test_argument_injection_required(self):
|
||||
flow = utils.TaskMultiArgOneReturn(provides='result',
|
||||
requires=['a', 'b', 'c'],
|
||||
inject={'x': 1, 'y': 4, 'z': 9,
|
||||
'a': 0, 'b': 0, 'c': 0})
|
||||
engine = self._make_engine(flow)
|
||||
engine.run()
|
||||
self.assertEqual(engine.storage.fetch_all(), {
|
||||
'result': 14,
|
||||
})
|
||||
|
||||
def test_all_arguments_mapping(self):
|
||||
flow = utils.TaskMultiArgOneReturn(provides='result',
|
||||
rebind=['a', 'b', 'c'])
|
||||
|
||||
Reference in New Issue
Block a user