diff --git a/taskflow/retry.py b/taskflow/retry.py index bb79c7f4..b1a6ff48 100644 --- a/taskflow/retry.py +++ b/taskflow/retry.py @@ -22,6 +22,7 @@ import six from taskflow import atom from taskflow import exceptions as exc +from taskflow.utils import misc LOG = logging.getLogger(__name__) @@ -132,17 +133,12 @@ class ForEachBase(Retry): """Base class for retries that iterate given collection.""" def _get_next_value(self, values, history): - values = list(values) # copy it - for (item, failures) in history: - try: - values.remove(item) # remove exactly one element from item - except ValueError: - # one of the results is not in our list now -- who cares? - pass - if not values: + items = (item for item, _failures in history) + remaining = misc.sequence_minus(values, items) + if not remaining: raise exc.NotFound("No elements left in collection of iterable " "retry controller %s" % self.name) - return values[0] + return remaining[0] def _on_failure(self, values, history): try: diff --git a/taskflow/test.py b/taskflow/test.py index 2557a132..ce99a373 100644 --- a/taskflow/test.py +++ b/taskflow/test.py @@ -25,6 +25,7 @@ import six from taskflow import exceptions from taskflow.tests import utils +from taskflow.utils import misc class GreaterThanEqual(object): @@ -57,6 +58,30 @@ class FailureRegexpMatcher(object): (failure, self.exc_class)) +class ItemsEqual(object): + """Matches the sequence that has same elements as reference + object, regardless of the order. + """ + + def __init__(self, seq): + self._seq = seq + self._list = list(seq) + + def match(self, other): + other_list = list(other) + extra = misc.sequence_minus(other_list, self._list) + missing = misc.sequence_minus(self._list, other_list) + if extra or missing: + msg = ("Sequences %s and %s do not have same items." + % (self._seq, other)) + if missing: + msg += " Extra items in first sequence: %s." % missing + if extra: + msg += " Extra items in second sequence: %s." % extra + return matchers.Mismatch(msg) + return None + + class TestCase(testcase.TestCase): """Test case base class for all taskflow unit tests.""" @@ -151,12 +176,9 @@ class TestCase(testcase.TestCase): except exceptions.WrappedFailure as e: self.assertThat(e, FailureRegexpMatcher(exc_class, pattern)) - def assertIsContainsSameElements(self, seq1, seq2, msg=None): - if sorted(seq1) != sorted(seq2): - if msg is None: - msg = ("%r doesn't contain same elements as %r." - % (seq1, seq2)) - self.fail(msg) + def assertItemsEqual(self, seq1, seq2, msg=None): + matcher = ItemsEqual(seq1) + self.assertThat(seq2, matcher) class MockTestCase(TestCase): diff --git a/taskflow/tests/unit/test_retries.py b/taskflow/tests/unit/test_retries.py index 5d475029..a55b307a 100644 --- a/taskflow/tests/unit/test_retries.py +++ b/taskflow/tests/unit/test_retries.py @@ -283,7 +283,7 @@ class RetryTest(utils.EngineTestBase): 'task1 reverted(5)', 'task2', 'task1'] - self.assertIsContainsSameElements(self.values, expected) + self.assertItemsEqual(self.values, expected) def test_nested_flow_reverts_parent_retries(self): retry1 = retry.Times(3, 'r1', provides='x') @@ -502,7 +502,7 @@ class RetryTest(utils.EngineTestBase): expected = [u't1 reverted(Failure: RuntimeError: Woot with 3)', u't1 reverted(Failure: RuntimeError: Woot with 2)', u't1 reverted(Failure: RuntimeError: Woot with 5)'] - self.assertIsContainsSameElements(self.values, expected) + self.assertItemsEqual(self.values, expected) def test_for_each_empty_collection(self): values = [] @@ -536,7 +536,7 @@ class RetryTest(utils.EngineTestBase): expected = [u't1 reverted(Failure: RuntimeError: Woot with 3)', u't1 reverted(Failure: RuntimeError: Woot with 2)', u't1 reverted(Failure: RuntimeError: Woot with 5)'] - self.assertIsContainsSameElements(self.values, expected) + self.assertItemsEqual(self.values, expected) def test_parameterized_for_each_empty_collection(self): values = [] @@ -606,7 +606,7 @@ class RetryParallelExecutionTest(utils.EngineTestBase): 'task1 reverted(5)', 'task2', 'task1'] - self.assertIsContainsSameElements(self.values, expected) + self.assertItemsEqual(self.values, expected) def test_when_subflow_fails_revert_success_tasks(self): waiting_task = utils.WaitForOneFromTask('task2', 'task1', @@ -631,7 +631,7 @@ class RetryParallelExecutionTest(utils.EngineTestBase): 'task1', 'task2', 'task3'] - self.assertIsContainsSameElements(self.values, expected) + self.assertItemsEqual(self.values, expected) class SingleThreadedEngineTest(RetryTest, diff --git a/taskflow/tests/unit/test_utils.py b/taskflow/tests/unit/test_utils.py index 4be4646d..1c5c197b 100644 --- a/taskflow/tests/unit/test_utils.py +++ b/taskflow/tests/unit/test_utils.py @@ -495,3 +495,22 @@ class ExcInfoUtilsTest(test.TestCase): exc_info = self._make_ex_info() copied = misc.copy_exc_info(exc_info) self.assertTrue(misc.are_equal_exc_info_tuples(exc_info, copied)) + + +class TestSequenceMinus(test.TestCase): + + def test_simple_case(self): + result = misc.sequence_minus([1, 2, 3, 4], [2, 3]) + self.assertEqual(result, [1, 4]) + + def test_subtrahend_has_extra_elements(self): + result = misc.sequence_minus([1, 2, 3, 4], [2, 3, 5, 7, 13]) + self.assertEqual(result, [1, 4]) + + def test_some_items_are_equal(self): + result = misc.sequence_minus([1, 1, 1, 1], [1, 1, 3]) + self.assertEqual(result, [1, 1]) + + def test_equal_items_not_continious(self): + result = misc.sequence_minus([1, 2, 3, 1], [1, 3]) + self.assertEqual(result, [2, 1]) diff --git a/taskflow/utils/misc.py b/taskflow/utils/misc.py index 1f9a310c..ee8d7e2b 100644 --- a/taskflow/utils/misc.py +++ b/taskflow/utils/misc.py @@ -122,6 +122,22 @@ def get_version_string(obj): return obj_version +def sequence_minus(seq1, seq2): + """Calculate difference of two sequences. + + Result contains the elements from first sequence that are not + present in second sequence, in original order. Works even + if sequence elements are not hashable. + """ + result = list(seq1) + for item in seq2: + try: + result.remove(item) + except ValueError: + pass + return result + + def item_from(container, index, name=None): """Attempts to fetch a index/key from a given container.""" if index is None: