diff --git a/taskflow/engines/action_engine/task_action.py b/taskflow/engines/action_engine/task_action.py index 4944debc..c7421f7e 100644 --- a/taskflow/engines/action_engine/task_action.py +++ b/taskflow/engines/action_engine/task_action.py @@ -16,6 +16,9 @@ # License for the specific language governing permissions and limitations # under the License. +import contextlib +import logging + from taskflow.engines.action_engine import base_action as base from taskflow import exceptions from taskflow.openstack.common import excutils @@ -23,6 +26,22 @@ from taskflow.openstack.common import uuidutils from taskflow import states from taskflow.utils import misc +LOG = logging.getLogger(__name__) + +RESET_TASK_STATES = (states.PENDING,) +SAVE_RESULT_STATES = (states.SUCCESS, states.FAILURE) +ALREADY_FINISHED_STATES = (states.SUCCESS,) +NEVER_RAN_STATES = (states.PENDING,) + + +@contextlib.contextmanager +def _autobind(task, bind_name, bind_func, **kwargs): + try: + task.bind(bind_name, bind_func, **kwargs) + yield task + finally: + task.unbind(bind_name, bind_func) + class TaskAction(base.Action): @@ -49,65 +68,66 @@ class TaskAction(base.Action): def uuid(self): return self._id - def _change_state(self, engine, state): - """Check and update state of task.""" - if state in (states.RUNNING, states.REVERTING, states.PENDING): - self._task.update_progress(0.0) - elif state in (states.SUCCESS, states.REVERTED): - self._task.update_progress(1.0) - engine.storage.set_task_state(self.uuid, state) - engine.on_task_state_change(self, state) - - def _update_result(self, engine, state, result=None): + def _change_state(self, engine, state, result=None, progress=None): """Update result and change state.""" - if state == states.PENDING: + if state in RESET_TASK_STATES: engine.storage.reset(self.uuid) - else: + if state in SAVE_RESULT_STATES: engine.storage.save(self.uuid, result, state) - engine.on_task_state_change(self, state, result) + else: + engine.storage.set_task_state(self.uuid, state) + if progress is not None: + engine.storage.set_task_progress(self.uuid, progress) + engine.on_task_state_change(self, state, result=result) - def _update_progress(self, task, event_data, progress, **kwargs): + def _on_update_progress(self, task, event_data, progress, **kwargs): """Update task progress value that stored in engine.""" - engine = event_data['engine'] - engine.storage.set_task_progress(self.uuid, progress, **kwargs) + try: + engine = event_data['engine'] + engine.storage.set_task_progress(self.uuid, progress, **kwargs) + except Exception: + # Update progress callbacks should never fail, so capture and log + # the emitted exception instead of raising it. + LOG.exception("Failed setting task progress for %s (%s) to %0.3f", + task, self.uuid, progress) + + def _force_state(self, engine, state, progress, result=None): + self._change_state(engine, state, result=result, progress=progress) + self._task.update_progress(progress) def execute(self, engine): - if engine.storage.get_task_state(self.uuid) == states.SUCCESS: + if engine.storage.get_task_state(self.uuid) in ALREADY_FINISHED_STATES: + # Skip tasks that already finished. return - self._task.bind('update_progress', self._update_progress, - engine=engine) - try: - kwargs = engine.storage.fetch_mapped_args(self._args_mapping) - self._change_state(engine, states.RUNNING) - result = self._task.execute(**kwargs) - except Exception: - failure = misc.Failure() - self._update_result(engine, states.FAILURE, failure) - failure.reraise() - else: - self._update_result(engine, states.SUCCESS, result) - finally: - self._task.unbind('update_progress', self._update_progress) + self._force_state(engine, states.RUNNING, 0.0) + with _autobind(self._task, + 'update_progress', self._on_update_progress, + engine=engine): + try: + kwargs = engine.storage.fetch_mapped_args(self._args_mapping) + result = self._task.execute(**kwargs) + except Exception: + failure = misc.Failure() + self._change_state(engine, states.FAILURE, result=failure) + failure.reraise() + self._force_state(engine, states.SUCCESS, 1.0, result=result) def revert(self, engine): - if engine.storage.get_task_state(self.uuid) == states.PENDING: + if engine.storage.get_task_state(self.uuid) in NEVER_RAN_STATES: # NOTE(imelnikov): in all the other states, the task - # execution was at least attempted, so we should give - # task a chance for cleanup + # execution was at least attempted, so we should give + # task a chance for cleanup return - self._task.bind('update_progress', self._update_progress, - engine=engine) - kwargs = engine.storage.fetch_mapped_args(self._args_mapping) - - self._change_state(engine, states.REVERTING) - try: - self._task.revert(result=engine.storage.get(self._id), - **kwargs) - self._change_state(engine, states.REVERTED) - except Exception: - with excutils.save_and_reraise_exception(): - self._change_state(engine, states.FAILURE) - else: - self._update_result(engine, states.PENDING) - finally: - self._task.unbind('update_progress', self._update_progress) + self._force_state(engine, states.REVERTING, 0.0) + with _autobind(self._task, + 'update_progress', self._on_update_progress, + engine=engine): + kwargs = engine.storage.fetch_mapped_args(self._args_mapping) + try: + self._task.revert(result=engine.storage.get(self._id), + **kwargs) + except Exception: + with excutils.save_and_reraise_exception(): + self._change_state(engine, states.FAILURE) + self._force_state(engine, states.REVERTED, 1.0) + self._force_state(engine, states.PENDING, 0.0) diff --git a/taskflow/task.py b/taskflow/task.py index 1796ba9b..d20bf121 100644 --- a/taskflow/task.py +++ b/taskflow/task.py @@ -18,11 +18,15 @@ # under the License. import abc +import logging + import six from taskflow.utils import misc from taskflow.utils import reflection +LOG = logging.getLogger(__name__) + def _save_as_to_mapping(save_as): """Convert save_as to mapping name => index @@ -118,8 +122,7 @@ class BaseTask(object): # can be useful in resuming older versions of tasks. Standard # major, minor version semantics apply. self.version = (1, 0) - # List of callback functions to invoke when progress updated. - self._on_update_progress_notify = [] + # Map of events => callback functions to invoke on task events. self._events_listeners = {} @property @@ -159,6 +162,12 @@ class BaseTask(object): :param progress: task progress float value between 0 and 1 :param kwargs: task specific progress information """ + if progress > 1.0: + LOG.warn("Progress must be <= 1.0, clamping to upper bound") + progress = 1.0 + if progress < 0.0: + LOG.warn("Progress must be >= 0.0, clamping to lower bound") + progress = 0.0 self._trigger('update_progress', progress, **kwargs) def _trigger(self, event, *args, **kwargs): @@ -166,7 +175,11 @@ class BaseTask(object): if event in self._events_listeners: for handler in self._events_listeners[event]: event_data = self._events_listeners[event][handler] - handler(self, event_data, *args, **kwargs) + try: + handler(self, event_data, *args, **kwargs) + except Exception: + LOG.exception("Failed calling `%s` on event '%s'", + reflection.get_callable_name(handler), event) def bind(self, event, handler, **kwargs): """Attach a handler to an event for the task. diff --git a/taskflow/tests/unit/test_action_engine.py b/taskflow/tests/unit/test_action_engine.py index 0ee91448..900df5f0 100644 --- a/taskflow/tests/unit/test_action_engine.py +++ b/taskflow/tests/unit/test_action_engine.py @@ -120,7 +120,7 @@ class AutoSuspendingTask(TestTask): engine.suspend() return result - def revert(self, egnine, result): + def revert(self, engine, result): super(AutoSuspendingTask, self).revert(**{'result': result}) diff --git a/taskflow/tests/unit/test_progress.py b/taskflow/tests/unit/test_progress.py new file mode 100644 index 00000000..56d479d8 --- /dev/null +++ b/taskflow/tests/unit/test_progress.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- + +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import contextlib + +from taskflow import task +from taskflow import test + +from taskflow.engines.action_engine import engine as eng +from taskflow.patterns import linear_flow as lf +from taskflow.persistence.backends import impl_memory +from taskflow.utils import persistence_utils as p_utils + + +class ProgressTask(task.Task): + def __init__(self, name, segments): + super(ProgressTask, self).__init__(name=name) + self._segments = segments + + def execute(self): + if self._segments <= 0: + return + for i in range(1, self._segments): + progress = float(i) / self._segments + self.update_progress(progress) + + +class TestProgress(test.TestCase): + def _make_engine(self, flow): + e = eng.SingleThreadedActionEngine(flow) + e.compile() + return e + + def tearDown(self): + super(TestProgress, self).tearDown() + with contextlib.closing(impl_memory.MemoryBackend({})) as be: + with contextlib.closing(be.get_connection()) as conn: + conn.clear_all() + + def test_sanity_progress(self): + fired_events = [] + + def notify_me(task, event_data, progress): + fired_events.append(progress) + + ev_count = 5 + t = ProgressTask("test", ev_count) + t.bind('update_progress', notify_me) + flo = lf.Flow("test") + flo.add(t) + e = self._make_engine(flo) + e.run() + self.assertEquals(ev_count + 1, len(fired_events)) + self.assertEquals(1.0, fired_events[-1]) + self.assertEquals(0.0, fired_events[0]) + + def test_no_segments_progress(self): + fired_events = [] + + def notify_me(task, event_data, progress): + fired_events.append(progress) + + t = ProgressTask("test", 0) + t.bind('update_progress', notify_me) + flo = lf.Flow("test") + flo.add(t) + e = self._make_engine(flo) + e.run() + # 0.0 and 1.0 should be automatically fired + self.assertEquals(2, len(fired_events)) + self.assertEquals(1.0, fired_events[-1]) + self.assertEquals(0.0, fired_events[0]) + + def test_storage_progress(self): + with contextlib.closing(impl_memory.MemoryBackend({})) as be: + flo = lf.Flow("test") + flo.add(ProgressTask("test", 3)) + b, fd = p_utils.temporary_flow_detail(be) + e = eng.SingleThreadedActionEngine(flo, + book=b, flow_detail=fd, + backend=be) + e.run() + t_uuid = e.storage.get_uuid_by_name("test") + end_progress = e.storage.get_task_progress(t_uuid) + self.assertEquals(1.0, end_progress) + td = fd.find(t_uuid) + self.assertEquals({'progress': 1.0}, td.meta) + + def test_dual_storage_progress(self): + fired_events = [] + + def notify_me(task, event_data, progress): + fired_events.append(progress) + + with contextlib.closing(impl_memory.MemoryBackend({})) as be: + t = ProgressTask("test", 5) + t.bind('update_progress', notify_me) + flo = lf.Flow("test") + flo.add(t) + b, fd = p_utils.temporary_flow_detail(be) + e = eng.SingleThreadedActionEngine(flo, + book=b, flow_detail=fd, + backend=be) + e.run() + + t_uuid = e.storage.get_uuid_by_name("test") + end_progress = e.storage.get_task_progress(t_uuid) + self.assertEquals(1.0, end_progress) + td = fd.find(t_uuid) + self.assertEquals({'progress': 1.0}, td.meta) + self.assertEquals(6, len(fired_events))