Merge "Map optional arguments as well as required arguments"

This commit is contained in:
Jenkins 2015-02-11 08:39:50 +00:00 committed by Gerrit Code Review
commit 761321dec7
8 changed files with 207 additions and 51 deletions

View File

@ -22,9 +22,12 @@ are and how to use those ways to accomplish your desired usage pattern.
Task/retry arguments Task/retry arguments
Set of names of task/retry arguments available as the ``requires`` Set of names of task/retry arguments available as the ``requires``
property of the task/retry instance. When a task or retry object is and/or ``optional`` property of the task/retry instance. When a task or
about to be executed values with these names are retrieved from storage retry object is about to be executed values with these names are
and passed to the ``execute`` method of the task/retry. retrieved from storage and passed to the ``execute`` method of the
task/retry. If any names in the ``requires`` property cannot be
found in storage, an exception will be thrown. Any names in the
``optional`` property that cannot be found are ignored.
Task/retry results Task/retry results
Set of names of task/retry results (what task/retry provides) available Set of names of task/retry results (what task/retry provides) available
@ -53,32 +56,26 @@ method of a task (or the |retry.execute| of a retry object).
.. doctest:: .. doctest::
>>> class MyTask(task.Task): >>> class MyTask(task.Task):
... def execute(self, spam, eggs): ... def execute(self, spam, eggs, bacon=None):
... return spam + eggs ... return spam + eggs
... ...
>>> sorted(MyTask().requires) >>> sorted(MyTask().requires)
['eggs', 'spam'] ['eggs', 'spam']
>>> sorted(MyTask().optional)
['bacon']
Inference from the method signature is the ''simplest'' way to specify Inference from the method signature is the ''simplest'' way to specify
arguments. Optional arguments (with default values), and special arguments like arguments. Special arguments like ``self``, ``*args`` and ``**kwargs`` are
``self``, ``*args`` and ``**kwargs`` are ignored during inference (as these ignored during inference (as these names have special meaning/usage in python).
names have special meaning/usage in python).
.. doctest:: .. doctest::
>>> class MyTask(task.Task):
... def execute(self, spam, eggs=()):
... return spam + eggs
...
>>> MyTask().requires
set(['spam'])
>>>
>>> class UniTask(task.Task): >>> class UniTask(task.Task):
... def execute(self, *args, **kwargs): ... def execute(self, *args, **kwargs):
... pass ... pass
... ...
>>> UniTask().requires >>> UniTask().requires
set([]) frozenset([])
.. make vim sphinx highlighter* happy** .. make vim sphinx highlighter* happy**

View File

@ -82,33 +82,50 @@ def _build_arg_mapping(atom_name, reqs, rebind_args, function, do_infer,
well as verify that the final argument mapping does not have missing or well as verify that the final argument mapping does not have missing or
extra arguments (where applicable). extra arguments (where applicable).
""" """
atom_args = reflection.get_callable_args(function, required_only=True)
# build a list of required arguments based on function signature
req_args = reflection.get_callable_args(function, required_only=True)
all_args = reflection.get_callable_args(function, required_only=False)
# remove arguments that are part of ignore_list
if ignore_list: if ignore_list:
for arg in ignore_list: for arg in ignore_list:
if arg in atom_args: if arg in req_args:
atom_args.remove(arg) req_args.remove(arg)
else:
ignore_list = []
result = {} required = {}
# add reqs to required mappings
if reqs: if reqs:
result.update((a, a) for a in reqs) required.update((a, a) for a in reqs)
# add req_args to required mappings if do_infer is set
if do_infer: if do_infer:
result.update((a, a) for a in atom_args) required.update((a, a) for a in req_args)
result.update(_build_rebind_dict(atom_args, rebind_args))
# update required mappings based on rebind_args
required.update(_build_rebind_dict(req_args, rebind_args))
if do_infer:
opt_args = set(all_args) - set(required) - set(ignore_list)
optional = dict((a, a) for a in opt_args)
else:
optional = {}
if not reflection.accepts_kwargs(function): if not reflection.accepts_kwargs(function):
all_args = reflection.get_callable_args(function, required_only=False) extra_args = set(required) - set(all_args)
extra_args = set(result) - set(all_args)
if extra_args: if extra_args:
extra_args_str = ', '.join(sorted(extra_args)) extra_args_str = ', '.join(sorted(extra_args))
raise ValueError('Extra arguments given to atom %s: %s' raise ValueError('Extra arguments given to atom %s: %s'
% (atom_name, extra_args_str)) % (atom_name, extra_args_str))
# NOTE(imelnikov): don't use set to preserve order in error message # NOTE(imelnikov): don't use set to preserve order in error message
missing_args = [arg for arg in atom_args if arg not in result] missing_args = [arg for arg in req_args if arg not in required]
if missing_args: if missing_args:
raise ValueError('Missing arguments for atom %s: %s' raise ValueError('Missing arguments for atom %s: %s'
% (atom_name, ' ,'.join(missing_args))) % (atom_name, ' ,'.join(missing_args)))
return result return required, optional
class Atom(object): class Atom(object):
@ -146,6 +163,13 @@ class Atom(object):
commences (this allows for providing atom *local* values that commences (this allows for providing atom *local* values that
do not need to be provided by other atoms/dependents). do not need to be provided by other atoms/dependents).
:ivar inject: See parameter ``inject``. :ivar inject: See parameter ``inject``.
:ivar requires: Any inputs this atom requires to function (if applicable).
NOTE(harlowja): there can be no intersection between what
this atom requires and what it produces (since this would
be an impossible dependency to satisfy).
:ivar optional: Any inputs that are optional for this atom's execute
method.
""" """
def __init__(self, name=None, provides=None, inject=None): def __init__(self, name=None, provides=None, inject=None):
@ -153,11 +177,27 @@ class Atom(object):
self.save_as = _save_as_to_mapping(provides) self.save_as = _save_as_to_mapping(provides)
self.version = (1, 0) self.version = (1, 0)
self.inject = inject self.inject = inject
self.requires = frozenset()
self.optional = frozenset()
def _build_arg_mapping(self, executor, requires=None, rebind=None, def _build_arg_mapping(self, executor, requires=None, rebind=None,
auto_extract=True, ignore_list=None): auto_extract=True, ignore_list=None):
self.rebind = _build_arg_mapping(self.name, requires, rebind, req_arg, opt_arg = _build_arg_mapping(self.name, requires, rebind,
executor, auto_extract, ignore_list) executor, auto_extract,
ignore_list)
self.rebind = {}
if opt_arg:
self.rebind.update(opt_arg)
if req_arg:
self.rebind.update(req_arg)
self.requires = frozenset(req_arg.values())
self.optional = frozenset(opt_arg.values())
if self.inject:
inject_set = set(six.iterkeys(self.inject))
self.requires -= inject_set
self.optional -= inject_set
out_of_order = self.provides.intersection(self.requires) out_of_order = self.provides.intersection(self.requires)
if out_of_order: if out_of_order:
raise exceptions.DependencyFailure( raise exceptions.DependencyFailure(
@ -185,16 +225,3 @@ class Atom(object):
dependency to satisfy). dependency to satisfy).
""" """
return set(self.save_as) return set(self.save_as)
@property
def requires(self):
"""Any inputs this atom requires to function (if applicable).
NOTE(harlowja): there can be no intersection between what this atom
requires and what it produces (since this would be an impossible
dependency to satisfy).
"""
requires = set(self.rebind.values())
if self.inject:
requires = requires - set(six.iterkeys(self.inject))
return requires

View File

@ -54,9 +54,12 @@ class RetryAction(base.Action):
def _get_retry_args(self, retry, addons=None): def _get_retry_args(self, retry, addons=None):
scope_walker = self._walker_factory(retry) scope_walker = self._walker_factory(retry)
arguments = self._storage.fetch_mapped_args(retry.rebind, arguments = self._storage.fetch_mapped_args(
atom_name=retry.name, retry.rebind,
scope_walker=scope_walker) atom_name=retry.name,
scope_walker=scope_walker,
optional_args=retry.optional
)
history = self._storage.get_retry_history(retry.name) history = self._storage.get_retry_history(retry.name)
arguments[retry_atom.EXECUTE_REVERT_HISTORY] = history arguments[retry_atom.EXECUTE_REVERT_HISTORY] = history
if addons: if addons:

View File

@ -101,9 +101,12 @@ class TaskAction(base.Action):
def schedule_execution(self, task): def schedule_execution(self, task):
self.change_state(task, states.RUNNING, progress=0.0) self.change_state(task, states.RUNNING, progress=0.0)
scope_walker = self._walker_factory(task) scope_walker = self._walker_factory(task)
arguments = self._storage.fetch_mapped_args(task.rebind, arguments = self._storage.fetch_mapped_args(
atom_name=task.name, task.rebind,
scope_walker=scope_walker) atom_name=task.name,
scope_walker=scope_walker,
optional_args=task.optional
)
if task.notifier.can_be_registered(task_atom.EVENT_UPDATE_PROGRESS): if task.notifier.can_be_registered(task_atom.EVENT_UPDATE_PROGRESS):
progress_callback = functools.partial(self._on_update_progress, progress_callback = functools.partial(self._on_update_progress,
task) task)
@ -124,9 +127,12 @@ class TaskAction(base.Action):
def schedule_reversion(self, task): def schedule_reversion(self, task):
self.change_state(task, states.REVERTING, progress=0.0) self.change_state(task, states.REVERTING, progress=0.0)
scope_walker = self._walker_factory(task) scope_walker = self._walker_factory(task)
arguments = self._storage.fetch_mapped_args(task.rebind, arguments = self._storage.fetch_mapped_args(
atom_name=task.name, task.rebind,
scope_walker=scope_walker) atom_name=task.name,
scope_walker=scope_walker,
optional_args=task.optional
)
task_uuid = self._storage.get_atom_uuid(task.name) task_uuid = self._storage.get_atom_uuid(task.name)
task_result = self._storage.get(task.name) task_result = self._storage.get(task.name)
failures = self._storage.get_failures() failures = self._storage.get_failures()

View File

@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2015 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import unittest
from taskflow import engines
from taskflow.patterns import linear_flow
from taskflow import task
class TestTask(task.Task):
def execute(self, a, b=5):
result = a * b
return result
flow_no_inject = linear_flow.Flow("flow").add(TestTask(provides='result'))
flow_inject_a = linear_flow.Flow("flow").add(TestTask(provides='result',
inject={'a': 10}))
flow_inject_b = linear_flow.Flow("flow").add(TestTask(provides='result',
inject={'b': 1000}))
ASSERT = True
class MyTest(unittest.TestCase):
def test_my_test(self):
print("Expected result = 15")
result = engines.run(flow_no_inject, store={'a': 3})
print(result)
if ASSERT:
self.assertEqual(result,
{'a': 3, 'result': 15}
)
print("Expected result = 39")
result = engines.run(flow_no_inject, store={'a': 3, 'b': 7})
print(result)
if ASSERT:
self.assertEqual(
result,
{'a': 3, 'b': 7, 'result': 21}
)
print("Expected result = 200")
result = engines.run(flow_inject_a, store={'a': 3})
print(result)
if ASSERT:
self.assertEqual(
result,
{'a': 3, 'result': 50}
)
print("Expected result = 400")
result = engines.run(flow_inject_a, store={'a': 3, 'b': 7})
print(result)
if ASSERT:
self.assertEqual(
result,
{'a': 3, 'b': 7, 'result': 70}
)
print("Expected result = 40")
result = engines.run(flow_inject_b, store={'a': 3})
print(result)
if ASSERT:
self.assertEqual(
result,
{'a': 3, 'result': 3000}
)
print("Expected result = 40")
result = engines.run(flow_inject_b, store={'a': 3, 'b': 7})
print(result)
if ASSERT:
self.assertEqual(
result,
{'a': 3, 'b': 7, 'result': 3000}
)
if __name__ == '__main__':
unittest.main()

View File

@ -629,7 +629,8 @@ class Storage(object):
return results return results
def fetch_mapped_args(self, args_mapping, def fetch_mapped_args(self, args_mapping,
atom_name=None, scope_walker=None): atom_name=None, scope_walker=None,
optional_args=None):
"""Fetch arguments for an atom using an atoms argument mapping.""" """Fetch arguments for an atom using an atoms argument mapping."""
def _get_results(looking_for, provider): def _get_results(looking_for, provider):
@ -668,10 +669,14 @@ class Storage(object):
return [] return []
with self._lock.read_lock(): with self._lock.read_lock():
if optional_args is None:
optional_args = []
if atom_name and atom_name not in self._atom_name_to_uuid: if atom_name and atom_name not in self._atom_name_to_uuid:
raise exceptions.NotFound("Unknown atom name: %s" % atom_name) raise exceptions.NotFound("Unknown atom name: %s" % atom_name)
if not args_mapping: if not args_mapping:
return {} return {}
# The order of lookup is the following: # The order of lookup is the following:
# #
# 1. Injected atom specific arguments. # 1. Injected atom specific arguments.
@ -705,6 +710,8 @@ class Storage(object):
try: try:
possible_providers = self._reverse_mapping[name] possible_providers = self._reverse_mapping[name]
except KeyError: except KeyError:
if bound_name in optional_args:
continue
raise exceptions.NotFound("Name %r is not mapped as a" raise exceptions.NotFound("Name %r is not mapped as a"
" produced output by any" " produced output by any"
" providers" % name) " providers" % name)

View File

@ -351,6 +351,20 @@ class StorageTestMixin(object):
self.assertRaises(exceptions.NotFound, self.assertRaises(exceptions.NotFound,
s.fetch_mapped_args, {'viking': 'helmet'}) s.fetch_mapped_args, {'viking': 'helmet'})
def test_fetch_optional_args_found(self):
s = self._get_storage()
s.inject({'foo': 'bar', 'spam': 'eggs'})
self.assertEqual(s.fetch_mapped_args({'viking': 'spam'},
optional_args=set(['viking'])),
{'viking': 'eggs'})
def test_fetch_optional_args_not_found(self):
s = self._get_storage()
s.inject({'foo': 'bar', 'spam': 'eggs'})
self.assertEqual(s.fetch_mapped_args({'viking': 'helmet'},
optional_args=set(['viking'])),
{})
def test_set_and_get_task_state(self): def test_set_and_get_task_state(self):
s = self._get_storage() s = self._get_storage()
state = states.PENDING state = states.PENDING

View File

@ -120,10 +120,19 @@ class TaskTest(test.TestCase):
def test_requires_ignores_optional(self): def test_requires_ignores_optional(self):
my_task = DefaultArgTask() my_task = DefaultArgTask()
self.assertEqual(my_task.requires, set(['spam'])) self.assertEqual(my_task.requires, set(['spam']))
self.assertEqual(my_task.optional, set(['eggs']))
def test_requires_allows_optional(self): def test_requires_allows_optional(self):
my_task = DefaultArgTask(requires=('spam', 'eggs')) my_task = DefaultArgTask(requires=('spam', 'eggs'))
self.assertEqual(my_task.requires, set(['spam', 'eggs'])) self.assertEqual(my_task.requires, set(['spam', 'eggs']))
self.assertEqual(my_task.optional, set())
def test_rebind_includes_optional(self):
my_task = DefaultArgTask()
self.assertEqual(my_task.rebind, {
'spam': 'spam',
'eggs': 'eggs',
})
def test_rebind_all_args(self): def test_rebind_all_args(self):
my_task = MyTask(rebind={'spam': 'a', 'eggs': 'b', 'context': 'c'}) my_task = MyTask(rebind={'spam': 'a', 'eggs': 'b', 'context': 'c'})