Merge "Map optional arguments as well as required arguments"
This commit is contained in:
commit
761321dec7
@ -22,9 +22,12 @@ are and how to use those ways to accomplish your desired usage pattern.
|
|||||||
|
|
||||||
Task/retry arguments
|
Task/retry arguments
|
||||||
Set of names of task/retry arguments available as the ``requires``
|
Set of names of task/retry arguments available as the ``requires``
|
||||||
property of the task/retry instance. When a task or retry object is
|
and/or ``optional`` property of the task/retry instance. When a task or
|
||||||
about to be executed values with these names are retrieved from storage
|
retry object is about to be executed values with these names are
|
||||||
and passed to the ``execute`` method of the task/retry.
|
retrieved from storage and passed to the ``execute`` method of the
|
||||||
|
task/retry. If any names in the ``requires`` property cannot be
|
||||||
|
found in storage, an exception will be thrown. Any names in the
|
||||||
|
``optional`` property that cannot be found are ignored.
|
||||||
|
|
||||||
Task/retry results
|
Task/retry results
|
||||||
Set of names of task/retry results (what task/retry provides) available
|
Set of names of task/retry results (what task/retry provides) available
|
||||||
@ -53,32 +56,26 @@ method of a task (or the |retry.execute| of a retry object).
|
|||||||
.. doctest::
|
.. doctest::
|
||||||
|
|
||||||
>>> class MyTask(task.Task):
|
>>> class MyTask(task.Task):
|
||||||
... def execute(self, spam, eggs):
|
... def execute(self, spam, eggs, bacon=None):
|
||||||
... return spam + eggs
|
... return spam + eggs
|
||||||
...
|
...
|
||||||
>>> sorted(MyTask().requires)
|
>>> sorted(MyTask().requires)
|
||||||
['eggs', 'spam']
|
['eggs', 'spam']
|
||||||
|
>>> sorted(MyTask().optional)
|
||||||
|
['bacon']
|
||||||
|
|
||||||
Inference from the method signature is the ''simplest'' way to specify
|
Inference from the method signature is the ''simplest'' way to specify
|
||||||
arguments. Optional arguments (with default values), and special arguments like
|
arguments. Special arguments like ``self``, ``*args`` and ``**kwargs`` are
|
||||||
``self``, ``*args`` and ``**kwargs`` are ignored during inference (as these
|
ignored during inference (as these names have special meaning/usage in python).
|
||||||
names have special meaning/usage in python).
|
|
||||||
|
|
||||||
.. doctest::
|
.. doctest::
|
||||||
|
|
||||||
>>> class MyTask(task.Task):
|
|
||||||
... def execute(self, spam, eggs=()):
|
|
||||||
... return spam + eggs
|
|
||||||
...
|
|
||||||
>>> MyTask().requires
|
|
||||||
set(['spam'])
|
|
||||||
>>>
|
|
||||||
>>> class UniTask(task.Task):
|
>>> class UniTask(task.Task):
|
||||||
... def execute(self, *args, **kwargs):
|
... def execute(self, *args, **kwargs):
|
||||||
... pass
|
... pass
|
||||||
...
|
...
|
||||||
>>> UniTask().requires
|
>>> UniTask().requires
|
||||||
set([])
|
frozenset([])
|
||||||
|
|
||||||
.. make vim sphinx highlighter* happy**
|
.. make vim sphinx highlighter* happy**
|
||||||
|
|
||||||
|
@ -82,33 +82,50 @@ def _build_arg_mapping(atom_name, reqs, rebind_args, function, do_infer,
|
|||||||
well as verify that the final argument mapping does not have missing or
|
well as verify that the final argument mapping does not have missing or
|
||||||
extra arguments (where applicable).
|
extra arguments (where applicable).
|
||||||
"""
|
"""
|
||||||
atom_args = reflection.get_callable_args(function, required_only=True)
|
|
||||||
|
# build a list of required arguments based on function signature
|
||||||
|
req_args = reflection.get_callable_args(function, required_only=True)
|
||||||
|
all_args = reflection.get_callable_args(function, required_only=False)
|
||||||
|
|
||||||
|
# remove arguments that are part of ignore_list
|
||||||
if ignore_list:
|
if ignore_list:
|
||||||
for arg in ignore_list:
|
for arg in ignore_list:
|
||||||
if arg in atom_args:
|
if arg in req_args:
|
||||||
atom_args.remove(arg)
|
req_args.remove(arg)
|
||||||
|
else:
|
||||||
|
ignore_list = []
|
||||||
|
|
||||||
result = {}
|
required = {}
|
||||||
|
# add reqs to required mappings
|
||||||
if reqs:
|
if reqs:
|
||||||
result.update((a, a) for a in reqs)
|
required.update((a, a) for a in reqs)
|
||||||
|
|
||||||
|
# add req_args to required mappings if do_infer is set
|
||||||
if do_infer:
|
if do_infer:
|
||||||
result.update((a, a) for a in atom_args)
|
required.update((a, a) for a in req_args)
|
||||||
result.update(_build_rebind_dict(atom_args, rebind_args))
|
|
||||||
|
# update required mappings based on rebind_args
|
||||||
|
required.update(_build_rebind_dict(req_args, rebind_args))
|
||||||
|
|
||||||
|
if do_infer:
|
||||||
|
opt_args = set(all_args) - set(required) - set(ignore_list)
|
||||||
|
optional = dict((a, a) for a in opt_args)
|
||||||
|
else:
|
||||||
|
optional = {}
|
||||||
|
|
||||||
if not reflection.accepts_kwargs(function):
|
if not reflection.accepts_kwargs(function):
|
||||||
all_args = reflection.get_callable_args(function, required_only=False)
|
extra_args = set(required) - set(all_args)
|
||||||
extra_args = set(result) - set(all_args)
|
|
||||||
if extra_args:
|
if extra_args:
|
||||||
extra_args_str = ', '.join(sorted(extra_args))
|
extra_args_str = ', '.join(sorted(extra_args))
|
||||||
raise ValueError('Extra arguments given to atom %s: %s'
|
raise ValueError('Extra arguments given to atom %s: %s'
|
||||||
% (atom_name, extra_args_str))
|
% (atom_name, extra_args_str))
|
||||||
|
|
||||||
# NOTE(imelnikov): don't use set to preserve order in error message
|
# NOTE(imelnikov): don't use set to preserve order in error message
|
||||||
missing_args = [arg for arg in atom_args if arg not in result]
|
missing_args = [arg for arg in req_args if arg not in required]
|
||||||
if missing_args:
|
if missing_args:
|
||||||
raise ValueError('Missing arguments for atom %s: %s'
|
raise ValueError('Missing arguments for atom %s: %s'
|
||||||
% (atom_name, ' ,'.join(missing_args)))
|
% (atom_name, ' ,'.join(missing_args)))
|
||||||
return result
|
return required, optional
|
||||||
|
|
||||||
|
|
||||||
class Atom(object):
|
class Atom(object):
|
||||||
@ -146,6 +163,13 @@ class Atom(object):
|
|||||||
commences (this allows for providing atom *local* values that
|
commences (this allows for providing atom *local* values that
|
||||||
do not need to be provided by other atoms/dependents).
|
do not need to be provided by other atoms/dependents).
|
||||||
:ivar inject: See parameter ``inject``.
|
:ivar inject: See parameter ``inject``.
|
||||||
|
:ivar requires: Any inputs this atom requires to function (if applicable).
|
||||||
|
NOTE(harlowja): there can be no intersection between what
|
||||||
|
this atom requires and what it produces (since this would
|
||||||
|
be an impossible dependency to satisfy).
|
||||||
|
:ivar optional: Any inputs that are optional for this atom's execute
|
||||||
|
method.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name=None, provides=None, inject=None):
|
def __init__(self, name=None, provides=None, inject=None):
|
||||||
@ -153,11 +177,27 @@ class Atom(object):
|
|||||||
self.save_as = _save_as_to_mapping(provides)
|
self.save_as = _save_as_to_mapping(provides)
|
||||||
self.version = (1, 0)
|
self.version = (1, 0)
|
||||||
self.inject = inject
|
self.inject = inject
|
||||||
|
self.requires = frozenset()
|
||||||
|
self.optional = frozenset()
|
||||||
|
|
||||||
def _build_arg_mapping(self, executor, requires=None, rebind=None,
|
def _build_arg_mapping(self, executor, requires=None, rebind=None,
|
||||||
auto_extract=True, ignore_list=None):
|
auto_extract=True, ignore_list=None):
|
||||||
self.rebind = _build_arg_mapping(self.name, requires, rebind,
|
req_arg, opt_arg = _build_arg_mapping(self.name, requires, rebind,
|
||||||
executor, auto_extract, ignore_list)
|
executor, auto_extract,
|
||||||
|
ignore_list)
|
||||||
|
|
||||||
|
self.rebind = {}
|
||||||
|
if opt_arg:
|
||||||
|
self.rebind.update(opt_arg)
|
||||||
|
if req_arg:
|
||||||
|
self.rebind.update(req_arg)
|
||||||
|
self.requires = frozenset(req_arg.values())
|
||||||
|
self.optional = frozenset(opt_arg.values())
|
||||||
|
if self.inject:
|
||||||
|
inject_set = set(six.iterkeys(self.inject))
|
||||||
|
self.requires -= inject_set
|
||||||
|
self.optional -= inject_set
|
||||||
|
|
||||||
out_of_order = self.provides.intersection(self.requires)
|
out_of_order = self.provides.intersection(self.requires)
|
||||||
if out_of_order:
|
if out_of_order:
|
||||||
raise exceptions.DependencyFailure(
|
raise exceptions.DependencyFailure(
|
||||||
@ -185,16 +225,3 @@ class Atom(object):
|
|||||||
dependency to satisfy).
|
dependency to satisfy).
|
||||||
"""
|
"""
|
||||||
return set(self.save_as)
|
return set(self.save_as)
|
||||||
|
|
||||||
@property
|
|
||||||
def requires(self):
|
|
||||||
"""Any inputs this atom requires to function (if applicable).
|
|
||||||
|
|
||||||
NOTE(harlowja): there can be no intersection between what this atom
|
|
||||||
requires and what it produces (since this would be an impossible
|
|
||||||
dependency to satisfy).
|
|
||||||
"""
|
|
||||||
requires = set(self.rebind.values())
|
|
||||||
if self.inject:
|
|
||||||
requires = requires - set(six.iterkeys(self.inject))
|
|
||||||
return requires
|
|
||||||
|
@ -54,9 +54,12 @@ class RetryAction(base.Action):
|
|||||||
|
|
||||||
def _get_retry_args(self, retry, addons=None):
|
def _get_retry_args(self, retry, addons=None):
|
||||||
scope_walker = self._walker_factory(retry)
|
scope_walker = self._walker_factory(retry)
|
||||||
arguments = self._storage.fetch_mapped_args(retry.rebind,
|
arguments = self._storage.fetch_mapped_args(
|
||||||
atom_name=retry.name,
|
retry.rebind,
|
||||||
scope_walker=scope_walker)
|
atom_name=retry.name,
|
||||||
|
scope_walker=scope_walker,
|
||||||
|
optional_args=retry.optional
|
||||||
|
)
|
||||||
history = self._storage.get_retry_history(retry.name)
|
history = self._storage.get_retry_history(retry.name)
|
||||||
arguments[retry_atom.EXECUTE_REVERT_HISTORY] = history
|
arguments[retry_atom.EXECUTE_REVERT_HISTORY] = history
|
||||||
if addons:
|
if addons:
|
||||||
|
@ -101,9 +101,12 @@ class TaskAction(base.Action):
|
|||||||
def schedule_execution(self, task):
|
def schedule_execution(self, task):
|
||||||
self.change_state(task, states.RUNNING, progress=0.0)
|
self.change_state(task, states.RUNNING, progress=0.0)
|
||||||
scope_walker = self._walker_factory(task)
|
scope_walker = self._walker_factory(task)
|
||||||
arguments = self._storage.fetch_mapped_args(task.rebind,
|
arguments = self._storage.fetch_mapped_args(
|
||||||
atom_name=task.name,
|
task.rebind,
|
||||||
scope_walker=scope_walker)
|
atom_name=task.name,
|
||||||
|
scope_walker=scope_walker,
|
||||||
|
optional_args=task.optional
|
||||||
|
)
|
||||||
if task.notifier.can_be_registered(task_atom.EVENT_UPDATE_PROGRESS):
|
if task.notifier.can_be_registered(task_atom.EVENT_UPDATE_PROGRESS):
|
||||||
progress_callback = functools.partial(self._on_update_progress,
|
progress_callback = functools.partial(self._on_update_progress,
|
||||||
task)
|
task)
|
||||||
@ -124,9 +127,12 @@ class TaskAction(base.Action):
|
|||||||
def schedule_reversion(self, task):
|
def schedule_reversion(self, task):
|
||||||
self.change_state(task, states.REVERTING, progress=0.0)
|
self.change_state(task, states.REVERTING, progress=0.0)
|
||||||
scope_walker = self._walker_factory(task)
|
scope_walker = self._walker_factory(task)
|
||||||
arguments = self._storage.fetch_mapped_args(task.rebind,
|
arguments = self._storage.fetch_mapped_args(
|
||||||
atom_name=task.name,
|
task.rebind,
|
||||||
scope_walker=scope_walker)
|
atom_name=task.name,
|
||||||
|
scope_walker=scope_walker,
|
||||||
|
optional_args=task.optional
|
||||||
|
)
|
||||||
task_uuid = self._storage.get_atom_uuid(task.name)
|
task_uuid = self._storage.get_atom_uuid(task.name)
|
||||||
task_result = self._storage.get(task.name)
|
task_result = self._storage.get(task.name)
|
||||||
failures = self._storage.get_failures()
|
failures = self._storage.get_failures()
|
||||||
|
93
taskflow/examples/optional_arguments.py
Normal file
93
taskflow/examples/optional_arguments.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright (C) 2015 Hewlett-Packard Development Company, L.P.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
# not use this file except in compliance with the License. You may obtain
|
||||||
|
# a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
# License for the specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from taskflow import engines
|
||||||
|
from taskflow.patterns import linear_flow
|
||||||
|
from taskflow import task
|
||||||
|
|
||||||
|
|
||||||
|
class TestTask(task.Task):
|
||||||
|
def execute(self, a, b=5):
|
||||||
|
result = a * b
|
||||||
|
return result
|
||||||
|
|
||||||
|
flow_no_inject = linear_flow.Flow("flow").add(TestTask(provides='result'))
|
||||||
|
flow_inject_a = linear_flow.Flow("flow").add(TestTask(provides='result',
|
||||||
|
inject={'a': 10}))
|
||||||
|
flow_inject_b = linear_flow.Flow("flow").add(TestTask(provides='result',
|
||||||
|
inject={'b': 1000}))
|
||||||
|
|
||||||
|
ASSERT = True
|
||||||
|
|
||||||
|
|
||||||
|
class MyTest(unittest.TestCase):
|
||||||
|
def test_my_test(self):
|
||||||
|
print("Expected result = 15")
|
||||||
|
result = engines.run(flow_no_inject, store={'a': 3})
|
||||||
|
print(result)
|
||||||
|
if ASSERT:
|
||||||
|
self.assertEqual(result,
|
||||||
|
{'a': 3, 'result': 15}
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Expected result = 39")
|
||||||
|
result = engines.run(flow_no_inject, store={'a': 3, 'b': 7})
|
||||||
|
print(result)
|
||||||
|
if ASSERT:
|
||||||
|
self.assertEqual(
|
||||||
|
result,
|
||||||
|
{'a': 3, 'b': 7, 'result': 21}
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Expected result = 200")
|
||||||
|
result = engines.run(flow_inject_a, store={'a': 3})
|
||||||
|
print(result)
|
||||||
|
if ASSERT:
|
||||||
|
self.assertEqual(
|
||||||
|
result,
|
||||||
|
{'a': 3, 'result': 50}
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Expected result = 400")
|
||||||
|
result = engines.run(flow_inject_a, store={'a': 3, 'b': 7})
|
||||||
|
print(result)
|
||||||
|
if ASSERT:
|
||||||
|
self.assertEqual(
|
||||||
|
result,
|
||||||
|
{'a': 3, 'b': 7, 'result': 70}
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Expected result = 40")
|
||||||
|
result = engines.run(flow_inject_b, store={'a': 3})
|
||||||
|
print(result)
|
||||||
|
if ASSERT:
|
||||||
|
self.assertEqual(
|
||||||
|
result,
|
||||||
|
{'a': 3, 'result': 3000}
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Expected result = 40")
|
||||||
|
result = engines.run(flow_inject_b, store={'a': 3, 'b': 7})
|
||||||
|
print(result)
|
||||||
|
if ASSERT:
|
||||||
|
self.assertEqual(
|
||||||
|
result,
|
||||||
|
{'a': 3, 'b': 7, 'result': 3000}
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -629,7 +629,8 @@ class Storage(object):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
def fetch_mapped_args(self, args_mapping,
|
def fetch_mapped_args(self, args_mapping,
|
||||||
atom_name=None, scope_walker=None):
|
atom_name=None, scope_walker=None,
|
||||||
|
optional_args=None):
|
||||||
"""Fetch arguments for an atom using an atoms argument mapping."""
|
"""Fetch arguments for an atom using an atoms argument mapping."""
|
||||||
|
|
||||||
def _get_results(looking_for, provider):
|
def _get_results(looking_for, provider):
|
||||||
@ -668,10 +669,14 @@ class Storage(object):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
with self._lock.read_lock():
|
with self._lock.read_lock():
|
||||||
|
if optional_args is None:
|
||||||
|
optional_args = []
|
||||||
|
|
||||||
if atom_name and atom_name not in self._atom_name_to_uuid:
|
if atom_name and atom_name not in self._atom_name_to_uuid:
|
||||||
raise exceptions.NotFound("Unknown atom name: %s" % atom_name)
|
raise exceptions.NotFound("Unknown atom name: %s" % atom_name)
|
||||||
if not args_mapping:
|
if not args_mapping:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# The order of lookup is the following:
|
# The order of lookup is the following:
|
||||||
#
|
#
|
||||||
# 1. Injected atom specific arguments.
|
# 1. Injected atom specific arguments.
|
||||||
@ -705,6 +710,8 @@ class Storage(object):
|
|||||||
try:
|
try:
|
||||||
possible_providers = self._reverse_mapping[name]
|
possible_providers = self._reverse_mapping[name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
if bound_name in optional_args:
|
||||||
|
continue
|
||||||
raise exceptions.NotFound("Name %r is not mapped as a"
|
raise exceptions.NotFound("Name %r is not mapped as a"
|
||||||
" produced output by any"
|
" produced output by any"
|
||||||
" providers" % name)
|
" providers" % name)
|
||||||
|
@ -351,6 +351,20 @@ class StorageTestMixin(object):
|
|||||||
self.assertRaises(exceptions.NotFound,
|
self.assertRaises(exceptions.NotFound,
|
||||||
s.fetch_mapped_args, {'viking': 'helmet'})
|
s.fetch_mapped_args, {'viking': 'helmet'})
|
||||||
|
|
||||||
|
def test_fetch_optional_args_found(self):
|
||||||
|
s = self._get_storage()
|
||||||
|
s.inject({'foo': 'bar', 'spam': 'eggs'})
|
||||||
|
self.assertEqual(s.fetch_mapped_args({'viking': 'spam'},
|
||||||
|
optional_args=set(['viking'])),
|
||||||
|
{'viking': 'eggs'})
|
||||||
|
|
||||||
|
def test_fetch_optional_args_not_found(self):
|
||||||
|
s = self._get_storage()
|
||||||
|
s.inject({'foo': 'bar', 'spam': 'eggs'})
|
||||||
|
self.assertEqual(s.fetch_mapped_args({'viking': 'helmet'},
|
||||||
|
optional_args=set(['viking'])),
|
||||||
|
{})
|
||||||
|
|
||||||
def test_set_and_get_task_state(self):
|
def test_set_and_get_task_state(self):
|
||||||
s = self._get_storage()
|
s = self._get_storage()
|
||||||
state = states.PENDING
|
state = states.PENDING
|
||||||
|
@ -120,10 +120,19 @@ class TaskTest(test.TestCase):
|
|||||||
def test_requires_ignores_optional(self):
|
def test_requires_ignores_optional(self):
|
||||||
my_task = DefaultArgTask()
|
my_task = DefaultArgTask()
|
||||||
self.assertEqual(my_task.requires, set(['spam']))
|
self.assertEqual(my_task.requires, set(['spam']))
|
||||||
|
self.assertEqual(my_task.optional, set(['eggs']))
|
||||||
|
|
||||||
def test_requires_allows_optional(self):
|
def test_requires_allows_optional(self):
|
||||||
my_task = DefaultArgTask(requires=('spam', 'eggs'))
|
my_task = DefaultArgTask(requires=('spam', 'eggs'))
|
||||||
self.assertEqual(my_task.requires, set(['spam', 'eggs']))
|
self.assertEqual(my_task.requires, set(['spam', 'eggs']))
|
||||||
|
self.assertEqual(my_task.optional, set())
|
||||||
|
|
||||||
|
def test_rebind_includes_optional(self):
|
||||||
|
my_task = DefaultArgTask()
|
||||||
|
self.assertEqual(my_task.rebind, {
|
||||||
|
'spam': 'spam',
|
||||||
|
'eggs': 'eggs',
|
||||||
|
})
|
||||||
|
|
||||||
def test_rebind_all_args(self):
|
def test_rebind_all_args(self):
|
||||||
my_task = MyTask(rebind={'spam': 'a', 'eggs': 'b', 'context': 'c'})
|
my_task = MyTask(rebind={'spam': 'a', 'eggs': 'b', 'context': 'c'})
|
||||||
|
Loading…
Reference in New Issue
Block a user