diff --git a/taskflow/engines/action_engine/analyzer.py b/taskflow/engines/action_engine/analyzer.py index 1c0f4545..c0b625b6 100644 --- a/taskflow/engines/action_engine/analyzer.py +++ b/taskflow/engines/action_engine/analyzer.py @@ -16,6 +16,7 @@ import abc import itertools +import operator import weakref import six @@ -133,8 +134,9 @@ class Analyzer(object): def iter_next_atoms(self, atom=None): """Iterate next atoms to run (originating from atom or all atoms).""" if atom is None: - return iter_utils.unique_seen(self.browse_atoms_for_execute(), - self.browse_atoms_for_revert()) + return iter_utils.unique_seen((self.browse_atoms_for_execute(), + self.browse_atoms_for_revert()), + seen_selector=operator.itemgetter(0)) state = self._storage.get_atom_state(atom.name) intention = self._storage.get_atom_intention(atom.name) if state == st.SUCCESS: diff --git a/taskflow/engines/action_engine/builder.py b/taskflow/engines/action_engine/builder.py index 67b333df..72aa81c5 100644 --- a/taskflow/engines/action_engine/builder.py +++ b/taskflow/engines/action_engine/builder.py @@ -137,8 +137,8 @@ class MachineBuilder(object): # attempt, which may be empty if never ran before) and any nodes # that are now ready to be ran. memory.next_up.update( - iter_utils.unique_seen(self._completer.resume(), - iter_next_atoms())) + iter_utils.unique_seen((self._completer.resume(), + iter_next_atoms()))) return SCHEDULE def game_over(old_state, new_state, event): diff --git a/taskflow/tests/unit/test_utils_iter_utils.py b/taskflow/tests/unit/test_utils_iter_utils.py index ad522a4e..4a5ff4b9 100644 --- a/taskflow/tests/unit/test_utils_iter_utils.py +++ b/taskflow/tests/unit/test_utils_iter_utils.py @@ -39,9 +39,10 @@ class IterUtilsTest(test.TestCase): ['a', 'b'], 2, None, + object(), ] self.assertRaises(ValueError, - iter_utils.unique_seen, *iters) + iter_utils.unique_seen, iters) def test_generate_delays(self): it = iter_utils.generate_delays(1, 60) @@ -77,7 +78,22 @@ class IterUtilsTest(test.TestCase): ['f', 'm', 'n'], ] self.assertEqual(['a', 'b', 'c', 'd', 'e', 'f', 'm', 'n'], - list(iter_utils.unique_seen(*iters))) + list(iter_utils.unique_seen(iters))) + + def test_unique_seen_empty(self): + iters = [] + self.assertEqual([], list(iter_utils.unique_seen(iters))) + + def test_unique_seen_selector(self): + iters = [ + [(1, 'a'), (1, 'a')], + [(2, 'b')], + [(3, 'c')], + [(1, 'a'), (3, 'c')], + ] + it = iter_utils.unique_seen(iters, + seen_selector=lambda value: value[0]) + self.assertEqual([(1, 'a'), (2, 'b'), (3, 'c')], list(it)) def test_bad_fill(self): self.assertRaises(ValueError, iter_utils.fill, 2, 2) diff --git a/taskflow/utils/iter_utils.py b/taskflow/utils/iter_utils.py index 413b36f4..959f6660 100644 --- a/taskflow/utils/iter_utils.py +++ b/taskflow/utils/iter_utils.py @@ -91,7 +91,7 @@ def generate_delays(delay, max_delay, multiplier=2): return _gen_it() -def unique_seen(it, *its): +def unique_seen(its, seen_selector=None): """Yields unique values from iterator(s) (and retains order).""" def _gen_it(all_its): @@ -99,16 +99,17 @@ def unique_seen(it, *its): # can happen before generation/iteration... (instead of # during generation/iteration) seen = set() - while all_its: - it = all_its.popleft() + for it in all_its: for value in it: - if value not in seen: + if seen_selector is not None: + maybe_seen_value = seen_selector(value) + else: + maybe_seen_value = value + if maybe_seen_value not in seen: yield value - seen.add(value) + seen.add(maybe_seen_value) - all_its = collections.deque([it]) - if its: - all_its.extend(its) + all_its = list(its) for it in all_its: if not isinstance(it, collections.Iterable): raise ValueError("Iterable expected, but '%s' is"