Merge "Map optional arguments as well as required arguments"
This commit is contained in:
commit
761321dec7
@ -22,9 +22,12 @@ are and how to use those ways to accomplish your desired usage pattern.
|
||||
|
||||
Task/retry arguments
|
||||
Set of names of task/retry arguments available as the ``requires``
|
||||
property of the task/retry instance. When a task or retry object is
|
||||
about to be executed values with these names are retrieved from storage
|
||||
and passed to the ``execute`` method of the task/retry.
|
||||
and/or ``optional`` property of the task/retry instance. When a task or
|
||||
retry object is about to be executed values with these names are
|
||||
retrieved from storage and passed to the ``execute`` method of the
|
||||
task/retry. If any names in the ``requires`` property cannot be
|
||||
found in storage, an exception will be thrown. Any names in the
|
||||
``optional`` property that cannot be found are ignored.
|
||||
|
||||
Task/retry results
|
||||
Set of names of task/retry results (what task/retry provides) available
|
||||
@ -53,32 +56,26 @@ method of a task (or the |retry.execute| of a retry object).
|
||||
.. doctest::
|
||||
|
||||
>>> class MyTask(task.Task):
|
||||
... def execute(self, spam, eggs):
|
||||
... def execute(self, spam, eggs, bacon=None):
|
||||
... return spam + eggs
|
||||
...
|
||||
>>> sorted(MyTask().requires)
|
||||
['eggs', 'spam']
|
||||
>>> sorted(MyTask().optional)
|
||||
['bacon']
|
||||
|
||||
Inference from the method signature is the ''simplest'' way to specify
|
||||
arguments. Optional arguments (with default values), and special arguments like
|
||||
``self``, ``*args`` and ``**kwargs`` are ignored during inference (as these
|
||||
names have special meaning/usage in python).
|
||||
arguments. Special arguments like ``self``, ``*args`` and ``**kwargs`` are
|
||||
ignored during inference (as these names have special meaning/usage in python).
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> class MyTask(task.Task):
|
||||
... def execute(self, spam, eggs=()):
|
||||
... return spam + eggs
|
||||
...
|
||||
>>> MyTask().requires
|
||||
set(['spam'])
|
||||
>>>
|
||||
>>> class UniTask(task.Task):
|
||||
... def execute(self, *args, **kwargs):
|
||||
... pass
|
||||
...
|
||||
>>> UniTask().requires
|
||||
set([])
|
||||
frozenset([])
|
||||
|
||||
.. make vim sphinx highlighter* happy**
|
||||
|
||||
|
@ -82,33 +82,50 @@ def _build_arg_mapping(atom_name, reqs, rebind_args, function, do_infer,
|
||||
well as verify that the final argument mapping does not have missing or
|
||||
extra arguments (where applicable).
|
||||
"""
|
||||
atom_args = reflection.get_callable_args(function, required_only=True)
|
||||
|
||||
# build a list of required arguments based on function signature
|
||||
req_args = reflection.get_callable_args(function, required_only=True)
|
||||
all_args = reflection.get_callable_args(function, required_only=False)
|
||||
|
||||
# remove arguments that are part of ignore_list
|
||||
if ignore_list:
|
||||
for arg in ignore_list:
|
||||
if arg in atom_args:
|
||||
atom_args.remove(arg)
|
||||
if arg in req_args:
|
||||
req_args.remove(arg)
|
||||
else:
|
||||
ignore_list = []
|
||||
|
||||
result = {}
|
||||
required = {}
|
||||
# add reqs to required mappings
|
||||
if reqs:
|
||||
result.update((a, a) for a in reqs)
|
||||
required.update((a, a) for a in reqs)
|
||||
|
||||
# add req_args to required mappings if do_infer is set
|
||||
if do_infer:
|
||||
result.update((a, a) for a in atom_args)
|
||||
result.update(_build_rebind_dict(atom_args, rebind_args))
|
||||
required.update((a, a) for a in req_args)
|
||||
|
||||
# update required mappings based on rebind_args
|
||||
required.update(_build_rebind_dict(req_args, rebind_args))
|
||||
|
||||
if do_infer:
|
||||
opt_args = set(all_args) - set(required) - set(ignore_list)
|
||||
optional = dict((a, a) for a in opt_args)
|
||||
else:
|
||||
optional = {}
|
||||
|
||||
if not reflection.accepts_kwargs(function):
|
||||
all_args = reflection.get_callable_args(function, required_only=False)
|
||||
extra_args = set(result) - set(all_args)
|
||||
extra_args = set(required) - set(all_args)
|
||||
if extra_args:
|
||||
extra_args_str = ', '.join(sorted(extra_args))
|
||||
raise ValueError('Extra arguments given to atom %s: %s'
|
||||
% (atom_name, extra_args_str))
|
||||
|
||||
# NOTE(imelnikov): don't use set to preserve order in error message
|
||||
missing_args = [arg for arg in atom_args if arg not in result]
|
||||
missing_args = [arg for arg in req_args if arg not in required]
|
||||
if missing_args:
|
||||
raise ValueError('Missing arguments for atom %s: %s'
|
||||
% (atom_name, ' ,'.join(missing_args)))
|
||||
return result
|
||||
return required, optional
|
||||
|
||||
|
||||
class Atom(object):
|
||||
@ -146,6 +163,13 @@ class Atom(object):
|
||||
commences (this allows for providing atom *local* values that
|
||||
do not need to be provided by other atoms/dependents).
|
||||
:ivar inject: See parameter ``inject``.
|
||||
:ivar requires: Any inputs this atom requires to function (if applicable).
|
||||
NOTE(harlowja): there can be no intersection between what
|
||||
this atom requires and what it produces (since this would
|
||||
be an impossible dependency to satisfy).
|
||||
:ivar optional: Any inputs that are optional for this atom's execute
|
||||
method.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name=None, provides=None, inject=None):
|
||||
@ -153,11 +177,27 @@ class Atom(object):
|
||||
self.save_as = _save_as_to_mapping(provides)
|
||||
self.version = (1, 0)
|
||||
self.inject = inject
|
||||
self.requires = frozenset()
|
||||
self.optional = frozenset()
|
||||
|
||||
def _build_arg_mapping(self, executor, requires=None, rebind=None,
|
||||
auto_extract=True, ignore_list=None):
|
||||
self.rebind = _build_arg_mapping(self.name, requires, rebind,
|
||||
executor, auto_extract, ignore_list)
|
||||
req_arg, opt_arg = _build_arg_mapping(self.name, requires, rebind,
|
||||
executor, auto_extract,
|
||||
ignore_list)
|
||||
|
||||
self.rebind = {}
|
||||
if opt_arg:
|
||||
self.rebind.update(opt_arg)
|
||||
if req_arg:
|
||||
self.rebind.update(req_arg)
|
||||
self.requires = frozenset(req_arg.values())
|
||||
self.optional = frozenset(opt_arg.values())
|
||||
if self.inject:
|
||||
inject_set = set(six.iterkeys(self.inject))
|
||||
self.requires -= inject_set
|
||||
self.optional -= inject_set
|
||||
|
||||
out_of_order = self.provides.intersection(self.requires)
|
||||
if out_of_order:
|
||||
raise exceptions.DependencyFailure(
|
||||
@ -185,16 +225,3 @@ class Atom(object):
|
||||
dependency to satisfy).
|
||||
"""
|
||||
return set(self.save_as)
|
||||
|
||||
@property
|
||||
def requires(self):
|
||||
"""Any inputs this atom requires to function (if applicable).
|
||||
|
||||
NOTE(harlowja): there can be no intersection between what this atom
|
||||
requires and what it produces (since this would be an impossible
|
||||
dependency to satisfy).
|
||||
"""
|
||||
requires = set(self.rebind.values())
|
||||
if self.inject:
|
||||
requires = requires - set(six.iterkeys(self.inject))
|
||||
return requires
|
||||
|
@ -54,9 +54,12 @@ class RetryAction(base.Action):
|
||||
|
||||
def _get_retry_args(self, retry, addons=None):
|
||||
scope_walker = self._walker_factory(retry)
|
||||
arguments = self._storage.fetch_mapped_args(retry.rebind,
|
||||
atom_name=retry.name,
|
||||
scope_walker=scope_walker)
|
||||
arguments = self._storage.fetch_mapped_args(
|
||||
retry.rebind,
|
||||
atom_name=retry.name,
|
||||
scope_walker=scope_walker,
|
||||
optional_args=retry.optional
|
||||
)
|
||||
history = self._storage.get_retry_history(retry.name)
|
||||
arguments[retry_atom.EXECUTE_REVERT_HISTORY] = history
|
||||
if addons:
|
||||
|
@ -101,9 +101,12 @@ class TaskAction(base.Action):
|
||||
def schedule_execution(self, task):
|
||||
self.change_state(task, states.RUNNING, progress=0.0)
|
||||
scope_walker = self._walker_factory(task)
|
||||
arguments = self._storage.fetch_mapped_args(task.rebind,
|
||||
atom_name=task.name,
|
||||
scope_walker=scope_walker)
|
||||
arguments = self._storage.fetch_mapped_args(
|
||||
task.rebind,
|
||||
atom_name=task.name,
|
||||
scope_walker=scope_walker,
|
||||
optional_args=task.optional
|
||||
)
|
||||
if task.notifier.can_be_registered(task_atom.EVENT_UPDATE_PROGRESS):
|
||||
progress_callback = functools.partial(self._on_update_progress,
|
||||
task)
|
||||
@ -124,9 +127,12 @@ class TaskAction(base.Action):
|
||||
def schedule_reversion(self, task):
|
||||
self.change_state(task, states.REVERTING, progress=0.0)
|
||||
scope_walker = self._walker_factory(task)
|
||||
arguments = self._storage.fetch_mapped_args(task.rebind,
|
||||
atom_name=task.name,
|
||||
scope_walker=scope_walker)
|
||||
arguments = self._storage.fetch_mapped_args(
|
||||
task.rebind,
|
||||
atom_name=task.name,
|
||||
scope_walker=scope_walker,
|
||||
optional_args=task.optional
|
||||
)
|
||||
task_uuid = self._storage.get_atom_uuid(task.name)
|
||||
task_result = self._storage.get(task.name)
|
||||
failures = self._storage.get_failures()
|
||||
|
93
taskflow/examples/optional_arguments.py
Normal file
93
taskflow/examples/optional_arguments.py
Normal file
@ -0,0 +1,93 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2015 Hewlett-Packard Development Company, L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from taskflow import engines
|
||||
from taskflow.patterns import linear_flow
|
||||
from taskflow import task
|
||||
|
||||
|
||||
class TestTask(task.Task):
|
||||
def execute(self, a, b=5):
|
||||
result = a * b
|
||||
return result
|
||||
|
||||
flow_no_inject = linear_flow.Flow("flow").add(TestTask(provides='result'))
|
||||
flow_inject_a = linear_flow.Flow("flow").add(TestTask(provides='result',
|
||||
inject={'a': 10}))
|
||||
flow_inject_b = linear_flow.Flow("flow").add(TestTask(provides='result',
|
||||
inject={'b': 1000}))
|
||||
|
||||
ASSERT = True
|
||||
|
||||
|
||||
class MyTest(unittest.TestCase):
|
||||
def test_my_test(self):
|
||||
print("Expected result = 15")
|
||||
result = engines.run(flow_no_inject, store={'a': 3})
|
||||
print(result)
|
||||
if ASSERT:
|
||||
self.assertEqual(result,
|
||||
{'a': 3, 'result': 15}
|
||||
)
|
||||
|
||||
print("Expected result = 39")
|
||||
result = engines.run(flow_no_inject, store={'a': 3, 'b': 7})
|
||||
print(result)
|
||||
if ASSERT:
|
||||
self.assertEqual(
|
||||
result,
|
||||
{'a': 3, 'b': 7, 'result': 21}
|
||||
)
|
||||
|
||||
print("Expected result = 200")
|
||||
result = engines.run(flow_inject_a, store={'a': 3})
|
||||
print(result)
|
||||
if ASSERT:
|
||||
self.assertEqual(
|
||||
result,
|
||||
{'a': 3, 'result': 50}
|
||||
)
|
||||
|
||||
print("Expected result = 400")
|
||||
result = engines.run(flow_inject_a, store={'a': 3, 'b': 7})
|
||||
print(result)
|
||||
if ASSERT:
|
||||
self.assertEqual(
|
||||
result,
|
||||
{'a': 3, 'b': 7, 'result': 70}
|
||||
)
|
||||
|
||||
print("Expected result = 40")
|
||||
result = engines.run(flow_inject_b, store={'a': 3})
|
||||
print(result)
|
||||
if ASSERT:
|
||||
self.assertEqual(
|
||||
result,
|
||||
{'a': 3, 'result': 3000}
|
||||
)
|
||||
|
||||
print("Expected result = 40")
|
||||
result = engines.run(flow_inject_b, store={'a': 3, 'b': 7})
|
||||
print(result)
|
||||
if ASSERT:
|
||||
self.assertEqual(
|
||||
result,
|
||||
{'a': 3, 'b': 7, 'result': 3000}
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -629,7 +629,8 @@ class Storage(object):
|
||||
return results
|
||||
|
||||
def fetch_mapped_args(self, args_mapping,
|
||||
atom_name=None, scope_walker=None):
|
||||
atom_name=None, scope_walker=None,
|
||||
optional_args=None):
|
||||
"""Fetch arguments for an atom using an atoms argument mapping."""
|
||||
|
||||
def _get_results(looking_for, provider):
|
||||
@ -668,10 +669,14 @@ class Storage(object):
|
||||
return []
|
||||
|
||||
with self._lock.read_lock():
|
||||
if optional_args is None:
|
||||
optional_args = []
|
||||
|
||||
if atom_name and atom_name not in self._atom_name_to_uuid:
|
||||
raise exceptions.NotFound("Unknown atom name: %s" % atom_name)
|
||||
if not args_mapping:
|
||||
return {}
|
||||
|
||||
# The order of lookup is the following:
|
||||
#
|
||||
# 1. Injected atom specific arguments.
|
||||
@ -705,6 +710,8 @@ class Storage(object):
|
||||
try:
|
||||
possible_providers = self._reverse_mapping[name]
|
||||
except KeyError:
|
||||
if bound_name in optional_args:
|
||||
continue
|
||||
raise exceptions.NotFound("Name %r is not mapped as a"
|
||||
" produced output by any"
|
||||
" providers" % name)
|
||||
|
@ -351,6 +351,20 @@ class StorageTestMixin(object):
|
||||
self.assertRaises(exceptions.NotFound,
|
||||
s.fetch_mapped_args, {'viking': 'helmet'})
|
||||
|
||||
def test_fetch_optional_args_found(self):
|
||||
s = self._get_storage()
|
||||
s.inject({'foo': 'bar', 'spam': 'eggs'})
|
||||
self.assertEqual(s.fetch_mapped_args({'viking': 'spam'},
|
||||
optional_args=set(['viking'])),
|
||||
{'viking': 'eggs'})
|
||||
|
||||
def test_fetch_optional_args_not_found(self):
|
||||
s = self._get_storage()
|
||||
s.inject({'foo': 'bar', 'spam': 'eggs'})
|
||||
self.assertEqual(s.fetch_mapped_args({'viking': 'helmet'},
|
||||
optional_args=set(['viking'])),
|
||||
{})
|
||||
|
||||
def test_set_and_get_task_state(self):
|
||||
s = self._get_storage()
|
||||
state = states.PENDING
|
||||
|
@ -120,10 +120,19 @@ class TaskTest(test.TestCase):
|
||||
def test_requires_ignores_optional(self):
|
||||
my_task = DefaultArgTask()
|
||||
self.assertEqual(my_task.requires, set(['spam']))
|
||||
self.assertEqual(my_task.optional, set(['eggs']))
|
||||
|
||||
def test_requires_allows_optional(self):
|
||||
my_task = DefaultArgTask(requires=('spam', 'eggs'))
|
||||
self.assertEqual(my_task.requires, set(['spam', 'eggs']))
|
||||
self.assertEqual(my_task.optional, set())
|
||||
|
||||
def test_rebind_includes_optional(self):
|
||||
my_task = DefaultArgTask()
|
||||
self.assertEqual(my_task.rebind, {
|
||||
'spam': 'spam',
|
||||
'eggs': 'eggs',
|
||||
})
|
||||
|
||||
def test_rebind_all_args(self):
|
||||
my_task = MyTask(rebind={'spam': 'a', 'eggs': 'b', 'context': 'c'})
|
||||
|
Loading…
x
Reference in New Issue
Block a user