diff --git a/taskflow/engines/action_engine/engine.py b/taskflow/engines/action_engine/engine.py index 4f878545..ac430956 100644 --- a/taskflow/engines/action_engine/engine.py +++ b/taskflow/engines/action_engine/engine.py @@ -20,6 +20,7 @@ import threading from concurrent import futures +from taskflow.engines.action_engine import executor from taskflow.engines.action_engine import graph_action from taskflow.engines.action_engine import task_action from taskflow.engines import base @@ -51,6 +52,7 @@ class ActionEngine(base.EngineBase): """ _graph_action_cls = None _task_action_cls = task_action.TaskAction + _task_executor_cls = executor.SerialTaskExecutor def __init__(self, flow, flow_detail, backend, conf): super(ActionEngine, self).__init__(flow, flow_detail, backend, conf) @@ -59,7 +61,9 @@ class ActionEngine(base.EngineBase): self._state_lock = threading.RLock() self.notifier = misc.TransitionNotifier() self.task_notifier = misc.TransitionNotifier() + self._task_executor = self._task_executor_cls() self.task_action = self._task_action_cls(self.storage, + self._task_executor, self.task_notifier) def _revert(self, current_failure=None): @@ -106,10 +110,14 @@ class ActionEngine(base.EngineBase): missing = self._flow.requires - external_provides if missing: raise exc.MissingDependencies(self._flow, sorted(missing)) - if self.storage.has_failures(): - self._revert() - else: - self._run() + self._task_executor.start() + try: + if self.storage.has_failures(): + self._revert() + else: + self._run() + finally: + self._task_executor.stop() def _run(self): self._change_state(states.RUNNING) diff --git a/taskflow/engines/action_engine/executor.py b/taskflow/engines/action_engine/executor.py new file mode 100644 index 00000000..3526bfc0 --- /dev/null +++ b/taskflow/engines/action_engine/executor.py @@ -0,0 +1,154 @@ +# -*- coding: utf-8 -*- + +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright (C) 2013 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import abc +import contextlib +import six + +from concurrent import futures + +from taskflow.utils import misc +from taskflow.utils import threading_utils + + +@contextlib.contextmanager +def _autobind(task, bind_name, bind_func, **kwargs): + try: + task.bind(bind_name, bind_func, **kwargs) + yield task + finally: + task.unbind(bind_name, bind_func) + + +def _noop(*args, **kwargs): + pass + + +def _execute_task(task, arguments, progress_callback): + with _autobind(task, 'update_progress', progress_callback): + try: + result = task.execute(**arguments) + except Exception: + # NOTE(imelnikov): wrap current exception with Failure + # object and return it + result = misc.Failure() + return task, 'executed', result + + +def _revert_task(task, arguments, result, failures, progress_callback): + kwargs = arguments.copy() + kwargs['result'] = result + kwargs['flow_failures'] = failures + with _autobind(task, 'update_progress', progress_callback): + try: + result = task.revert(**kwargs) + except Exception: + # NOTE(imelnikov): wrap current exception with Failure + # object and return it + result = misc.Failure() + return task, 'reverted', result + + +@six.add_metaclass(abc.ABCMeta) +class TaskExecutorBase(object): + """Executes and reverts tasks. + + This class takes task and its arguments and executes or reverts it. + It encapsulates knowledge on how task should be executed or reverted: + right now, on separate thread, on another machine, etc. + """ + + @abc.abstractmethod + def execute_task(self, task, arguments, progress_callback=_noop): + """Schedules task execution.""" + + @abc.abstractmethod + def revert_task(self, task, arguments, result, failures, + progress_callback=_noop): + """Schedules task reversion""" + + @abc.abstractmethod + def wait_for_any(self, fs): + """Wait for futures returned by this executor to complete""" + + def start(self): + """Prepare to execute tasks""" + pass + + def stop(self): + """Finalize task executor""" + pass + + +class SerialTaskExecutor(TaskExecutorBase): + """Execute task one after another.""" + + @staticmethod + def _completed_future(result): + future = futures.Future() + future.set_result(result) + return future + + def execute_task(self, task, arguments, progress_callback=_noop): + return self._completed_future( + _execute_task(task, arguments, progress_callback)) + + def revert_task(self, task, arguments, result, failures, + progress_callback=_noop): + return self._completed_future( + _revert_task(task, arguments, result, + failures, progress_callback)) + + def wait_for_any(self, fs): + # NOTE(imelnikov): this executor returns only done futures + return fs, [] + + +class ParallelTaskExecutor(TaskExecutorBase): + """Executes tasks in parallel. + + Submits tasks to executor which should provide interface similar + to concurrent.Futures.Executor. + """ + + def __init__(self, executor=None): + self._executor = executor + self._own_executor = executor is None + + def execute_task(self, task, arguments, progress_callback=_noop): + return self._executor.submit( + _execute_task, task, arguments, progress_callback) + + def revert_task(self, task, arguments, result, failures, + progress_callback=_noop): + return self._executor.submit( + _revert_task, task, + arguments, result, failures, progress_callback) + + def wait_for_any(self, fs): + return futures.wait(fs, return_when=futures.FIRST_COMPLETED) + + def start(self): + if self._own_executor: + thread_count = threading_utils.get_optimal_thread_count() + self._executor = futures.ThreadPoolExecutor(thread_count) + + def stop(self): + if self._own_executor: + self._executor.shutdown(wait=True) + self._executor = None diff --git a/taskflow/engines/action_engine/task_action.py b/taskflow/engines/action_engine/task_action.py index 2ff14341..030bff81 100644 --- a/taskflow/engines/action_engine/task_action.py +++ b/taskflow/engines/action_engine/task_action.py @@ -16,10 +16,8 @@ # License for the specific language governing permissions and limitations # under the License. -import contextlib import logging -from taskflow.openstack.common import excutils from taskflow import states from taskflow.utils import misc @@ -28,18 +26,10 @@ LOG = logging.getLogger(__name__) SAVE_RESULT_STATES = (states.SUCCESS, states.FAILURE) -@contextlib.contextmanager -def _autobind(task, bind_name, bind_func, **kwargs): - try: - task.bind(bind_name, bind_func, **kwargs) - yield task - finally: - task.unbind(bind_name, bind_func) - - class TaskAction(object): - def __init__(self, storage, notifier): + def __init__(self, storage, task_executor, notifier): self._storage = storage + self._task_executor = task_executor self._notifier = notifier def _change_state(self, task, state, result=None, progress=None): @@ -75,27 +65,29 @@ class TaskAction(object): def execute(self, task): if not self._change_state(task, states.RUNNING, progress=0.0): return - with _autobind(task, 'update_progress', self._on_update_progress): - try: - kwargs = self._storage.fetch_mapped_args(task.rebind) - result = task.execute(**kwargs) - except Exception: - failure = misc.Failure() - self._change_state(task, states.FAILURE, result=failure) - failure.reraise() + kwargs = self._storage.fetch_mapped_args(task.rebind) + future = self._task_executor.execute_task(task, kwargs, + self._on_update_progress) + self._task_executor.wait_for_any(future) + _task, _event, result = future.result() + if isinstance(result, misc.Failure): + self._change_state(task, states.FAILURE, result=result) + result.reraise() self._change_state(task, states.SUCCESS, result=result, progress=1.0) def revert(self, task): if not self._change_state(task, states.REVERTING, progress=0.0): return - with _autobind(task, 'update_progress', self._on_update_progress): - kwargs = self._storage.fetch_mapped_args(task.rebind) - kwargs['result'] = self._storage.get(task.name) - kwargs['flow_failures'] = self._storage.get_failures() - try: - task.revert(**kwargs) - except Exception: - with excutils.save_and_reraise_exception(): - self._change_state(task, states.FAILURE) + kwargs = self._storage.fetch_mapped_args(task.rebind) + task_result = self._storage.get(task.name) + failures = self._storage.get_failures() + future = self._task_executor.revert_task(task, kwargs, + task_result, failures, + self._on_update_progress) + self._task_executor.wait_for_any(future) + _task, _event, result = future.result() + if isinstance(result, misc.Failure): + self._change_state(task, states.FAILURE) + result.reraise() self._change_state(task, states.REVERTED, progress=1.0)