diff --git a/mistral/engine1/policies.py b/mistral/engine1/policies.py index af0257db..c290831a 100644 --- a/mistral/engine1/policies.py +++ b/mistral/engine1/policies.py @@ -24,6 +24,7 @@ from mistral.workflow import utils _ENGINE_CLIENT_PATH = 'mistral.engine1.rpc.get_engine_client' +_RUN_TASK_EXECUTION_PATH = 'mistral.engine1.task_handler.run_task_execution' def _log_task_delay(task_ex, delay_sec): @@ -156,10 +157,6 @@ class WaitBeforePolicy(base.TaskPolicy): def before_task_start(self, task_ex, task_spec): super(WaitBeforePolicy, self).before_task_start(task_ex, task_spec) - # TODO(rakhmerov): This policy needs to be fixed. - if True: - return - context_key = 'wait_before_policy' runtime_context = _ensure_context_has_key( @@ -190,10 +187,10 @@ class WaitBeforePolicy(base.TaskPolicy): task_ex.state = states.DELAYED scheduler.schedule_call( - _ENGINE_CLIENT_PATH, - 'run_task', + None, + _RUN_TASK_EXECUTION_PATH, self.delay, - task_id=task_ex.id + task_ex_id=task_ex.id, ) diff --git a/mistral/engine1/task_handler.py b/mistral/engine1/task_handler.py index a415753d..2fca4c27 100644 --- a/mistral/engine1/task_handler.py +++ b/mistral/engine1/task_handler.py @@ -36,6 +36,31 @@ from mistral.workflow import with_items LOG = logging.getLogger(__name__) +def run_task_execution(task_ex_id): + """This function runs existent task execution. + + It is needed mostly by scheduler. + """ + task_ex = db_api.get_task_execution(task_ex_id) + task_spec = spec_parser.get_task_spec(task_ex.spec) + wf_spec = spec_parser.get_workflow_spec( + db_api.get_workflow_definition(task_ex.workflow_name).spec + ) + + # Explicitly change task state to RUNNING. + task_ex.state = states.RUNNING + + _run_task_execution(task_ex, task_spec, wf_spec) + + +def _run_task_execution(task_ex, task_spec, wf_spec): + input_dicts = _get_input_dictionaries( + wf_spec, task_ex, task_spec, task_ex.in_context + ) + for input_d in input_dicts: + _run_action_or_workflow(task_ex, task_spec, input_d) + + def run_task(wf_cmd): """Runs a task.""" ctx = wf_cmd.ctx @@ -59,8 +84,7 @@ def run_task(wf_cmd): if task_ex.state != states.RUNNING: return - for input_d in _get_input_dictionaries(wf_spec, task_ex, task_spec, ctx): - _run_action_or_workflow(task_ex, task_spec, input_d) + _run_task_execution(task_ex, task_spec, wf_spec) def on_action_complete(action_ex, result): diff --git a/mistral/tests/unit/engine1/test_policies.py b/mistral/tests/unit/engine1/test_policies.py index 850da283..e660549b 100644 --- a/mistral/tests/unit/engine1/test_policies.py +++ b/mistral/tests/unit/engine1/test_policies.py @@ -253,7 +253,6 @@ class PoliciesTest(base.EngineTestCase): thread_group = scheduler.setup() self.addCleanup(thread_group.stop) - @testtools.skip("Fix policies.") def test_build_policies(self): arr = policies.build_policies( self.task_spec.get_policies(), @@ -329,7 +328,6 @@ class PoliciesTest(base.EngineTestCase): self.assertIsInstance(p, policies.TimeoutPolicy) - @testtools.skip("Fix 'wait-before' policy.") def test_wait_before_policy(self): wb_service.create_workbook_v2(WAIT_BEFORE_WB) @@ -348,7 +346,6 @@ class PoliciesTest(base.EngineTestCase): self._await(lambda: self.is_execution_success(wf_ex.id)) - @testtools.skip("Fix 'wait-before' policy.") def test_wait_before_policy_from_var(self): wb_service.create_workbook_v2(WAIT_BEFORE_FROM_VAR) diff --git a/mistral/utils/serializer.py b/mistral/utils/serializers.py similarity index 100% rename from mistral/utils/serializer.py rename to mistral/utils/serializers.py diff --git a/mistral/workflow/utils.py b/mistral/workflow/utils.py index d8d5ca80..0b10e353 100644 --- a/mistral/workflow/utils.py +++ b/mistral/workflow/utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mistral.utils import serializer +from mistral.utils import serializers from mistral.workbook.v2 import tasks as v2_tasks_spec from mistral.workflow import states @@ -38,7 +38,7 @@ class Result(object): return self.data == other.data and self.error == other.error -class ResultSerializer(serializer.Serializer): +class ResultSerializer(serializers.Serializer): def serialize(self, entity): return {'data': entity.data, 'error': entity.error}