diff --git a/taskflow/engines/action_engine/actions/retry.py b/taskflow/engines/action_engine/actions/retry.py index 5afd2751..3262a79f 100644 --- a/taskflow/engines/action_engine/actions/retry.py +++ b/taskflow/engines/action_engine/actions/retry.py @@ -76,7 +76,7 @@ class RetryAction(object): result = retry.execute(**kwargs) except Exception: result = failure.Failure() - return (retry, ex.EXECUTED, result) + return (ex.EXECUTED, result) def _on_done_callback(fut): result = fut.result()[-1] @@ -89,6 +89,7 @@ class RetryAction(object): fut = self._executor.submit(_execute_retry, self._get_retry_args(retry)) fut.add_done_callback(_on_done_callback) + fut.atom = retry return fut def revert(self, retry): @@ -98,7 +99,7 @@ class RetryAction(object): result = retry.revert(**kwargs) except Exception: result = failure.Failure() - return (retry, ex.REVERTED, result) + return (ex.REVERTED, result) def _on_done_callback(fut): result = fut.result()[-1] @@ -115,6 +116,7 @@ class RetryAction(object): self._get_retry_args(retry, addons=arg_addons)) fut.add_done_callback(_on_done_callback) + fut.atom = retry return fut def on_failure(self, retry, atom, last_failure): diff --git a/taskflow/engines/action_engine/executor.py b/taskflow/engines/action_engine/executor.py index 4c5c091a..9224adb1 100644 --- a/taskflow/engines/action_engine/executor.py +++ b/taskflow/engines/action_engine/executor.py @@ -40,7 +40,7 @@ def _execute_task(task, arguments, progress_callback): result = failure.Failure() finally: task.post_execute() - return (task, EXECUTED, result) + return (EXECUTED, result) def _revert_task(task, arguments, result, failures, progress_callback): @@ -57,7 +57,7 @@ def _revert_task(task, arguments, result, failures, progress_callback): result = failure.Failure() finally: task.post_revert() - return (task, REVERTED, result) + return (REVERTED, result) @six.add_metaclass(abc.ABCMeta) @@ -98,13 +98,17 @@ class SerialTaskExecutor(TaskExecutorBase): self._executor = futures.SynchronousExecutor() def execute_task(self, task, task_uuid, arguments, progress_callback=None): - return self._executor.submit(_execute_task, task, arguments, - progress_callback) + fut = self._executor.submit(_execute_task, task, arguments, + progress_callback) + fut.atom = task + return fut def revert_task(self, task, task_uuid, arguments, result, failures, progress_callback=None): - return self._executor.submit(_revert_task, task, arguments, result, - failures, progress_callback) + fut = self._executor.submit(_revert_task, task, arguments, result, + failures, progress_callback) + fut.atom = task + return fut def wait_for_any(self, fs, timeout=None): return async_utils.wait_for_any(fs, timeout) @@ -123,14 +127,17 @@ class ParallelTaskExecutor(TaskExecutorBase): self._create_executor = executor is None def execute_task(self, task, task_uuid, arguments, progress_callback=None): - return self._executor.submit( - _execute_task, task, arguments, progress_callback) + fut = self._executor.submit(_execute_task, task, + arguments, progress_callback) + fut.atom = task + return fut def revert_task(self, task, task_uuid, arguments, result, failures, progress_callback=None): - return self._executor.submit( - _revert_task, task, - arguments, result, failures, progress_callback) + fut = self._executor.submit(_revert_task, task, arguments, + result, failures, progress_callback) + fut.atom = task + return fut def wait_for_any(self, fs, timeout=None): return async_utils.wait_for_any(fs, timeout) diff --git a/taskflow/engines/action_engine/runner.py b/taskflow/engines/action_engine/runner.py index 79ebc657..502bad18 100644 --- a/taskflow/engines/action_engine/runner.py +++ b/taskflow/engines/action_engine/runner.py @@ -129,8 +129,9 @@ class _MachineBuilder(object): next_nodes = set() while memory.done: fut = memory.done.pop() + node = fut.atom try: - node, event, result = fut.result() + event, result = fut.result() retain = self._completer.complete(node, event, result) if retain and isinstance(result, failure.Failure): memory.failures.append(result) diff --git a/taskflow/engines/worker_based/endpoint.py b/taskflow/engines/worker_based/endpoint.py index 3a16266d..276e93c6 100644 --- a/taskflow/engines/worker_based/endpoint.py +++ b/taskflow/engines/worker_based/endpoint.py @@ -40,11 +40,11 @@ class Endpoint(object): return self._task_cls(name=name) def execute(self, task_name, **kwargs): - task, event, result = self._executor.execute_task( - self._get_task(task_name), **kwargs).result() + task = self._get_task(task_name) + event, result = self._executor.execute_task(task, **kwargs).result() return result def revert(self, task_name, **kwargs): - task, event, result = self._executor.revert_task( - self._get_task(task_name), **kwargs).result() + task = self._get_task(task_name) + event, result = self._executor.revert_task(task, **kwargs).result() return result diff --git a/taskflow/engines/worker_based/protocol.py b/taskflow/engines/worker_based/protocol.py index 3cd7e178..3fd87216 100644 --- a/taskflow/engines/worker_based/protocol.py +++ b/taskflow/engines/worker_based/protocol.py @@ -234,6 +234,7 @@ class Request(Message): self._lock = threading.Lock() self._created_on = timeutils.utcnow() self.result = futures.Future() + self.result.atom = task @property def uuid(self): @@ -290,7 +291,7 @@ class Request(Message): return request def set_result(self, result): - self.result.set_result((self._task, self._event, result)) + self.result.set_result((self._event, result)) def on_progress(self, event_data, progress): self._progress_callback(self._task, event_data, progress) diff --git a/taskflow/tests/unit/worker_based/test_pipeline.py b/taskflow/tests/unit/worker_based/test_pipeline.py index ed3e2662..2822a852 100644 --- a/taskflow/tests/unit/worker_based/test_pipeline.py +++ b/taskflow/tests/unit/worker_based/test_pipeline.py @@ -16,6 +16,7 @@ from concurrent import futures +from taskflow.engines.action_engine import executor as base_executor from taskflow.engines.worker_based import endpoint from taskflow.engines.worker_based import executor as worker_executor from taskflow.engines.worker_based import server as worker_server @@ -73,13 +74,14 @@ class TestPipeline(test.TestCase): self.assertEqual(0, executor.wait_for_workers(timeout=WAIT_TIMEOUT)) t = test_utils.TaskOneReturn() - f = executor.execute_task(t, uuidutils.generate_uuid(), {}) + progress_callback = lambda *args, **kwargs: None + f = executor.execute_task(t, uuidutils.generate_uuid(), {}, + progress_callback=progress_callback) executor.wait_for_any([f]) - t2, _action, result = f.result() - + event, result = f.result() self.assertEqual(1, result) - self.assertEqual(t, t2) + self.assertEqual(base_executor.EXECUTED, event) def test_execution_failure_pipeline(self): task_classes = [ @@ -88,9 +90,12 @@ class TestPipeline(test.TestCase): executor, server = self._start_components(task_classes) t = test_utils.TaskWithFailure() - f = executor.execute_task(t, uuidutils.generate_uuid(), {}) + progress_callback = lambda *args, **kwargs: None + f = executor.execute_task(t, uuidutils.generate_uuid(), {}, + progress_callback=progress_callback) executor.wait_for_any([f]) - _t2, _action, result = f.result() + action, result = f.result() self.assertIsInstance(result, failure.Failure) self.assertEqual(RuntimeError, result.check(RuntimeError)) + self.assertEqual(base_executor.EXECUTED, action) diff --git a/taskflow/tests/unit/worker_based/test_protocol.py b/taskflow/tests/unit/worker_based/test_protocol.py index 4c9c4b77..2f5fd927 100644 --- a/taskflow/tests/unit/worker_based/test_protocol.py +++ b/taskflow/tests/unit/worker_based/test_protocol.py @@ -17,6 +17,7 @@ from concurrent import futures from oslo.utils import timeutils +from taskflow.engines.action_engine import executor from taskflow.engines.worker_based import protocol as pr from taskflow import exceptions as excp from taskflow.openstack.common import uuidutils @@ -182,7 +183,7 @@ class TestProtocol(test.TestCase): request = self.request() request.set_result(111) result = request.result.result() - self.assertEqual(result, (self.task, 'executed', 111)) + self.assertEqual(result, (executor.EXECUTED, 111)) def test_on_progress(self): progress_callback = mock.MagicMock(name='progress_callback')