diff --git a/taskflow/storage.py b/taskflow/storage.py index e96874ac..bcc4b382 100644 --- a/taskflow/storage.py +++ b/taskflow/storage.py @@ -657,15 +657,17 @@ class Storage(object): scope_iter = iter(scope_walker) else: scope_iter = iter([]) - for atom_names in scope_iter: - if not atom_names: - continue - providers = [] - for p in possible_providers: - if p.name in atom_names: - providers.append((p, _get_results(looking_for, p))) + extractor = lambda p: p.name + for names in scope_iter: + # *Always* retain the scope ordering (if any matches + # happen); instead of retaining the possible provider match + # order (which isn't that important and may be different from + # the scope requested ordering). + providers = misc.look_for(names, possible_providers, + extractor=extractor) if providers: - return providers + return [(p, _get_results(looking_for, p)) + for p in providers] return [] with self._lock.read_lock(): diff --git a/taskflow/tests/unit/test_engines.py b/taskflow/tests/unit/test_engines.py index 9176b9d6..04b8fa3a 100644 --- a/taskflow/tests/unit/test_engines.py +++ b/taskflow/tests/unit/test_engines.py @@ -156,6 +156,40 @@ class EngineLinearFlowTest(utils.EngineTestBase): engine = self._make_engine(flow) self.assertRaises(exc.Empty, engine.run) + def test_overlap_parent_sibling_expected_result(self): + flow = lf.Flow('flow-1') + flow.add(utils.ProgressingTask(provides='source')) + flow.add(utils.TaskOneReturn(provides='source')) + subflow = lf.Flow('flow-2') + subflow.add(utils.AddOne()) + flow.add(subflow) + engine = self._make_engine(flow) + engine.run() + results = engine.storage.fetch_all() + self.assertEqual(2, results['result']) + + def test_overlap_parent_expected_result(self): + flow = lf.Flow('flow-1') + flow.add(utils.ProgressingTask(provides='source')) + subflow = lf.Flow('flow-2') + subflow.add(utils.TaskOneReturn(provides='source')) + subflow.add(utils.AddOne()) + flow.add(subflow) + engine = self._make_engine(flow) + engine.run() + results = engine.storage.fetch_all() + self.assertEqual(2, results['result']) + + def test_overlap_sibling_expected_result(self): + flow = lf.Flow('flow-1') + flow.add(utils.ProgressingTask(provides='source')) + flow.add(utils.TaskOneReturn(provides='source')) + flow.add(utils.AddOne()) + engine = self._make_engine(flow) + engine.run() + results = engine.storage.fetch_all() + self.assertEqual(2, results['result']) + def test_sequential_flow_one_task(self): flow = lf.Flow('flow-1').add( utils.ProgressingTask(name='task1') diff --git a/taskflow/tests/unit/test_utils.py b/taskflow/tests/unit/test_utils.py index 1477fe5a..df229fac 100644 --- a/taskflow/tests/unit/test_utils.py +++ b/taskflow/tests/unit/test_utils.py @@ -244,6 +244,24 @@ class TestCountdownIter(test.TestCase): self.assertRaises(ValueError, six.next, it) +class TestLookFor(test.TestCase): + def test_no_matches(self): + hay = [9, 10, 11] + self.assertEqual([], misc.look_for(hay, [1, 2, 3])) + + def test_match_order(self): + hay = [6, 5, 4, 3, 2, 1] + priors = [] + for i in range(0, 6): + priors.append(i + 1) + matches = misc.look_for(hay, priors) + self.assertGreater(0, len(matches)) + self.assertIsSuperAndSubsequence(hay, matches) + hay = [10, 1, 15, 3, 5, 8, 44] + self.assertEqual([1, 15], misc.look_for(hay, [15, 1])) + self.assertEqual([10, 44], misc.look_for(hay, [44, 10])) + + class TestClamping(test.TestCase): def test_simple_clamp(self): result = misc.clamp(1.0, 2.0, 3.0) diff --git a/taskflow/tests/unit/worker_based/test_worker.py b/taskflow/tests/unit/worker_based/test_worker.py index 597a64a4..07095efa 100644 --- a/taskflow/tests/unit/worker_based/test_worker.py +++ b/taskflow/tests/unit/worker_based/test_worker.py @@ -34,7 +34,7 @@ class TestWorker(test.MockTestCase): self.exchange = 'test-exchange' self.topic = 'test-topic' self.threads_count = 5 - self.endpoint_count = 23 + self.endpoint_count = 24 # patch classes self.executor_mock, self.executor_inst_mock = self.patchClass( diff --git a/taskflow/tests/utils.py b/taskflow/tests/utils.py index 9e6e2a34..031dd706 100644 --- a/taskflow/tests/utils.py +++ b/taskflow/tests/utils.py @@ -89,6 +89,13 @@ class DummyTask(task.Task): pass +class AddOne(task.Task): + default_provides = 'result' + + def execute(self, source): + return source + 1 + + class FakeTask(object): def execute(self, **kwargs): diff --git a/taskflow/utils/misc.py b/taskflow/utils/misc.py index 39708b5c..6c5c9de7 100644 --- a/taskflow/utils/misc.py +++ b/taskflow/utils/misc.py @@ -196,6 +196,37 @@ def parse_uri(uri): return netutils.urlsplit(uri) +def look_for(haystack, needles, extractor=None): + """Find items in haystack and returns matches found (in haystack order). + + Given a list of items (the haystack) and a list of items to look for (the + needles) this will look for the needles in the haystack and returns + the found needles (if any). The ordering of the returned needles is in the + order they are located in the haystack. + + Example input and output: + + >>> from taskflow.utils import misc + >>> hay = [3, 2, 1] + >>> misc.look_for(hay, [1, 2]) + [2, 1] + """ + if not haystack: + return [] + if extractor is None: + extractor = lambda v: v + matches = [] + for i, v in enumerate(needles): + try: + matches.append((haystack.index(extractor(v)), i)) + except ValueError: + pass + if not matches: + return [] + else: + return [needles[i] for (_hay_i, i) in sorted(matches)] + + def clamp(value, minimum, maximum, on_clamped=None): """Clamps a value to ensure its >= minimum and <= maximum.""" if minimum > maximum: