Merge "Map optional arguments as well as required arguments"

This commit is contained in:
Jenkins 2015-02-11 08:39:50 +00:00 committed by Gerrit Code Review
commit 761321dec7
8 changed files with 207 additions and 51 deletions

View File

@ -22,9 +22,12 @@ are and how to use those ways to accomplish your desired usage pattern.
Task/retry arguments
Set of names of task/retry arguments available as the ``requires``
property of the task/retry instance. When a task or retry object is
about to be executed values with these names are retrieved from storage
and passed to the ``execute`` method of the task/retry.
and/or ``optional`` property of the task/retry instance. When a task or
retry object is about to be executed values with these names are
retrieved from storage and passed to the ``execute`` method of the
task/retry. If any names in the ``requires`` property cannot be
found in storage, an exception will be thrown. Any names in the
``optional`` property that cannot be found are ignored.
Task/retry results
Set of names of task/retry results (what task/retry provides) available
@ -53,32 +56,26 @@ method of a task (or the |retry.execute| of a retry object).
.. doctest::
>>> class MyTask(task.Task):
... def execute(self, spam, eggs):
... def execute(self, spam, eggs, bacon=None):
... return spam + eggs
...
>>> sorted(MyTask().requires)
['eggs', 'spam']
>>> sorted(MyTask().optional)
['bacon']
Inference from the method signature is the ''simplest'' way to specify
arguments. Optional arguments (with default values), and special arguments like
``self``, ``*args`` and ``**kwargs`` are ignored during inference (as these
names have special meaning/usage in python).
arguments. Special arguments like ``self``, ``*args`` and ``**kwargs`` are
ignored during inference (as these names have special meaning/usage in python).
.. doctest::
>>> class MyTask(task.Task):
... def execute(self, spam, eggs=()):
... return spam + eggs
...
>>> MyTask().requires
set(['spam'])
>>>
>>> class UniTask(task.Task):
... def execute(self, *args, **kwargs):
... pass
...
>>> UniTask().requires
set([])
frozenset([])
.. make vim sphinx highlighter* happy**

View File

@ -82,33 +82,50 @@ def _build_arg_mapping(atom_name, reqs, rebind_args, function, do_infer,
well as verify that the final argument mapping does not have missing or
extra arguments (where applicable).
"""
atom_args = reflection.get_callable_args(function, required_only=True)
# build a list of required arguments based on function signature
req_args = reflection.get_callable_args(function, required_only=True)
all_args = reflection.get_callable_args(function, required_only=False)
# remove arguments that are part of ignore_list
if ignore_list:
for arg in ignore_list:
if arg in atom_args:
atom_args.remove(arg)
if arg in req_args:
req_args.remove(arg)
else:
ignore_list = []
result = {}
required = {}
# add reqs to required mappings
if reqs:
result.update((a, a) for a in reqs)
required.update((a, a) for a in reqs)
# add req_args to required mappings if do_infer is set
if do_infer:
result.update((a, a) for a in atom_args)
result.update(_build_rebind_dict(atom_args, rebind_args))
required.update((a, a) for a in req_args)
# update required mappings based on rebind_args
required.update(_build_rebind_dict(req_args, rebind_args))
if do_infer:
opt_args = set(all_args) - set(required) - set(ignore_list)
optional = dict((a, a) for a in opt_args)
else:
optional = {}
if not reflection.accepts_kwargs(function):
all_args = reflection.get_callable_args(function, required_only=False)
extra_args = set(result) - set(all_args)
extra_args = set(required) - set(all_args)
if extra_args:
extra_args_str = ', '.join(sorted(extra_args))
raise ValueError('Extra arguments given to atom %s: %s'
% (atom_name, extra_args_str))
# NOTE(imelnikov): don't use set to preserve order in error message
missing_args = [arg for arg in atom_args if arg not in result]
missing_args = [arg for arg in req_args if arg not in required]
if missing_args:
raise ValueError('Missing arguments for atom %s: %s'
% (atom_name, ' ,'.join(missing_args)))
return result
return required, optional
class Atom(object):
@ -146,6 +163,13 @@ class Atom(object):
commences (this allows for providing atom *local* values that
do not need to be provided by other atoms/dependents).
:ivar inject: See parameter ``inject``.
:ivar requires: Any inputs this atom requires to function (if applicable).
NOTE(harlowja): there can be no intersection between what
this atom requires and what it produces (since this would
be an impossible dependency to satisfy).
:ivar optional: Any inputs that are optional for this atom's execute
method.
"""
def __init__(self, name=None, provides=None, inject=None):
@ -153,11 +177,27 @@ class Atom(object):
self.save_as = _save_as_to_mapping(provides)
self.version = (1, 0)
self.inject = inject
self.requires = frozenset()
self.optional = frozenset()
def _build_arg_mapping(self, executor, requires=None, rebind=None,
auto_extract=True, ignore_list=None):
self.rebind = _build_arg_mapping(self.name, requires, rebind,
executor, auto_extract, ignore_list)
req_arg, opt_arg = _build_arg_mapping(self.name, requires, rebind,
executor, auto_extract,
ignore_list)
self.rebind = {}
if opt_arg:
self.rebind.update(opt_arg)
if req_arg:
self.rebind.update(req_arg)
self.requires = frozenset(req_arg.values())
self.optional = frozenset(opt_arg.values())
if self.inject:
inject_set = set(six.iterkeys(self.inject))
self.requires -= inject_set
self.optional -= inject_set
out_of_order = self.provides.intersection(self.requires)
if out_of_order:
raise exceptions.DependencyFailure(
@ -185,16 +225,3 @@ class Atom(object):
dependency to satisfy).
"""
return set(self.save_as)
@property
def requires(self):
"""Any inputs this atom requires to function (if applicable).
NOTE(harlowja): there can be no intersection between what this atom
requires and what it produces (since this would be an impossible
dependency to satisfy).
"""
requires = set(self.rebind.values())
if self.inject:
requires = requires - set(six.iterkeys(self.inject))
return requires

View File

@ -54,9 +54,12 @@ class RetryAction(base.Action):
def _get_retry_args(self, retry, addons=None):
scope_walker = self._walker_factory(retry)
arguments = self._storage.fetch_mapped_args(retry.rebind,
atom_name=retry.name,
scope_walker=scope_walker)
arguments = self._storage.fetch_mapped_args(
retry.rebind,
atom_name=retry.name,
scope_walker=scope_walker,
optional_args=retry.optional
)
history = self._storage.get_retry_history(retry.name)
arguments[retry_atom.EXECUTE_REVERT_HISTORY] = history
if addons:

View File

@ -101,9 +101,12 @@ class TaskAction(base.Action):
def schedule_execution(self, task):
self.change_state(task, states.RUNNING, progress=0.0)
scope_walker = self._walker_factory(task)
arguments = self._storage.fetch_mapped_args(task.rebind,
atom_name=task.name,
scope_walker=scope_walker)
arguments = self._storage.fetch_mapped_args(
task.rebind,
atom_name=task.name,
scope_walker=scope_walker,
optional_args=task.optional
)
if task.notifier.can_be_registered(task_atom.EVENT_UPDATE_PROGRESS):
progress_callback = functools.partial(self._on_update_progress,
task)
@ -124,9 +127,12 @@ class TaskAction(base.Action):
def schedule_reversion(self, task):
self.change_state(task, states.REVERTING, progress=0.0)
scope_walker = self._walker_factory(task)
arguments = self._storage.fetch_mapped_args(task.rebind,
atom_name=task.name,
scope_walker=scope_walker)
arguments = self._storage.fetch_mapped_args(
task.rebind,
atom_name=task.name,
scope_walker=scope_walker,
optional_args=task.optional
)
task_uuid = self._storage.get_atom_uuid(task.name)
task_result = self._storage.get(task.name)
failures = self._storage.get_failures()

View File

@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2015 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import unittest
from taskflow import engines
from taskflow.patterns import linear_flow
from taskflow import task
class TestTask(task.Task):
def execute(self, a, b=5):
result = a * b
return result
flow_no_inject = linear_flow.Flow("flow").add(TestTask(provides='result'))
flow_inject_a = linear_flow.Flow("flow").add(TestTask(provides='result',
inject={'a': 10}))
flow_inject_b = linear_flow.Flow("flow").add(TestTask(provides='result',
inject={'b': 1000}))
ASSERT = True
class MyTest(unittest.TestCase):
def test_my_test(self):
print("Expected result = 15")
result = engines.run(flow_no_inject, store={'a': 3})
print(result)
if ASSERT:
self.assertEqual(result,
{'a': 3, 'result': 15}
)
print("Expected result = 39")
result = engines.run(flow_no_inject, store={'a': 3, 'b': 7})
print(result)
if ASSERT:
self.assertEqual(
result,
{'a': 3, 'b': 7, 'result': 21}
)
print("Expected result = 200")
result = engines.run(flow_inject_a, store={'a': 3})
print(result)
if ASSERT:
self.assertEqual(
result,
{'a': 3, 'result': 50}
)
print("Expected result = 400")
result = engines.run(flow_inject_a, store={'a': 3, 'b': 7})
print(result)
if ASSERT:
self.assertEqual(
result,
{'a': 3, 'b': 7, 'result': 70}
)
print("Expected result = 40")
result = engines.run(flow_inject_b, store={'a': 3})
print(result)
if ASSERT:
self.assertEqual(
result,
{'a': 3, 'result': 3000}
)
print("Expected result = 40")
result = engines.run(flow_inject_b, store={'a': 3, 'b': 7})
print(result)
if ASSERT:
self.assertEqual(
result,
{'a': 3, 'b': 7, 'result': 3000}
)
if __name__ == '__main__':
unittest.main()

View File

@ -629,7 +629,8 @@ class Storage(object):
return results
def fetch_mapped_args(self, args_mapping,
atom_name=None, scope_walker=None):
atom_name=None, scope_walker=None,
optional_args=None):
"""Fetch arguments for an atom using an atoms argument mapping."""
def _get_results(looking_for, provider):
@ -668,10 +669,14 @@ class Storage(object):
return []
with self._lock.read_lock():
if optional_args is None:
optional_args = []
if atom_name and atom_name not in self._atom_name_to_uuid:
raise exceptions.NotFound("Unknown atom name: %s" % atom_name)
if not args_mapping:
return {}
# The order of lookup is the following:
#
# 1. Injected atom specific arguments.
@ -705,6 +710,8 @@ class Storage(object):
try:
possible_providers = self._reverse_mapping[name]
except KeyError:
if bound_name in optional_args:
continue
raise exceptions.NotFound("Name %r is not mapped as a"
" produced output by any"
" providers" % name)

View File

@ -351,6 +351,20 @@ class StorageTestMixin(object):
self.assertRaises(exceptions.NotFound,
s.fetch_mapped_args, {'viking': 'helmet'})
def test_fetch_optional_args_found(self):
s = self._get_storage()
s.inject({'foo': 'bar', 'spam': 'eggs'})
self.assertEqual(s.fetch_mapped_args({'viking': 'spam'},
optional_args=set(['viking'])),
{'viking': 'eggs'})
def test_fetch_optional_args_not_found(self):
s = self._get_storage()
s.inject({'foo': 'bar', 'spam': 'eggs'})
self.assertEqual(s.fetch_mapped_args({'viking': 'helmet'},
optional_args=set(['viking'])),
{})
def test_set_and_get_task_state(self):
s = self._get_storage()
state = states.PENDING

View File

@ -120,10 +120,19 @@ class TaskTest(test.TestCase):
def test_requires_ignores_optional(self):
my_task = DefaultArgTask()
self.assertEqual(my_task.requires, set(['spam']))
self.assertEqual(my_task.optional, set(['eggs']))
def test_requires_allows_optional(self):
my_task = DefaultArgTask(requires=('spam', 'eggs'))
self.assertEqual(my_task.requires, set(['spam', 'eggs']))
self.assertEqual(my_task.optional, set())
def test_rebind_includes_optional(self):
my_task = DefaultArgTask()
self.assertEqual(my_task.rebind, {
'spam': 'spam',
'eggs': 'eggs',
})
def test_rebind_all_args(self):
my_task = MyTask(rebind={'spam': 'a', 'eggs': 'b', 'context': 'c'})