Use ordered[set/dict] to retain ordering

Instead of using always using a set/dict which do not retain
use a ordered set and a ordered dict for requires, optional,
and provides and rebind mappings types so that the ordering
of these containers is maintained later when they are used.

These ordering can be useful depending on the atom type (such
as in a map and reduce tasks).

Partial-Bug: 1357117

Change-Id: I365d11bbba4aa221bc36ca15441acecf199b4d56
This commit is contained in:
Joshua Harlow
2015-03-02 15:00:41 -08:00
parent 847d87db6a
commit 67f0f51464
6 changed files with 206 additions and 143 deletions

View File

@@ -74,8 +74,8 @@ ignored during inference (as these names have special meaning/usage in python).
... def execute(self, *args, **kwargs): ... def execute(self, *args, **kwargs):
... pass ... pass
... ...
>>> UniTask().requires >>> sorted(UniTask().requires)
frozenset([]) []
.. make vim sphinx highlighter* happy** .. make vim sphinx highlighter* happy**
@@ -214,8 +214,8 @@ name of the value.
... def execute(self): ... def execute(self):
... return 42 ... return 42
... ...
>>> TheAnswerReturningTask(provides='the_answer').provides >>> sorted(TheAnswerReturningTask(provides='the_answer').provides)
frozenset(['the_answer']) ['the_answer']
Returning a tuple Returning a tuple
+++++++++++++++++ +++++++++++++++++

View File

@@ -16,14 +16,28 @@
# under the License. # under the License.
import abc import abc
import collections
import itertools
try:
from collections import OrderedDict # noqa
except ImportError:
from ordereddict import OrderedDict # noqa
from oslo_utils import reflection from oslo_utils import reflection
import six import six
from six.moves import zip as compat_zip
from taskflow import exceptions from taskflow import exceptions
from taskflow.types import sets
from taskflow.utils import misc from taskflow.utils import misc
# Helper types tuples...
_sequence_types = (list, tuple, collections.Sequence)
_set_types = (set, collections.Set)
def _save_as_to_mapping(save_as): def _save_as_to_mapping(save_as):
"""Convert save_as to mapping name => index. """Convert save_as to mapping name => index.
@@ -33,23 +47,24 @@ def _save_as_to_mapping(save_as):
# outside of code so that it's more easily understandable, since what an # outside of code so that it's more easily understandable, since what an
# atom returns is pretty crucial for other later operations. # atom returns is pretty crucial for other later operations.
if save_as is None: if save_as is None:
return {} return OrderedDict()
if isinstance(save_as, six.string_types): if isinstance(save_as, six.string_types):
# NOTE(harlowja): this means that your atom will only return one item # NOTE(harlowja): this means that your atom will only return one item
# instead of a dictionary-like object or a indexable object (like a # instead of a dictionary-like object or a indexable object (like a
# list or tuple). # list or tuple).
return {save_as: None} return OrderedDict([(save_as, None)])
elif isinstance(save_as, (tuple, list)): elif isinstance(save_as, _sequence_types):
# NOTE(harlowja): this means that your atom will return a indexable # NOTE(harlowja): this means that your atom will return a indexable
# object, like a list or tuple and the results can be mapped by index # object, like a list or tuple and the results can be mapped by index
# to that tuple/list that is returned for others to use. # to that tuple/list that is returned for others to use.
return dict((key, num) for num, key in enumerate(save_as)) return OrderedDict((key, num) for num, key in enumerate(save_as))
elif isinstance(save_as, set): elif isinstance(save_as, _set_types):
# NOTE(harlowja): in the case where a set is given we will not be # NOTE(harlowja): in the case where a set is given we will not be
# able to determine the numeric ordering in a reliable way (since it is # able to determine the numeric ordering in a reliable way (since it
# a unordered set) so the only way for us to easily map the result of # may be an unordered set) so the only way for us to easily map the
# the atom will be via the key itself. # result of the atom will be via the key itself.
return dict((key, key) for key in save_as) return OrderedDict((key, key) for key in save_as)
else:
raise TypeError('Atom provides parameter ' raise TypeError('Atom provides parameter '
'should be str, set or tuple/list, not %r' % save_as) 'should be str, set or tuple/list, not %r' % save_as)
@@ -62,9 +77,9 @@ def _build_rebind_dict(args, rebind_args):
new name onto the required name). new name onto the required name).
""" """
if rebind_args is None: if rebind_args is None:
return {} return OrderedDict()
elif isinstance(rebind_args, (list, tuple)): elif isinstance(rebind_args, (list, tuple)):
rebind = dict(zip(args, rebind_args)) rebind = OrderedDict(compat_zip(args, rebind_args))
if len(args) < len(rebind_args): if len(args) < len(rebind_args):
rebind.update((a, a) for a in rebind_args[len(args):]) rebind.update((a, a) for a in rebind_args[len(args):])
return rebind return rebind
@@ -85,11 +100,11 @@ def _build_arg_mapping(atom_name, reqs, rebind_args, function, do_infer,
extra arguments (where applicable). extra arguments (where applicable).
""" """
# build a list of required arguments based on function signature # Build a list of required arguments based on function signature.
req_args = reflection.get_callable_args(function, required_only=True) req_args = reflection.get_callable_args(function, required_only=True)
all_args = reflection.get_callable_args(function, required_only=False) all_args = reflection.get_callable_args(function, required_only=False)
# remove arguments that are part of ignore_list # Remove arguments that are part of ignore list.
if ignore_list: if ignore_list:
for arg in ignore_list: for arg in ignore_list:
if arg in req_args: if arg in req_args:
@@ -97,39 +112,45 @@ def _build_arg_mapping(atom_name, reqs, rebind_args, function, do_infer,
else: else:
ignore_list = [] ignore_list = []
required = {} # Build the required names.
# add reqs to required mappings required = OrderedDict()
# Add required arguments to required mappings if inference is enabled.
if do_infer:
required.update((a, a) for a in req_args)
# Add additional manually provided requirements to required mappings.
if reqs: if reqs:
if isinstance(reqs, six.string_types): if isinstance(reqs, six.string_types):
required.update({reqs: reqs}) required.update({reqs: reqs})
else: else:
required.update((a, a) for a in reqs) required.update((a, a) for a in reqs)
# add req_args to required mappings if do_infer is set # Update required mappings values based on rebinding of arguments names.
if do_infer:
required.update((a, a) for a in req_args)
# update required mappings based on rebind_args
required.update(_build_rebind_dict(req_args, rebind_args)) required.update(_build_rebind_dict(req_args, rebind_args))
# Determine if there are optional arguments that we may or may not take.
if do_infer: if do_infer:
opt_args = set(all_args) - set(required) - set(ignore_list) opt_args = sets.OrderedSet(all_args)
optional = dict((a, a) for a in opt_args) opt_args = opt_args - set(itertools.chain(six.iterkeys(required),
iter(ignore_list)))
optional = OrderedDict((a, a) for a in opt_args)
else: else:
optional = {} optional = OrderedDict()
# Check if we are given some extra arguments that we aren't able to accept.
if not reflection.accepts_kwargs(function): if not reflection.accepts_kwargs(function):
extra_args = set(required) - set(all_args) extra_args = sets.OrderedSet(six.iterkeys(required))
extra_args -= all_args
if extra_args: if extra_args:
extra_args_str = ', '.join(sorted(extra_args))
raise ValueError('Extra arguments given to atom %s: %s' raise ValueError('Extra arguments given to atom %s: %s'
% (atom_name, extra_args_str)) % (atom_name, list(extra_args)))
# NOTE(imelnikov): don't use set to preserve order in error message # NOTE(imelnikov): don't use set to preserve order in error message
missing_args = [arg for arg in req_args if arg not in required] missing_args = [arg for arg in req_args if arg not in required]
if missing_args: if missing_args:
raise ValueError('Missing arguments for atom %s: %s' raise ValueError('Missing arguments for atom %s: %s'
% (atom_name, ' ,'.join(missing_args))) % (atom_name, missing_args))
return required, optional return required, optional
@@ -161,52 +182,53 @@ class Atom(object):
with this atom. It can be useful in resuming older versions with this atom. It can be useful in resuming older versions
of atoms. Standard major, minor versioning concepts of atoms. Standard major, minor versioning concepts
should apply. should apply.
:ivar save_as: An *immutable* output ``resource`` name dictionary this atom :ivar save_as: An *immutable* output ``resource`` name
produces that other atoms may depend on this atom providing. :py:class:`.OrderedDict` this atom produces that other
The format is output index (or key when a dictionary atoms may depend on this atom providing. The format is
is returned from the execute method) to stored argument output index (or key when a dictionary is returned from
name. the execute method) to stored argument name.
:ivar rebind: An *immutable* input ``resource`` mapping dictionary that :ivar rebind: An *immutable* input ``resource`` :py:class:`.OrderedDict`
can be used to alter the inputs given to this atom. It is that can be used to alter the inputs given to this atom. It
typically used for mapping a prior atoms output into is typically used for mapping a prior atoms output into
the names that this atom expects (in a way this is like the names that this atom expects (in a way this is like
remapping a namespace of another atom into the namespace remapping a namespace of another atom into the namespace
of this atom). of this atom).
:ivar inject: See parameter ``inject``. :ivar inject: See parameter ``inject``.
:ivar name: See parameter ``name``. :ivar name: See parameter ``name``.
:ivar requires: An *immutable* set of inputs this atom requires to :ivar requires: A :py:class:`~taskflow.types.sets.OrderedSet` of inputs
function. this atom requires to function.
:ivar optional: An *immutable* set of inputs that are optional for this :ivar optional: A :py:class:`~taskflow.types.sets.OrderedSet` of inputs
atom to function. that are optional for this atom to function.
:ivar provides: An *immutable* set of outputs this atom produces. :ivar provides: A :py:class:`~taskflow.types.sets.OrderedSet` of outputs
this atom produces.
""" """
def __init__(self, name=None, provides=None, inject=None): def __init__(self, name=None, provides=None, inject=None):
self.name = name self.name = name
self.save_as = _save_as_to_mapping(provides)
self.version = (1, 0) self.version = (1, 0)
self.inject = inject self.inject = inject
self.requires = frozenset() self.save_as = _save_as_to_mapping(provides)
self.optional = frozenset() self.requires = sets.OrderedSet()
self.provides = frozenset(self.save_as) self.optional = sets.OrderedSet()
self.rebind = {} self.provides = sets.OrderedSet(self.save_as)
self.rebind = OrderedDict()
def _build_arg_mapping(self, executor, requires=None, rebind=None, def _build_arg_mapping(self, executor, requires=None, rebind=None,
auto_extract=True, ignore_list=None): auto_extract=True, ignore_list=None):
req_arg, opt_arg = _build_arg_mapping(self.name, requires, rebind, required, optional = _build_arg_mapping(self.name, requires, rebind,
executor, auto_extract, executor, auto_extract,
ignore_list) ignore_list=ignore_list)
self.rebind.clear() rebind = OrderedDict()
if opt_arg: for (arg_name, bound_name) in itertools.chain(six.iteritems(required),
self.rebind.update(opt_arg) six.iteritems(optional)):
if req_arg: rebind.setdefault(arg_name, bound_name)
self.rebind.update(req_arg) self.rebind = rebind
self.requires = frozenset(req_arg.values()) self.requires = sets.OrderedSet(six.itervalues(required))
self.optional = frozenset(opt_arg.values()) self.optional = sets.OrderedSet(six.itervalues(optional))
if self.inject: if self.inject:
inject_set = set(six.iterkeys(self.inject)) inject_keys = frozenset(six.iterkeys(self.inject))
self.requires -= inject_set self.requires -= inject_keys
self.optional -= inject_set self.optional -= inject_keys
out_of_order = self.provides.intersection(self.requires) out_of_order = self.provides.intersection(self.requires)
if out_of_order: if out_of_order:
raise exceptions.DependencyFailure( raise exceptions.DependencyFailure(

View File

@@ -158,13 +158,22 @@ class Linker(object):
" decomposed into an empty graph" % (v, u, u)) " decomposed into an empty graph" % (v, u, u))
for u in u_g.nodes_iter(): for u in u_g.nodes_iter():
for v in v_g.nodes_iter(): for v in v_g.nodes_iter():
depends_on = u.provides & v.requires # This is using the intersection() method vs the &
# operator since the latter doesn't work with frozen
# sets (when used in combination with ordered sets).
#
# If this is not done the following happens...
#
# TypeError: unsupported operand type(s)
# for &: 'frozenset' and 'OrderedSet'
depends_on = u.provides.intersection(v.requires)
if depends_on: if depends_on:
edge_attrs = {
_EDGE_REASONS: frozenset(depends_on),
}
_add_update_edges(graph, _add_update_edges(graph,
[u], [v], [u], [v],
attr_dict={ attr_dict=edge_attrs)
_EDGE_REASONS: depends_on,
})
else: else:
# Connect nodes with no predecessors in v to nodes with no # Connect nodes with no predecessors in v to nodes with no
# successors in the *first* non-empty predecessor of v (thus # successors in the *first* non-empty predecessor of v (thus

View File

@@ -27,11 +27,20 @@ def _unsatisfied_requires(node, graph, *additional_provided):
if not requires: if not requires:
return requires return requires
for provided in additional_provided: for provided in additional_provided:
requires = requires - provided # This is using the difference() method vs the -
# operator since the latter doesn't work with frozen
# or regular sets (when used in combination with ordered
# sets).
#
# If this is not done the following happens...
#
# TypeError: unsupported operand type(s)
# for -: 'set' and 'OrderedSet'
requires = requires.difference(provided)
if not requires: if not requires:
return requires return requires
for pred in graph.bfs_predecessors_iter(node): for pred in graph.bfs_predecessors_iter(node):
requires = requires - pred.provides requires = requires.difference(pred.provides)
if not requires: if not requires:
return requires return requires
return requires return requires

View File

@@ -27,40 +27,40 @@ class FlowDependenciesTest(test.TestCase):
def test_task_without_dependencies(self): def test_task_without_dependencies(self):
flow = utils.TaskNoRequiresNoReturns() flow = utils.TaskNoRequiresNoReturns()
self.assertEqual(flow.requires, set()) self.assertEqual(set(), flow.requires)
self.assertEqual(flow.provides, set()) self.assertEqual(set(), flow.provides)
def test_task_requires_default_values(self): def test_task_requires_default_values(self):
flow = utils.TaskMultiArg() flow = utils.TaskMultiArg()
self.assertEqual(flow.requires, set(['x', 'y', 'z'])) self.assertEqual(set(['x', 'y', 'z']), flow.requires)
self.assertEqual(flow.provides, set()) self.assertEqual(set(), flow.provides, )
def test_task_requires_rebinded_mapped(self): def test_task_requires_rebinded_mapped(self):
flow = utils.TaskMultiArg(rebind={'x': 'a', 'y': 'b', 'z': 'c'}) flow = utils.TaskMultiArg(rebind={'x': 'a', 'y': 'b', 'z': 'c'})
self.assertEqual(flow.requires, set(['a', 'b', 'c'])) self.assertEqual(set(['a', 'b', 'c']), flow.requires)
self.assertEqual(flow.provides, set()) self.assertEqual(set(), flow.provides)
def test_task_requires_additional_values(self): def test_task_requires_additional_values(self):
flow = utils.TaskMultiArg(requires=['a', 'b']) flow = utils.TaskMultiArg(requires=['a', 'b'])
self.assertEqual(flow.requires, set(['a', 'b', 'x', 'y', 'z'])) self.assertEqual(set(['a', 'b', 'x', 'y', 'z']), flow.requires)
self.assertEqual(flow.provides, set()) self.assertEqual(set(), flow.provides)
def test_task_provides_values(self): def test_task_provides_values(self):
flow = utils.TaskMultiReturn(provides=['a', 'b', 'c']) flow = utils.TaskMultiReturn(provides=['a', 'b', 'c'])
self.assertEqual(flow.requires, set()) self.assertEqual(set(), flow.requires)
self.assertEqual(flow.provides, set(['a', 'b', 'c'])) self.assertEqual(set(['a', 'b', 'c']), flow.provides)
def test_task_provides_and_requires_values(self): def test_task_provides_and_requires_values(self):
flow = utils.TaskMultiArgMultiReturn(provides=['a', 'b', 'c']) flow = utils.TaskMultiArgMultiReturn(provides=['a', 'b', 'c'])
self.assertEqual(flow.requires, set(['x', 'y', 'z'])) self.assertEqual(set(['x', 'y', 'z']), flow.requires)
self.assertEqual(flow.provides, set(['a', 'b', 'c'])) self.assertEqual(set(['a', 'b', 'c']), flow.provides)
def test_linear_flow_without_dependencies(self): def test_linear_flow_without_dependencies(self):
flow = lf.Flow('lf').add( flow = lf.Flow('lf').add(
utils.TaskNoRequiresNoReturns('task1'), utils.TaskNoRequiresNoReturns('task1'),
utils.TaskNoRequiresNoReturns('task2')) utils.TaskNoRequiresNoReturns('task2'))
self.assertEqual(flow.requires, set()) self.assertEqual(set(), flow.requires)
self.assertEqual(flow.provides, set()) self.assertEqual(set(), flow.provides)
def test_linear_flow_requires_values(self): def test_linear_flow_requires_values(self):
flow = lf.Flow('lf').add( flow = lf.Flow('lf').add(

View File

@@ -52,36 +52,36 @@ class TaskTest(test.TestCase):
def test_passed_name(self): def test_passed_name(self):
my_task = MyTask(name='my name') my_task = MyTask(name='my name')
self.assertEqual(my_task.name, 'my name') self.assertEqual('my name', my_task.name)
def test_generated_name(self): def test_generated_name(self):
my_task = MyTask() my_task = MyTask()
self.assertEqual(my_task.name, self.assertEqual('%s.%s' % (__name__, 'MyTask'),
'%s.%s' % (__name__, 'MyTask')) my_task.name)
def test_task_str(self): def test_task_str(self):
my_task = MyTask(name='my') my_task = MyTask(name='my')
self.assertEqual(str(my_task), 'my==1.0') self.assertEqual('my==1.0', str(my_task))
def test_task_repr(self): def test_task_repr(self):
my_task = MyTask(name='my') my_task = MyTask(name='my')
self.assertEqual(repr(my_task), '<%s.MyTask my==1.0>' % __name__) self.assertEqual('<%s.MyTask my==1.0>' % __name__, repr(my_task))
def test_no_provides(self): def test_no_provides(self):
my_task = MyTask() my_task = MyTask()
self.assertEqual(my_task.save_as, {}) self.assertEqual({}, my_task.save_as)
def test_provides(self): def test_provides(self):
my_task = MyTask(provides='food') my_task = MyTask(provides='food')
self.assertEqual(my_task.save_as, {'food': None}) self.assertEqual({'food': None}, my_task.save_as)
def test_multi_provides(self): def test_multi_provides(self):
my_task = MyTask(provides=('food', 'water')) my_task = MyTask(provides=('food', 'water'))
self.assertEqual(my_task.save_as, {'food': 0, 'water': 1}) self.assertEqual({'food': 0, 'water': 1}, my_task.save_as)
def test_unpack(self): def test_unpack(self):
my_task = MyTask(provides=('food',)) my_task = MyTask(provides=('food',))
self.assertEqual(my_task.save_as, {'food': 0}) self.assertEqual({'food': 0}, my_task.save_as)
def test_bad_provides(self): def test_bad_provides(self):
self.assertRaisesRegexp(TypeError, '^Atom provides', self.assertRaisesRegexp(TypeError, '^Atom provides',
@@ -89,28 +89,34 @@ class TaskTest(test.TestCase):
def test_requires_by_default(self): def test_requires_by_default(self):
my_task = MyTask() my_task = MyTask()
self.assertEqual(my_task.rebind, { expected = {
'spam': 'spam', 'spam': 'spam',
'eggs': 'eggs', 'eggs': 'eggs',
'context': 'context' 'context': 'context'
}) }
self.assertEqual(expected,
my_task.rebind)
self.assertEqual(set(['spam', 'eggs', 'context']),
my_task.requires)
def test_requires_amended(self): def test_requires_amended(self):
my_task = MyTask(requires=('spam', 'eggs')) my_task = MyTask(requires=('spam', 'eggs'))
self.assertEqual(my_task.rebind, { expected = {
'spam': 'spam', 'spam': 'spam',
'eggs': 'eggs', 'eggs': 'eggs',
'context': 'context' 'context': 'context'
}) }
self.assertEqual(expected, my_task.rebind)
def test_requires_explicit(self): def test_requires_explicit(self):
my_task = MyTask(auto_extract=False, my_task = MyTask(auto_extract=False,
requires=('spam', 'eggs', 'context')) requires=('spam', 'eggs', 'context'))
self.assertEqual(my_task.rebind, { expected = {
'spam': 'spam', 'spam': 'spam',
'eggs': 'eggs', 'eggs': 'eggs',
'context': 'context' 'context': 'context'
}) }
self.assertEqual(expected, my_task.rebind)
def test_requires_explicit_not_enough(self): def test_requires_explicit_not_enough(self):
self.assertRaisesRegexp(ValueError, '^Missing arguments', self.assertRaisesRegexp(ValueError, '^Missing arguments',
@@ -119,36 +125,43 @@ class TaskTest(test.TestCase):
def test_requires_ignores_optional(self): def test_requires_ignores_optional(self):
my_task = DefaultArgTask() my_task = DefaultArgTask()
self.assertEqual(my_task.requires, set(['spam'])) self.assertEqual(set(['spam']), my_task.requires)
self.assertEqual(my_task.optional, set(['eggs'])) self.assertEqual(set(['eggs']), my_task.optional)
def test_requires_allows_optional(self): def test_requires_allows_optional(self):
my_task = DefaultArgTask(requires=('spam', 'eggs')) my_task = DefaultArgTask(requires=('spam', 'eggs'))
self.assertEqual(my_task.requires, set(['spam', 'eggs'])) self.assertEqual(set(['spam', 'eggs']), my_task.requires)
self.assertEqual(my_task.optional, set()) self.assertEqual(set(), my_task.optional)
def test_rebind_includes_optional(self): def test_rebind_includes_optional(self):
my_task = DefaultArgTask() my_task = DefaultArgTask()
self.assertEqual(my_task.rebind, { expected = {
'spam': 'spam', 'spam': 'spam',
'eggs': 'eggs', 'eggs': 'eggs',
}) }
self.assertEqual(expected, my_task.rebind)
def test_rebind_all_args(self): def test_rebind_all_args(self):
my_task = MyTask(rebind={'spam': 'a', 'eggs': 'b', 'context': 'c'}) my_task = MyTask(rebind={'spam': 'a', 'eggs': 'b', 'context': 'c'})
self.assertEqual(my_task.rebind, { expected = {
'spam': 'a', 'spam': 'a',
'eggs': 'b', 'eggs': 'b',
'context': 'c' 'context': 'c'
}) }
self.assertEqual(expected, my_task.rebind)
self.assertEqual(set(['a', 'b', 'c']),
my_task.requires)
def test_rebind_partial(self): def test_rebind_partial(self):
my_task = MyTask(rebind={'spam': 'a', 'eggs': 'b'}) my_task = MyTask(rebind={'spam': 'a', 'eggs': 'b'})
self.assertEqual(my_task.rebind, { expected = {
'spam': 'a', 'spam': 'a',
'eggs': 'b', 'eggs': 'b',
'context': 'context' 'context': 'context'
}) }
self.assertEqual(expected, my_task.rebind)
self.assertEqual(set(['a', 'b', 'context']),
my_task.requires)
def test_rebind_unknown(self): def test_rebind_unknown(self):
self.assertRaisesRegexp(ValueError, '^Extra arguments', self.assertRaisesRegexp(ValueError, '^Extra arguments',
@@ -156,26 +169,33 @@ class TaskTest(test.TestCase):
def test_rebind_unknown_kwargs(self): def test_rebind_unknown_kwargs(self):
task = KwargsTask(rebind={'foo': 'bar'}) task = KwargsTask(rebind={'foo': 'bar'})
self.assertEqual(task.rebind, { expected = {
'foo': 'bar', 'foo': 'bar',
'spam': 'spam' 'spam': 'spam'
}) }
self.assertEqual(expected, task.rebind)
def test_rebind_list_all(self): def test_rebind_list_all(self):
my_task = MyTask(rebind=('a', 'b', 'c')) my_task = MyTask(rebind=('a', 'b', 'c'))
self.assertEqual(my_task.rebind, { expected = {
'context': 'a', 'context': 'a',
'spam': 'b', 'spam': 'b',
'eggs': 'c' 'eggs': 'c'
}) }
self.assertEqual(expected, my_task.rebind)
self.assertEqual(set(['a', 'b', 'c']),
my_task.requires)
def test_rebind_list_partial(self): def test_rebind_list_partial(self):
my_task = MyTask(rebind=('a', 'b')) my_task = MyTask(rebind=('a', 'b'))
self.assertEqual(my_task.rebind, { expected = {
'context': 'a', 'context': 'a',
'spam': 'b', 'spam': 'b',
'eggs': 'eggs' 'eggs': 'eggs'
}) }
self.assertEqual(expected, my_task.rebind)
self.assertEqual(set(['a', 'b', 'eggs']),
my_task.requires)
def test_rebind_list_more(self): def test_rebind_list_more(self):
self.assertRaisesRegexp(ValueError, '^Extra arguments', self.assertRaisesRegexp(ValueError, '^Extra arguments',
@@ -183,11 +203,14 @@ class TaskTest(test.TestCase):
def test_rebind_list_more_kwargs(self): def test_rebind_list_more_kwargs(self):
task = KwargsTask(rebind=('a', 'b', 'c')) task = KwargsTask(rebind=('a', 'b', 'c'))
self.assertEqual(task.rebind, { expected = {
'spam': 'a', 'spam': 'a',
'b': 'b', 'b': 'b',
'c': 'c' 'c': 'c'
}) }
self.assertEqual(expected, task.rebind)
self.assertEqual(set(['a', 'b', 'c']),
task.requires)
def test_rebind_list_bad_value(self): def test_rebind_list_bad_value(self):
self.assertRaisesRegexp(TypeError, '^Invalid rebind value', self.assertRaisesRegexp(TypeError, '^Invalid rebind value',
@@ -195,13 +218,13 @@ class TaskTest(test.TestCase):
def test_default_provides(self): def test_default_provides(self):
task = DefaultProvidesTask() task = DefaultProvidesTask()
self.assertEqual(task.provides, set(['def'])) self.assertEqual(set(['def']), task.provides)
self.assertEqual(task.save_as, {'def': None}) self.assertEqual({'def': None}, task.save_as)
def test_default_provides_can_be_overridden(self): def test_default_provides_can_be_overridden(self):
task = DefaultProvidesTask(provides=('spam', 'eggs')) task = DefaultProvidesTask(provides=('spam', 'eggs'))
self.assertEqual(task.provides, set(['spam', 'eggs'])) self.assertEqual(set(['spam', 'eggs']), task.provides)
self.assertEqual(task.save_as, {'spam': 0, 'eggs': 1}) self.assertEqual({'spam': 0, 'eggs': 1}, task.save_as)
def test_update_progress_within_bounds(self): def test_update_progress_within_bounds(self):
values = [0.0, 0.5, 1.0] values = [0.0, 0.5, 1.0]
@@ -213,7 +236,7 @@ class TaskTest(test.TestCase):
a_task = ProgressTask() a_task = ProgressTask()
a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, progress_callback) a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, progress_callback)
a_task.execute(values) a_task.execute(values)
self.assertEqual(result, values) self.assertEqual(values, result)
@mock.patch.object(task.LOG, 'warn') @mock.patch.object(task.LOG, 'warn')
def test_update_progress_lower_bound(self, mocked_warn): def test_update_progress_lower_bound(self, mocked_warn):
@@ -225,8 +248,8 @@ class TaskTest(test.TestCase):
a_task = ProgressTask() a_task = ProgressTask()
a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, progress_callback) a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, progress_callback)
a_task.execute([-1.0, -0.5, 0.0]) a_task.execute([-1.0, -0.5, 0.0])
self.assertEqual(result, [0.0, 0.0, 0.0]) self.assertEqual([0.0, 0.0, 0.0], result)
self.assertEqual(mocked_warn.call_count, 2) self.assertEqual(2, mocked_warn.call_count)
@mock.patch.object(task.LOG, 'warn') @mock.patch.object(task.LOG, 'warn')
def test_update_progress_upper_bound(self, mocked_warn): def test_update_progress_upper_bound(self, mocked_warn):
@@ -238,8 +261,8 @@ class TaskTest(test.TestCase):
a_task = ProgressTask() a_task = ProgressTask()
a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, progress_callback) a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, progress_callback)
a_task.execute([1.0, 1.5, 2.0]) a_task.execute([1.0, 1.5, 2.0])
self.assertEqual(result, [1.0, 1.0, 1.0]) self.assertEqual([1.0, 1.0, 1.0], result)
self.assertEqual(mocked_warn.call_count, 2) self.assertEqual(2, mocked_warn.call_count)
@mock.patch.object(notifier.LOG, 'warn') @mock.patch.object(notifier.LOG, 'warn')
def test_update_progress_handler_failure(self, mocked_warn): def test_update_progress_handler_failure(self, mocked_warn):
@@ -256,34 +279,34 @@ class TaskTest(test.TestCase):
a_task = MyTask() a_task = MyTask()
self.assertRaises(ValueError, a_task.notifier.register, self.assertRaises(ValueError, a_task.notifier.register,
task.EVENT_UPDATE_PROGRESS, None) task.EVENT_UPDATE_PROGRESS, None)
self.assertEqual(len(a_task.notifier), 0) self.assertEqual(0, len(a_task.notifier))
def test_deregister_any_handler(self): def test_deregister_any_handler(self):
a_task = MyTask() a_task = MyTask()
self.assertEqual(len(a_task.notifier), 0) self.assertEqual(0, len(a_task.notifier))
a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, a_task.notifier.register(task.EVENT_UPDATE_PROGRESS,
lambda event_type, details: None) lambda event_type, details: None)
self.assertEqual(len(a_task.notifier), 1) self.assertEqual(1, len(a_task.notifier))
a_task.notifier.deregister_event(task.EVENT_UPDATE_PROGRESS) a_task.notifier.deregister_event(task.EVENT_UPDATE_PROGRESS)
self.assertEqual(len(a_task.notifier), 0) self.assertEqual(0, len(a_task.notifier))
def test_deregister_any_handler_empty_listeners(self): def test_deregister_any_handler_empty_listeners(self):
a_task = MyTask() a_task = MyTask()
self.assertEqual(len(a_task.notifier), 0) self.assertEqual(0, len(a_task.notifier))
self.assertFalse(a_task.notifier.deregister_event( self.assertFalse(a_task.notifier.deregister_event(
task.EVENT_UPDATE_PROGRESS)) task.EVENT_UPDATE_PROGRESS))
self.assertEqual(len(a_task.notifier), 0) self.assertEqual(0, len(a_task.notifier))
def test_deregister_non_existent_listener(self): def test_deregister_non_existent_listener(self):
handler1 = lambda event_type, details: None handler1 = lambda event_type, details: None
handler2 = lambda event_type, details: None handler2 = lambda event_type, details: None
a_task = MyTask() a_task = MyTask()
a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, handler1) a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, handler1)
self.assertEqual(len(list(a_task.notifier.listeners_iter())), 1) self.assertEqual(1, len(list(a_task.notifier.listeners_iter())))
a_task.notifier.deregister(task.EVENT_UPDATE_PROGRESS, handler2) a_task.notifier.deregister(task.EVENT_UPDATE_PROGRESS, handler2)
self.assertEqual(len(list(a_task.notifier.listeners_iter())), 1) self.assertEqual(1, len(list(a_task.notifier.listeners_iter())))
a_task.notifier.deregister(task.EVENT_UPDATE_PROGRESS, handler1) a_task.notifier.deregister(task.EVENT_UPDATE_PROGRESS, handler1)
self.assertEqual(len(list(a_task.notifier.listeners_iter())), 0) self.assertEqual(0, len(list(a_task.notifier.listeners_iter())))
def test_bind_not_callable(self): def test_bind_not_callable(self):
a_task = MyTask() a_task = MyTask()
@@ -295,8 +318,8 @@ class TaskTest(test.TestCase):
a_task = MyTask() a_task = MyTask()
a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, handler1) a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, handler1)
b_task = a_task.copy(retain_listeners=False) b_task = a_task.copy(retain_listeners=False)
self.assertEqual(len(a_task.notifier), 1) self.assertEqual(1, len(a_task.notifier))
self.assertEqual(len(b_task.notifier), 0) self.assertEqual(0, len(b_task.notifier))
def test_copy_listeners(self): def test_copy_listeners(self):
handler1 = lambda event_type, details: None handler1 = lambda event_type, details: None
@@ -304,15 +327,15 @@ class TaskTest(test.TestCase):
a_task = MyTask() a_task = MyTask()
a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, handler1) a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, handler1)
b_task = a_task.copy() b_task = a_task.copy()
self.assertEqual(len(b_task.notifier), 1) self.assertEqual(1, len(b_task.notifier))
self.assertTrue(a_task.notifier.deregister_event( self.assertTrue(a_task.notifier.deregister_event(
task.EVENT_UPDATE_PROGRESS)) task.EVENT_UPDATE_PROGRESS))
self.assertEqual(len(a_task.notifier), 0) self.assertEqual(0, len(a_task.notifier))
self.assertEqual(len(b_task.notifier), 1) self.assertEqual(1, len(b_task.notifier))
b_task.notifier.register(task.EVENT_UPDATE_PROGRESS, handler2) b_task.notifier.register(task.EVENT_UPDATE_PROGRESS, handler2)
listeners = dict(list(b_task.notifier.listeners_iter())) listeners = dict(list(b_task.notifier.listeners_iter()))
self.assertEqual(len(listeners[task.EVENT_UPDATE_PROGRESS]), 2) self.assertEqual(2, len(listeners[task.EVENT_UPDATE_PROGRESS]))
self.assertEqual(len(a_task.notifier), 0) self.assertEqual(0, len(a_task.notifier))
class FunctorTaskTest(test.TestCase): class FunctorTaskTest(test.TestCase):
@@ -320,7 +343,7 @@ class FunctorTaskTest(test.TestCase):
def test_creation_with_version(self): def test_creation_with_version(self):
version = (2, 0) version = (2, 0)
f_task = task.FunctorTask(lambda: None, version=version) f_task = task.FunctorTask(lambda: None, version=version)
self.assertEqual(f_task.version, version) self.assertEqual(version, f_task.version)
def test_execute_not_callable(self): def test_execute_not_callable(self):
self.assertRaises(ValueError, task.FunctorTask, 2) self.assertRaises(ValueError, task.FunctorTask, 2)