Merge "Rewrite assertion for same elements in sequences"

This commit is contained in:
Jenkins
2014-03-24 11:03:07 +00:00
committed by Gerrit Code Review
5 changed files with 73 additions and 20 deletions

View File

@@ -22,6 +22,7 @@ import six
from taskflow import atom
from taskflow import exceptions as exc
from taskflow.utils import misc
LOG = logging.getLogger(__name__)
@@ -132,17 +133,12 @@ class ForEachBase(Retry):
"""Base class for retries that iterate given collection."""
def _get_next_value(self, values, history):
values = list(values) # copy it
for (item, failures) in history:
try:
values.remove(item) # remove exactly one element from item
except ValueError:
# one of the results is not in our list now -- who cares?
pass
if not values:
items = (item for item, _failures in history)
remaining = misc.sequence_minus(values, items)
if not remaining:
raise exc.NotFound("No elements left in collection of iterable "
"retry controller %s" % self.name)
return values[0]
return remaining[0]
def _on_failure(self, values, history):
try:

View File

@@ -25,6 +25,7 @@ import six
from taskflow import exceptions
from taskflow.tests import utils
from taskflow.utils import misc
class GreaterThanEqual(object):
@@ -57,6 +58,30 @@ class FailureRegexpMatcher(object):
(failure, self.exc_class))
class ItemsEqual(object):
"""Matches the sequence that has same elements as reference
object, regardless of the order.
"""
def __init__(self, seq):
self._seq = seq
self._list = list(seq)
def match(self, other):
other_list = list(other)
extra = misc.sequence_minus(other_list, self._list)
missing = misc.sequence_minus(self._list, other_list)
if extra or missing:
msg = ("Sequences %s and %s do not have same items."
% (self._seq, other))
if missing:
msg += " Extra items in first sequence: %s." % missing
if extra:
msg += " Extra items in second sequence: %s." % extra
return matchers.Mismatch(msg)
return None
class TestCase(testcase.TestCase):
"""Test case base class for all taskflow unit tests."""
@@ -151,12 +176,9 @@ class TestCase(testcase.TestCase):
except exceptions.WrappedFailure as e:
self.assertThat(e, FailureRegexpMatcher(exc_class, pattern))
def assertIsContainsSameElements(self, seq1, seq2, msg=None):
if sorted(seq1) != sorted(seq2):
if msg is None:
msg = ("%r doesn't contain same elements as %r."
% (seq1, seq2))
self.fail(msg)
def assertItemsEqual(self, seq1, seq2, msg=None):
matcher = ItemsEqual(seq1)
self.assertThat(seq2, matcher)
class MockTestCase(TestCase):

View File

@@ -283,7 +283,7 @@ class RetryTest(utils.EngineTestBase):
'task1 reverted(5)',
'task2',
'task1']
self.assertIsContainsSameElements(self.values, expected)
self.assertItemsEqual(self.values, expected)
def test_nested_flow_reverts_parent_retries(self):
retry1 = retry.Times(3, 'r1', provides='x')
@@ -502,7 +502,7 @@ class RetryTest(utils.EngineTestBase):
expected = [u't1 reverted(Failure: RuntimeError: Woot with 3)',
u't1 reverted(Failure: RuntimeError: Woot with 2)',
u't1 reverted(Failure: RuntimeError: Woot with 5)']
self.assertIsContainsSameElements(self.values, expected)
self.assertItemsEqual(self.values, expected)
def test_for_each_empty_collection(self):
values = []
@@ -536,7 +536,7 @@ class RetryTest(utils.EngineTestBase):
expected = [u't1 reverted(Failure: RuntimeError: Woot with 3)',
u't1 reverted(Failure: RuntimeError: Woot with 2)',
u't1 reverted(Failure: RuntimeError: Woot with 5)']
self.assertIsContainsSameElements(self.values, expected)
self.assertItemsEqual(self.values, expected)
def test_parameterized_for_each_empty_collection(self):
values = []
@@ -606,7 +606,7 @@ class RetryParallelExecutionTest(utils.EngineTestBase):
'task1 reverted(5)',
'task2',
'task1']
self.assertIsContainsSameElements(self.values, expected)
self.assertItemsEqual(self.values, expected)
def test_when_subflow_fails_revert_success_tasks(self):
waiting_task = utils.WaitForOneFromTask('task2', 'task1',
@@ -631,7 +631,7 @@ class RetryParallelExecutionTest(utils.EngineTestBase):
'task1',
'task2',
'task3']
self.assertIsContainsSameElements(self.values, expected)
self.assertItemsEqual(self.values, expected)
class SingleThreadedEngineTest(RetryTest,

View File

@@ -495,3 +495,22 @@ class ExcInfoUtilsTest(test.TestCase):
exc_info = self._make_ex_info()
copied = misc.copy_exc_info(exc_info)
self.assertTrue(misc.are_equal_exc_info_tuples(exc_info, copied))
class TestSequenceMinus(test.TestCase):
def test_simple_case(self):
result = misc.sequence_minus([1, 2, 3, 4], [2, 3])
self.assertEqual(result, [1, 4])
def test_subtrahend_has_extra_elements(self):
result = misc.sequence_minus([1, 2, 3, 4], [2, 3, 5, 7, 13])
self.assertEqual(result, [1, 4])
def test_some_items_are_equal(self):
result = misc.sequence_minus([1, 1, 1, 1], [1, 1, 3])
self.assertEqual(result, [1, 1])
def test_equal_items_not_continious(self):
result = misc.sequence_minus([1, 2, 3, 1], [1, 3])
self.assertEqual(result, [2, 1])

View File

@@ -122,6 +122,22 @@ def get_version_string(obj):
return obj_version
def sequence_minus(seq1, seq2):
"""Calculate difference of two sequences.
Result contains the elements from first sequence that are not
present in second sequence, in original order. Works even
if sequence elements are not hashable.
"""
result = list(seq1)
for item in seq2:
try:
result.remove(item)
except ValueError:
pass
return result
def item_from(container, index, name=None):
"""Attempts to fetch a index/key from a given container."""
if index is None: