diff --git a/mistral/engine/actions.py b/mistral/engine/actions.py index b154ea1d..e36b24e1 100644 --- a/mistral/engine/actions.py +++ b/mistral/engine/actions.py @@ -75,7 +75,7 @@ class Action(object): self.action_ex.output = {'result': msg} @abc.abstractmethod - def schedule(self, input_dict, target, index=0, desc=''): + def schedule(self, input_dict, target, index=0, desc='', safe_rerun=False): """Schedule action run. This method is needed to schedule action run so its result can @@ -88,11 +88,14 @@ class Action(object): :param target: Target (group of action executors). :param index: Action execution index. Makes sense for some types. :param desc: Action execution description. + :param safe_rerun: If true, action would be re-run if executor dies + during execution. """ raise NotImplementedError @abc.abstractmethod - def run(self, input_dict, target, index=0, desc='', save=True): + def run(self, input_dict, target, index=0, desc='', save=True, + safe_rerun=False): """Immediately run action. This method runs method w/o scheduling its run for a later time. @@ -104,6 +107,8 @@ class Action(object): :param index: Action execution index. Makes sense for some types. :param desc: Action execution description. :param save: True if action execution object needs to be saved. + :param safe_rerun: If true, action would be re-run if executor dies + during execution. :return: Action output. """ raise NotImplementedError @@ -196,7 +201,7 @@ class PythonAction(Action): self._log_result(prev_state, result) @profiler.trace('action-schedule') - def schedule(self, input_dict, target, index=0, desc=''): + def schedule(self, input_dict, target, index=0, desc='', safe_rerun=False): assert not self.action_ex # Assign the action execution ID here to minimize database calls. @@ -209,7 +214,7 @@ class PythonAction(Action): self._create_action_execution( self._prepare_input(input_dict), - self._prepare_runtime_context(index), + self._prepare_runtime_context(index, safe_rerun), desc=desc, action_ex_id=action_ex_id ) @@ -223,11 +228,12 @@ class PythonAction(Action): ) @profiler.trace('action-run') - def run(self, input_dict, target, index=0, desc='', save=True): + def run(self, input_dict, target, index=0, desc='', save=True, + safe_rerun=False): assert not self.action_ex input_dict = self._prepare_input(input_dict) - runtime_ctx = self._prepare_runtime_context(index) + runtime_ctx = self._prepare_runtime_context(index, safe_rerun) # Assign the action execution ID here to minimize database calls. # Otherwise, the input property of the action execution DB object needs @@ -251,7 +257,8 @@ class PythonAction(Action): self.action_def.attributes or {}, input_dict, target, - async=False + async=False, + safe_rerun=safe_rerun ) return self._prepare_output(result) @@ -287,12 +294,13 @@ class PythonAction(Action): """ return _get_action_output(result) if result else None - def _prepare_runtime_context(self, index): + def _prepare_runtime_context(self, index, safe_rerun): """Template method to prepare action runtime context. - Python action inserts index into runtime context. + Python action inserts index into runtime context and information if + given action is safe_rerun. """ - return {'index': index} + return {'index': index, 'safe_rerun': safe_rerun} def _insert_action_context(self, action_ex_id, input_dict, save=True): """Template method to prepare action context. @@ -378,8 +386,11 @@ class AdHocAction(PythonAction): return _get_action_output(result) if result else None - def _prepare_runtime_context(self, index): - ctx = super(AdHocAction, self)._prepare_runtime_context(index) + def _prepare_runtime_context(self, index, safe_rerun): + ctx = super(AdHocAction, self)._prepare_runtime_context( + index, + safe_rerun + ) # Insert special field into runtime context so that we track # a relationship between python action and adhoc action. @@ -432,7 +443,7 @@ class WorkflowAction(Action): pass @profiler.trace('action-schedule') - def schedule(self, input_dict, target, index=0, desc=''): + def schedule(self, input_dict, target, index=0, desc='', safe_rerun=False): assert not self.action_ex parent_wf_ex = self.task_ex.workflow_execution @@ -471,7 +482,8 @@ class WorkflowAction(Action): ) @profiler.trace('action-run') - def run(self, input_dict, target, index=0, desc='', save=True): + def run(self, input_dict, target, index=0, desc='', save=True, + safe_rerun=True): raise NotImplemented('Does not apply to this WorkflowAction.') def is_sync(self, input_dict): @@ -492,7 +504,8 @@ def _run_existing_action(action_ex_id, target): action_def.action_class, action_def.attributes or {}, action_ex.input, - target + target, + safe_rerun=action_ex.runtime_context.get('safe_rerun', False) ) return _get_action_output(result) if result else None diff --git a/mistral/engine/rpc_backend/rpc.py b/mistral/engine/rpc_backend/rpc.py index 79c5e50b..bcfb1242 100644 --- a/mistral/engine/rpc_backend/rpc.py +++ b/mistral/engine/rpc_backend/rpc.py @@ -518,7 +518,7 @@ class ExecutorClient(base.Executor): self._client = get_rpc_client_driver()(rpc_conf_dict) def run_action(self, action_ex_id, action_class_str, attributes, - action_params, target=None, async=True): + action_params, target=None, async=True, safe_rerun=False): """Sends a request to run action to executor. :param action_ex_id: Action execution id. @@ -528,6 +528,8 @@ class ExecutorClient(base.Executor): :param target: Target (group of action executors). :param async: If True, run action in asynchronous mode (w/o waiting for completion). + :param safe_rerun: If true, action would be re-run if executor dies + during execution. :return: Action result. """ @@ -535,7 +537,7 @@ class ExecutorClient(base.Executor): 'action_ex_id': action_ex_id, 'action_class_str': action_class_str, 'attributes': attributes, - 'params': action_params + 'params': action_params, } rpc_client_method = (self._client.async_call diff --git a/mistral/engine/tasks.py b/mistral/engine/tasks.py index bcd54a48..76d8e079 100644 --- a/mistral/engine/tasks.py +++ b/mistral/engine/tasks.py @@ -311,7 +311,11 @@ class RegularTask(Task): action.validate_input(input_dict) - action.schedule(input_dict, target) + action.schedule( + input_dict, + target, + safe_rerun=self.task_spec.get_safe_rerun() + ) def _get_target(self, input_dict): return expr.evaluate_recursively( @@ -399,7 +403,12 @@ class WithItemsTask(RegularTask): action = self._build_action() - action.schedule(input_dict, target, index=idx) + action.schedule( + input_dict, + target, + index=idx, + safe_rerun=self.task_spec.get_safe_rerun() + ) def _get_with_items_input(self): """Calculate input array for separating each action input. diff --git a/mistral/tests/unit/engine/test_environment.py b/mistral/tests/unit/engine/test_environment.py index 04529a1c..c160cc11 100644 --- a/mistral/tests/unit/engine/test_environment.py +++ b/mistral/tests/unit/engine/test_environment.py @@ -78,7 +78,7 @@ workflows: def _run_at_target(action_ex_id, action_class_str, attributes, - action_params, target=None, async=True): + action_params, target=None, async=True, safe_rerun=False): # We'll just call executor directly for testing purposes. executor = default_executor.DefaultExecutor(rpc.get_engine_client()) @@ -172,7 +172,8 @@ class EnvironmentTest(base.EngineTestCase): 'mistral.actions.std_actions.EchoAction', {}, a_ex.input, - TARGET + TARGET, + safe_rerun=False ) def test_subworkflow_env_task_input(self):