diff --git a/taskflow/engines/action_engine/retry_action.py b/taskflow/engines/action_engine/retry_action.py index a1ca3abb..afdfb456 100644 --- a/taskflow/engines/action_engine/retry_action.py +++ b/taskflow/engines/action_engine/retry_action.py @@ -17,7 +17,6 @@ import logging from taskflow.engines.action_engine import executor as ex -from taskflow import exceptions from taskflow import states from taskflow.utils import async_utils from taskflow.utils import misc @@ -39,27 +38,25 @@ class RetryAction(object): return kwargs def change_state(self, retry, state, result=None): - old_state = self._storage.get_atom_state(retry.name) - if old_state == state: - return state != states.PENDING if state in SAVE_RESULT_STATES: self._storage.save(retry.name, result, state) elif state == states.REVERTED: self._storage.cleanup_retry_history(retry.name, state) else: + old_state = self._storage.get_atom_state(retry.name) + if state == old_state: + # NOTE(imelnikov): nothing really changed, so we should not + # write anything to storage and run notifications + return self._storage.set_atom_state(retry.name, state) retry_uuid = self._storage.get_atom_uuid(retry.name) details = dict(retry_name=retry.name, retry_uuid=retry_uuid, result=result) self._notifier.notify(state, details) - return True def execute(self, retry): - if not self.change_state(retry, states.RUNNING): - raise exceptions.InvalidState("Retry controller %s is in invalid " - "state and can't be executed" % - retry.name) + self.change_state(retry, states.RUNNING) kwargs = self._get_retry_args(retry) try: result = retry.execute(**kwargs) @@ -71,10 +68,7 @@ class RetryAction(object): return async_utils.make_completed_future((retry, ex.EXECUTED, result)) def revert(self, retry): - if not self.change_state(retry, states.REVERTING): - raise exceptions.InvalidState("Retry controller %s is in invalid " - "state and can't be reverted" % - retry.name) + self.change_state(retry, states.REVERTING) kwargs = self._get_retry_args(retry) kwargs['flow_failures'] = self._storage.get_failures() try: diff --git a/taskflow/engines/action_engine/task_action.py b/taskflow/engines/action_engine/task_action.py index c0d1daa5..a07ded79 100644 --- a/taskflow/engines/action_engine/task_action.py +++ b/taskflow/engines/action_engine/task_action.py @@ -16,7 +16,6 @@ import logging -from taskflow import exceptions from taskflow import states from taskflow.utils import misc @@ -32,10 +31,30 @@ class TaskAction(object): self._task_executor = task_executor self._notifier = notifier - def change_state(self, task, state, result=None, progress=None): + def _is_identity_transition(self, state, task, progress): + if state in SAVE_RESULT_STATES: + # saving result is never identity transition + return False old_state = self._storage.get_atom_state(task.name) - if old_state == state: - return state != states.PENDING + if state != old_state: + # changing state is not identity transition by definition + return False + # NOTE(imelnikov): last thing to check is that the progress has + # changed, which means progress is not None and is different from + # what is stored in the database. + if progress is None: + return False + old_progress = self._storage.get_task_progress(task.name) + if old_progress != progress: + return False + return True + + def change_state(self, task, state, result=None, progress=None): + if self._is_identity_transition(state, task, progress): + # NOTE(imelnikov): ignore identity transitions in order + # to avoid extra write to storage backend and, what's + # more important, extra notifications + return if state in SAVE_RESULT_STATES: self._storage.save(task.name, result, state) else: @@ -49,7 +68,6 @@ class TaskAction(object): self._notifier.notify(state, details) if progress is not None: task.update_progress(progress) - return True def _on_update_progress(self, task, event_data, progress, **kwargs): """Should be called when task updates its progress.""" @@ -62,9 +80,7 @@ class TaskAction(object): task, progress) def schedule_execution(self, task): - if not self.change_state(task, states.RUNNING, progress=0.0): - raise exceptions.InvalidState("Task %s is in invalid state and" - " can't be executed" % task.name) + self.change_state(task, states.RUNNING, progress=0.0) kwargs = self._storage.fetch_mapped_args(task.rebind, atom_name=task.name) task_uuid = self._storage.get_atom_uuid(task.name) @@ -79,9 +95,7 @@ class TaskAction(object): result=result, progress=1.0) def schedule_reversion(self, task): - if not self.change_state(task, states.REVERTING, progress=0.0): - raise exceptions.InvalidState("Task %s is in invalid state and" - " can't be reverted" % task.name) + self.change_state(task, states.REVERTING, progress=0.0) kwargs = self._storage.fetch_mapped_args(task.rebind, atom_name=task.name) task_uuid = self._storage.get_atom_uuid(task.name)