Map optional arguments as well as required arguments
Optional arguments that are not explicitly required are being ignored when arguments are being mapped based on inference from atoms' execute method signatures. This patch adds support for mapping optional arguments in addition to required arguments. Change-Id: I440c02dcd901a563df512e33754b13e3c05d4155
This commit is contained in:
parent
eae693406e
commit
7f0c457e72
@ -22,9 +22,12 @@ are and how to use those ways to accomplish your desired usage pattern.
|
||||
|
||||
Task/retry arguments
|
||||
Set of names of task/retry arguments available as the ``requires``
|
||||
property of the task/retry instance. When a task or retry object is
|
||||
about to be executed values with these names are retrieved from storage
|
||||
and passed to the ``execute`` method of the task/retry.
|
||||
and/or ``optional`` property of the task/retry instance. When a task or
|
||||
retry object is about to be executed values with these names are
|
||||
retrieved from storage and passed to the ``execute`` method of the
|
||||
task/retry. If any names in the ``requires`` property cannot be
|
||||
found in storage, an exception will be thrown. Any names in the
|
||||
``optional`` property that cannot be found are ignored.
|
||||
|
||||
Task/retry results
|
||||
Set of names of task/retry results (what task/retry provides) available
|
||||
@ -53,32 +56,26 @@ method of a task (or the |retry.execute| of a retry object).
|
||||
.. doctest::
|
||||
|
||||
>>> class MyTask(task.Task):
|
||||
... def execute(self, spam, eggs):
|
||||
... def execute(self, spam, eggs, bacon=None):
|
||||
... return spam + eggs
|
||||
...
|
||||
>>> sorted(MyTask().requires)
|
||||
['eggs', 'spam']
|
||||
>>> sorted(MyTask().optional)
|
||||
['bacon']
|
||||
|
||||
Inference from the method signature is the ''simplest'' way to specify
|
||||
arguments. Optional arguments (with default values), and special arguments like
|
||||
``self``, ``*args`` and ``**kwargs`` are ignored during inference (as these
|
||||
names have special meaning/usage in python).
|
||||
arguments. Special arguments like ``self``, ``*args`` and ``**kwargs`` are
|
||||
ignored during inference (as these names have special meaning/usage in python).
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> class MyTask(task.Task):
|
||||
... def execute(self, spam, eggs=()):
|
||||
... return spam + eggs
|
||||
...
|
||||
>>> MyTask().requires
|
||||
set(['spam'])
|
||||
>>>
|
||||
>>> class UniTask(task.Task):
|
||||
... def execute(self, *args, **kwargs):
|
||||
... pass
|
||||
...
|
||||
>>> UniTask().requires
|
||||
set([])
|
||||
frozenset([])
|
||||
|
||||
.. make vim sphinx highlighter* happy**
|
||||
|
||||
|
@ -82,33 +82,50 @@ def _build_arg_mapping(atom_name, reqs, rebind_args, function, do_infer,
|
||||
well as verify that the final argument mapping does not have missing or
|
||||
extra arguments (where applicable).
|
||||
"""
|
||||
atom_args = reflection.get_callable_args(function, required_only=True)
|
||||
|
||||
# build a list of required arguments based on function signature
|
||||
req_args = reflection.get_callable_args(function, required_only=True)
|
||||
all_args = reflection.get_callable_args(function, required_only=False)
|
||||
|
||||
# remove arguments that are part of ignore_list
|
||||
if ignore_list:
|
||||
for arg in ignore_list:
|
||||
if arg in atom_args:
|
||||
atom_args.remove(arg)
|
||||
if arg in req_args:
|
||||
req_args.remove(arg)
|
||||
else:
|
||||
ignore_list = []
|
||||
|
||||
result = {}
|
||||
required = {}
|
||||
# add reqs to required mappings
|
||||
if reqs:
|
||||
result.update((a, a) for a in reqs)
|
||||
required.update((a, a) for a in reqs)
|
||||
|
||||
# add req_args to required mappings if do_infer is set
|
||||
if do_infer:
|
||||
result.update((a, a) for a in atom_args)
|
||||
result.update(_build_rebind_dict(atom_args, rebind_args))
|
||||
required.update((a, a) for a in req_args)
|
||||
|
||||
# update required mappings based on rebind_args
|
||||
required.update(_build_rebind_dict(req_args, rebind_args))
|
||||
|
||||
if do_infer:
|
||||
opt_args = set(all_args) - set(required) - set(ignore_list)
|
||||
optional = dict((a, a) for a in opt_args)
|
||||
else:
|
||||
optional = {}
|
||||
|
||||
if not reflection.accepts_kwargs(function):
|
||||
all_args = reflection.get_callable_args(function, required_only=False)
|
||||
extra_args = set(result) - set(all_args)
|
||||
extra_args = set(required) - set(all_args)
|
||||
if extra_args:
|
||||
extra_args_str = ', '.join(sorted(extra_args))
|
||||
raise ValueError('Extra arguments given to atom %s: %s'
|
||||
% (atom_name, extra_args_str))
|
||||
|
||||
# NOTE(imelnikov): don't use set to preserve order in error message
|
||||
missing_args = [arg for arg in atom_args if arg not in result]
|
||||
missing_args = [arg for arg in req_args if arg not in required]
|
||||
if missing_args:
|
||||
raise ValueError('Missing arguments for atom %s: %s'
|
||||
% (atom_name, ' ,'.join(missing_args)))
|
||||
return result
|
||||
return required, optional
|
||||
|
||||
|
||||
class Atom(object):
|
||||
@ -146,6 +163,13 @@ class Atom(object):
|
||||
commences (this allows for providing atom *local* values that
|
||||
do not need to be provided by other atoms/dependents).
|
||||
:ivar inject: See parameter ``inject``.
|
||||
:ivar requires: Any inputs this atom requires to function (if applicable).
|
||||
NOTE(harlowja): there can be no intersection between what
|
||||
this atom requires and what it produces (since this would
|
||||
be an impossible dependency to satisfy).
|
||||
:ivar optional: Any inputs that are optional for this atom's execute
|
||||
method.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name=None, provides=None, inject=None):
|
||||
@ -153,11 +177,27 @@ class Atom(object):
|
||||
self.save_as = _save_as_to_mapping(provides)
|
||||
self.version = (1, 0)
|
||||
self.inject = inject
|
||||
self.requires = frozenset()
|
||||
self.optional = frozenset()
|
||||
|
||||
def _build_arg_mapping(self, executor, requires=None, rebind=None,
|
||||
auto_extract=True, ignore_list=None):
|
||||
self.rebind = _build_arg_mapping(self.name, requires, rebind,
|
||||
executor, auto_extract, ignore_list)
|
||||
req_arg, opt_arg = _build_arg_mapping(self.name, requires, rebind,
|
||||
executor, auto_extract,
|
||||
ignore_list)
|
||||
|
||||
self.rebind = {}
|
||||
if opt_arg:
|
||||
self.rebind.update(opt_arg)
|
||||
if req_arg:
|
||||
self.rebind.update(req_arg)
|
||||
self.requires = frozenset(req_arg.values())
|
||||
self.optional = frozenset(opt_arg.values())
|
||||
if self.inject:
|
||||
inject_set = set(six.iterkeys(self.inject))
|
||||
self.requires -= inject_set
|
||||
self.optional -= inject_set
|
||||
|
||||
out_of_order = self.provides.intersection(self.requires)
|
||||
if out_of_order:
|
||||
raise exceptions.DependencyFailure(
|
||||
@ -185,16 +225,3 @@ class Atom(object):
|
||||
dependency to satisfy).
|
||||
"""
|
||||
return set(self.save_as)
|
||||
|
||||
@property
|
||||
def requires(self):
|
||||
"""Any inputs this atom requires to function (if applicable).
|
||||
|
||||
NOTE(harlowja): there can be no intersection between what this atom
|
||||
requires and what it produces (since this would be an impossible
|
||||
dependency to satisfy).
|
||||
"""
|
||||
requires = set(self.rebind.values())
|
||||
if self.inject:
|
||||
requires = requires - set(six.iterkeys(self.inject))
|
||||
return requires
|
||||
|
@ -54,9 +54,12 @@ class RetryAction(base.Action):
|
||||
|
||||
def _get_retry_args(self, retry, addons=None):
|
||||
scope_walker = self._walker_factory(retry)
|
||||
arguments = self._storage.fetch_mapped_args(retry.rebind,
|
||||
atom_name=retry.name,
|
||||
scope_walker=scope_walker)
|
||||
arguments = self._storage.fetch_mapped_args(
|
||||
retry.rebind,
|
||||
atom_name=retry.name,
|
||||
scope_walker=scope_walker,
|
||||
optional_args=retry.optional
|
||||
)
|
||||
history = self._storage.get_retry_history(retry.name)
|
||||
arguments[retry_atom.EXECUTE_REVERT_HISTORY] = history
|
||||
if addons:
|
||||
|
@ -101,9 +101,12 @@ class TaskAction(base.Action):
|
||||
def schedule_execution(self, task):
|
||||
self.change_state(task, states.RUNNING, progress=0.0)
|
||||
scope_walker = self._walker_factory(task)
|
||||
arguments = self._storage.fetch_mapped_args(task.rebind,
|
||||
atom_name=task.name,
|
||||
scope_walker=scope_walker)
|
||||
arguments = self._storage.fetch_mapped_args(
|
||||
task.rebind,
|
||||
atom_name=task.name,
|
||||
scope_walker=scope_walker,
|
||||
optional_args=task.optional
|
||||
)
|
||||
if task.notifier.can_be_registered(task_atom.EVENT_UPDATE_PROGRESS):
|
||||
progress_callback = functools.partial(self._on_update_progress,
|
||||
task)
|
||||
@ -124,9 +127,12 @@ class TaskAction(base.Action):
|
||||
def schedule_reversion(self, task):
|
||||
self.change_state(task, states.REVERTING, progress=0.0)
|
||||
scope_walker = self._walker_factory(task)
|
||||
arguments = self._storage.fetch_mapped_args(task.rebind,
|
||||
atom_name=task.name,
|
||||
scope_walker=scope_walker)
|
||||
arguments = self._storage.fetch_mapped_args(
|
||||
task.rebind,
|
||||
atom_name=task.name,
|
||||
scope_walker=scope_walker,
|
||||
optional_args=task.optional
|
||||
)
|
||||
task_uuid = self._storage.get_atom_uuid(task.name)
|
||||
task_result = self._storage.get(task.name)
|
||||
failures = self._storage.get_failures()
|
||||
|
93
taskflow/examples/optional_arguments.py
Normal file
93
taskflow/examples/optional_arguments.py
Normal file
@ -0,0 +1,93 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2015 Hewlett-Packard Development Company, L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from taskflow import engines
|
||||
from taskflow.patterns import linear_flow
|
||||
from taskflow import task
|
||||
|
||||
|
||||
class TestTask(task.Task):
|
||||
def execute(self, a, b=5):
|
||||
result = a * b
|
||||
return result
|
||||
|
||||
flow_no_inject = linear_flow.Flow("flow").add(TestTask(provides='result'))
|
||||
flow_inject_a = linear_flow.Flow("flow").add(TestTask(provides='result',
|
||||
inject={'a': 10}))
|
||||
flow_inject_b = linear_flow.Flow("flow").add(TestTask(provides='result',
|
||||
inject={'b': 1000}))
|
||||
|
||||
ASSERT = True
|
||||
|
||||
|
||||
class MyTest(unittest.TestCase):
|
||||
def test_my_test(self):
|
||||
print("Expected result = 15")
|
||||
result = engines.run(flow_no_inject, store={'a': 3})
|
||||
print(result)
|
||||
if ASSERT:
|
||||
self.assertEqual(result,
|
||||
{'a': 3, 'result': 15}
|
||||
)
|
||||
|
||||
print("Expected result = 39")
|
||||
result = engines.run(flow_no_inject, store={'a': 3, 'b': 7})
|
||||
print(result)
|
||||
if ASSERT:
|
||||
self.assertEqual(
|
||||
result,
|
||||
{'a': 3, 'b': 7, 'result': 21}
|
||||
)
|
||||
|
||||
print("Expected result = 200")
|
||||
result = engines.run(flow_inject_a, store={'a': 3})
|
||||
print(result)
|
||||
if ASSERT:
|
||||
self.assertEqual(
|
||||
result,
|
||||
{'a': 3, 'result': 50}
|
||||
)
|
||||
|
||||
print("Expected result = 400")
|
||||
result = engines.run(flow_inject_a, store={'a': 3, 'b': 7})
|
||||
print(result)
|
||||
if ASSERT:
|
||||
self.assertEqual(
|
||||
result,
|
||||
{'a': 3, 'b': 7, 'result': 70}
|
||||
)
|
||||
|
||||
print("Expected result = 40")
|
||||
result = engines.run(flow_inject_b, store={'a': 3})
|
||||
print(result)
|
||||
if ASSERT:
|
||||
self.assertEqual(
|
||||
result,
|
||||
{'a': 3, 'result': 3000}
|
||||
)
|
||||
|
||||
print("Expected result = 40")
|
||||
result = engines.run(flow_inject_b, store={'a': 3, 'b': 7})
|
||||
print(result)
|
||||
if ASSERT:
|
||||
self.assertEqual(
|
||||
result,
|
||||
{'a': 3, 'b': 7, 'result': 3000}
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -635,7 +635,8 @@ class Storage(object):
|
||||
return results
|
||||
|
||||
def fetch_mapped_args(self, args_mapping,
|
||||
atom_name=None, scope_walker=None):
|
||||
atom_name=None, scope_walker=None,
|
||||
optional_args=None):
|
||||
"""Fetch arguments for an atom using an atoms argument mapping."""
|
||||
|
||||
def _get_results(looking_for, provider):
|
||||
@ -674,10 +675,14 @@ class Storage(object):
|
||||
return []
|
||||
|
||||
with self._lock.read_lock():
|
||||
if optional_args is None:
|
||||
optional_args = []
|
||||
|
||||
if atom_name and atom_name not in self._atom_name_to_uuid:
|
||||
raise exceptions.NotFound("Unknown atom name: %s" % atom_name)
|
||||
if not args_mapping:
|
||||
return {}
|
||||
|
||||
# The order of lookup is the following:
|
||||
#
|
||||
# 1. Injected atom specific arguments.
|
||||
@ -711,6 +716,8 @@ class Storage(object):
|
||||
try:
|
||||
possible_providers = self._reverse_mapping[name]
|
||||
except KeyError:
|
||||
if bound_name in optional_args:
|
||||
continue
|
||||
raise exceptions.NotFound("Name %r is not mapped as a"
|
||||
" produced output by any"
|
||||
" providers" % name)
|
||||
|
@ -354,6 +354,20 @@ class StorageTestMixin(object):
|
||||
self.assertRaises(exceptions.NotFound,
|
||||
s.fetch_mapped_args, {'viking': 'helmet'})
|
||||
|
||||
def test_fetch_optional_args_found(self):
|
||||
s = self._get_storage()
|
||||
s.inject({'foo': 'bar', 'spam': 'eggs'})
|
||||
self.assertEqual(s.fetch_mapped_args({'viking': 'spam'},
|
||||
optional_args=set(['viking'])),
|
||||
{'viking': 'eggs'})
|
||||
|
||||
def test_fetch_optional_args_not_found(self):
|
||||
s = self._get_storage()
|
||||
s.inject({'foo': 'bar', 'spam': 'eggs'})
|
||||
self.assertEqual(s.fetch_mapped_args({'viking': 'helmet'},
|
||||
optional_args=set(['viking'])),
|
||||
{})
|
||||
|
||||
def test_set_and_get_task_state(self):
|
||||
s = self._get_storage()
|
||||
state = states.PENDING
|
||||
|
@ -120,10 +120,19 @@ class TaskTest(test.TestCase):
|
||||
def test_requires_ignores_optional(self):
|
||||
my_task = DefaultArgTask()
|
||||
self.assertEqual(my_task.requires, set(['spam']))
|
||||
self.assertEqual(my_task.optional, set(['eggs']))
|
||||
|
||||
def test_requires_allows_optional(self):
|
||||
my_task = DefaultArgTask(requires=('spam', 'eggs'))
|
||||
self.assertEqual(my_task.requires, set(['spam', 'eggs']))
|
||||
self.assertEqual(my_task.optional, set())
|
||||
|
||||
def test_rebind_includes_optional(self):
|
||||
my_task = DefaultArgTask()
|
||||
self.assertEqual(my_task.rebind, {
|
||||
'spam': 'spam',
|
||||
'eggs': 'eggs',
|
||||
})
|
||||
|
||||
def test_rebind_all_args(self):
|
||||
my_task = MyTask(rebind={'spam': 'a', 'eggs': 'b', 'context': 'c'})
|
||||
|
Loading…
x
Reference in New Issue
Block a user