Rewrite assertion for same elements in sequences
New version is able to compare sequences whith elements that are not hashable and cannot be compared (so that sorted() does not work). For that, a utility function that caluclates difference between two sequences was added. This function was also used in retry.ForEachBase. The assertion was also renamed to assertItemsEqual, which matches same assertion that was added in python 2.7. Change-Id: I2b1b811190e9dc51718e4ca17ffc5c9015c34dc4
This commit is contained in:
@@ -22,6 +22,7 @@ import six
|
||||
|
||||
from taskflow import atom
|
||||
from taskflow import exceptions as exc
|
||||
from taskflow.utils import misc
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@@ -132,17 +133,12 @@ class ForEachBase(Retry):
|
||||
"""Base class for retries that iterate given collection."""
|
||||
|
||||
def _get_next_value(self, values, history):
|
||||
values = list(values) # copy it
|
||||
for (item, failures) in history:
|
||||
try:
|
||||
values.remove(item) # remove exactly one element from item
|
||||
except ValueError:
|
||||
# one of the results is not in our list now -- who cares?
|
||||
pass
|
||||
if not values:
|
||||
items = (item for item, _failures in history)
|
||||
remaining = misc.sequence_minus(values, items)
|
||||
if not remaining:
|
||||
raise exc.NotFound("No elements left in collection of iterable "
|
||||
"retry controller %s" % self.name)
|
||||
return values[0]
|
||||
return remaining[0]
|
||||
|
||||
def _on_failure(self, values, history):
|
||||
try:
|
||||
|
@@ -25,6 +25,7 @@ import six
|
||||
|
||||
from taskflow import exceptions
|
||||
from taskflow.tests import utils
|
||||
from taskflow.utils import misc
|
||||
|
||||
|
||||
class GreaterThanEqual(object):
|
||||
@@ -57,6 +58,30 @@ class FailureRegexpMatcher(object):
|
||||
(failure, self.exc_class))
|
||||
|
||||
|
||||
class ItemsEqual(object):
|
||||
"""Matches the sequence that has same elements as reference
|
||||
object, regardless of the order.
|
||||
"""
|
||||
|
||||
def __init__(self, seq):
|
||||
self._seq = seq
|
||||
self._list = list(seq)
|
||||
|
||||
def match(self, other):
|
||||
other_list = list(other)
|
||||
extra = misc.sequence_minus(other_list, self._list)
|
||||
missing = misc.sequence_minus(self._list, other_list)
|
||||
if extra or missing:
|
||||
msg = ("Sequences %s and %s do not have same items."
|
||||
% (self._seq, other))
|
||||
if missing:
|
||||
msg += " Extra items in first sequence: %s." % missing
|
||||
if extra:
|
||||
msg += " Extra items in second sequence: %s." % extra
|
||||
return matchers.Mismatch(msg)
|
||||
return None
|
||||
|
||||
|
||||
class TestCase(testcase.TestCase):
|
||||
"""Test case base class for all taskflow unit tests."""
|
||||
|
||||
@@ -151,12 +176,9 @@ class TestCase(testcase.TestCase):
|
||||
except exceptions.WrappedFailure as e:
|
||||
self.assertThat(e, FailureRegexpMatcher(exc_class, pattern))
|
||||
|
||||
def assertIsContainsSameElements(self, seq1, seq2, msg=None):
|
||||
if sorted(seq1) != sorted(seq2):
|
||||
if msg is None:
|
||||
msg = ("%r doesn't contain same elements as %r."
|
||||
% (seq1, seq2))
|
||||
self.fail(msg)
|
||||
def assertItemsEqual(self, seq1, seq2, msg=None):
|
||||
matcher = ItemsEqual(seq1)
|
||||
self.assertThat(seq2, matcher)
|
||||
|
||||
|
||||
class MockTestCase(TestCase):
|
||||
|
@@ -283,7 +283,7 @@ class RetryTest(utils.EngineTestBase):
|
||||
'task1 reverted(5)',
|
||||
'task2',
|
||||
'task1']
|
||||
self.assertIsContainsSameElements(self.values, expected)
|
||||
self.assertItemsEqual(self.values, expected)
|
||||
|
||||
def test_nested_flow_reverts_parent_retries(self):
|
||||
retry1 = retry.Times(3, 'r1', provides='x')
|
||||
@@ -502,7 +502,7 @@ class RetryTest(utils.EngineTestBase):
|
||||
expected = [u't1 reverted(Failure: RuntimeError: Woot with 3)',
|
||||
u't1 reverted(Failure: RuntimeError: Woot with 2)',
|
||||
u't1 reverted(Failure: RuntimeError: Woot with 5)']
|
||||
self.assertIsContainsSameElements(self.values, expected)
|
||||
self.assertItemsEqual(self.values, expected)
|
||||
|
||||
def test_for_each_empty_collection(self):
|
||||
values = []
|
||||
@@ -536,7 +536,7 @@ class RetryTest(utils.EngineTestBase):
|
||||
expected = [u't1 reverted(Failure: RuntimeError: Woot with 3)',
|
||||
u't1 reverted(Failure: RuntimeError: Woot with 2)',
|
||||
u't1 reverted(Failure: RuntimeError: Woot with 5)']
|
||||
self.assertIsContainsSameElements(self.values, expected)
|
||||
self.assertItemsEqual(self.values, expected)
|
||||
|
||||
def test_parameterized_for_each_empty_collection(self):
|
||||
values = []
|
||||
@@ -606,7 +606,7 @@ class RetryParallelExecutionTest(utils.EngineTestBase):
|
||||
'task1 reverted(5)',
|
||||
'task2',
|
||||
'task1']
|
||||
self.assertIsContainsSameElements(self.values, expected)
|
||||
self.assertItemsEqual(self.values, expected)
|
||||
|
||||
def test_when_subflow_fails_revert_success_tasks(self):
|
||||
waiting_task = utils.WaitForOneFromTask('task2', 'task1',
|
||||
@@ -631,7 +631,7 @@ class RetryParallelExecutionTest(utils.EngineTestBase):
|
||||
'task1',
|
||||
'task2',
|
||||
'task3']
|
||||
self.assertIsContainsSameElements(self.values, expected)
|
||||
self.assertItemsEqual(self.values, expected)
|
||||
|
||||
|
||||
class SingleThreadedEngineTest(RetryTest,
|
||||
|
@@ -495,3 +495,22 @@ class ExcInfoUtilsTest(test.TestCase):
|
||||
exc_info = self._make_ex_info()
|
||||
copied = misc.copy_exc_info(exc_info)
|
||||
self.assertTrue(misc.are_equal_exc_info_tuples(exc_info, copied))
|
||||
|
||||
|
||||
class TestSequenceMinus(test.TestCase):
|
||||
|
||||
def test_simple_case(self):
|
||||
result = misc.sequence_minus([1, 2, 3, 4], [2, 3])
|
||||
self.assertEqual(result, [1, 4])
|
||||
|
||||
def test_subtrahend_has_extra_elements(self):
|
||||
result = misc.sequence_minus([1, 2, 3, 4], [2, 3, 5, 7, 13])
|
||||
self.assertEqual(result, [1, 4])
|
||||
|
||||
def test_some_items_are_equal(self):
|
||||
result = misc.sequence_minus([1, 1, 1, 1], [1, 1, 3])
|
||||
self.assertEqual(result, [1, 1])
|
||||
|
||||
def test_equal_items_not_continious(self):
|
||||
result = misc.sequence_minus([1, 2, 3, 1], [1, 3])
|
||||
self.assertEqual(result, [2, 1])
|
||||
|
@@ -121,6 +121,22 @@ def get_version_string(obj):
|
||||
return obj_version
|
||||
|
||||
|
||||
def sequence_minus(seq1, seq2):
|
||||
"""Calculate difference of two sequences.
|
||||
|
||||
Result contains the elements from first sequence that are not
|
||||
present in second sequence, in original order. Works even
|
||||
if sequence elements are not hashable.
|
||||
"""
|
||||
result = list(seq1)
|
||||
for item in seq2:
|
||||
try:
|
||||
result.remove(item)
|
||||
except ValueError:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
def item_from(container, index, name=None):
|
||||
"""Attempts to fetch a index/key from a given container."""
|
||||
if index is None:
|
||||
|
Reference in New Issue
Block a user