From f555a35f3081ba492db15d7bda11fbe50f2a8349 Mon Sep 17 00:00:00 2001 From: Joshua Harlow Date: Fri, 11 Dec 2015 08:40:23 -0800 Subject: [PATCH] Fix wrong usage of iter_utils.unique_seen This wasn't actually using the right data to derive uniqueness, so fix it to actually use the right entry to determine what to skip or what not to skip. Closes-Bug: #1525379 Change-Id: Ic23b6d03877f7714f6a3fb74adac0ba58cd97f0d --- taskflow/engines/action_engine/analyzer.py | 6 ++++-- taskflow/engines/action_engine/builder.py | 4 ++-- taskflow/tests/unit/test_utils_iter_utils.py | 20 ++++++++++++++++++-- taskflow/utils/iter_utils.py | 17 +++++++++-------- 4 files changed, 33 insertions(+), 14 deletions(-) diff --git a/taskflow/engines/action_engine/analyzer.py b/taskflow/engines/action_engine/analyzer.py index 1c0f4545..c0b625b6 100644 --- a/taskflow/engines/action_engine/analyzer.py +++ b/taskflow/engines/action_engine/analyzer.py @@ -16,6 +16,7 @@ import abc import itertools +import operator import weakref import six @@ -133,8 +134,9 @@ class Analyzer(object): def iter_next_atoms(self, atom=None): """Iterate next atoms to run (originating from atom or all atoms).""" if atom is None: - return iter_utils.unique_seen(self.browse_atoms_for_execute(), - self.browse_atoms_for_revert()) + return iter_utils.unique_seen((self.browse_atoms_for_execute(), + self.browse_atoms_for_revert()), + seen_selector=operator.itemgetter(0)) state = self._storage.get_atom_state(atom.name) intention = self._storage.get_atom_intention(atom.name) if state == st.SUCCESS: diff --git a/taskflow/engines/action_engine/builder.py b/taskflow/engines/action_engine/builder.py index 67b333df..72aa81c5 100644 --- a/taskflow/engines/action_engine/builder.py +++ b/taskflow/engines/action_engine/builder.py @@ -137,8 +137,8 @@ class MachineBuilder(object): # attempt, which may be empty if never ran before) and any nodes # that are now ready to be ran. memory.next_up.update( - iter_utils.unique_seen(self._completer.resume(), - iter_next_atoms())) + iter_utils.unique_seen((self._completer.resume(), + iter_next_atoms()))) return SCHEDULE def game_over(old_state, new_state, event): diff --git a/taskflow/tests/unit/test_utils_iter_utils.py b/taskflow/tests/unit/test_utils_iter_utils.py index ad522a4e..4a5ff4b9 100644 --- a/taskflow/tests/unit/test_utils_iter_utils.py +++ b/taskflow/tests/unit/test_utils_iter_utils.py @@ -39,9 +39,10 @@ class IterUtilsTest(test.TestCase): ['a', 'b'], 2, None, + object(), ] self.assertRaises(ValueError, - iter_utils.unique_seen, *iters) + iter_utils.unique_seen, iters) def test_generate_delays(self): it = iter_utils.generate_delays(1, 60) @@ -77,7 +78,22 @@ class IterUtilsTest(test.TestCase): ['f', 'm', 'n'], ] self.assertEqual(['a', 'b', 'c', 'd', 'e', 'f', 'm', 'n'], - list(iter_utils.unique_seen(*iters))) + list(iter_utils.unique_seen(iters))) + + def test_unique_seen_empty(self): + iters = [] + self.assertEqual([], list(iter_utils.unique_seen(iters))) + + def test_unique_seen_selector(self): + iters = [ + [(1, 'a'), (1, 'a')], + [(2, 'b')], + [(3, 'c')], + [(1, 'a'), (3, 'c')], + ] + it = iter_utils.unique_seen(iters, + seen_selector=lambda value: value[0]) + self.assertEqual([(1, 'a'), (2, 'b'), (3, 'c')], list(it)) def test_bad_fill(self): self.assertRaises(ValueError, iter_utils.fill, 2, 2) diff --git a/taskflow/utils/iter_utils.py b/taskflow/utils/iter_utils.py index 413b36f4..959f6660 100644 --- a/taskflow/utils/iter_utils.py +++ b/taskflow/utils/iter_utils.py @@ -91,7 +91,7 @@ def generate_delays(delay, max_delay, multiplier=2): return _gen_it() -def unique_seen(it, *its): +def unique_seen(its, seen_selector=None): """Yields unique values from iterator(s) (and retains order).""" def _gen_it(all_its): @@ -99,16 +99,17 @@ def unique_seen(it, *its): # can happen before generation/iteration... (instead of # during generation/iteration) seen = set() - while all_its: - it = all_its.popleft() + for it in all_its: for value in it: - if value not in seen: + if seen_selector is not None: + maybe_seen_value = seen_selector(value) + else: + maybe_seen_value = value + if maybe_seen_value not in seen: yield value - seen.add(value) + seen.add(maybe_seen_value) - all_its = collections.deque([it]) - if its: - all_its.extend(its) + all_its = list(its) for it in all_its: if not isinstance(it, collections.Iterable): raise ValueError("Iterable expected, but '%s' is"