From 9da521b841430ba10f22c03386273e562a21d13a Mon Sep 17 00:00:00 2001 From: Renat Akhmerov Date: Wed, 26 Nov 2014 15:43:41 +0600 Subject: [PATCH] Redesigning engine to move all remote calls from transactions * Fixes engine race condition between start_workflow and on_task_result methods * Engine commands now have local and remote parts (in fact, "in tx" and "non tx") Closes-Bug: #1395679 Change-Id: Icd4aa1a546893b815c01bea23880cde139df2d1b --- mistral/engine1/commands.py | 83 +++++-- mistral/engine1/default_engine.py | 44 +++- mistral/services/action_manager.py | 27 ++- mistral/tests/unit/engine1/test_dataflow.py | 26 +++ .../unit/engine1/test_direct_workflow.py | 2 + .../tests/unit/engine1/test_race_condition.py | 208 ++++++++++++++++++ mistral/utils/inspect_utils.py | 1 + mistral/workflow/base.py | 6 +- mistral/workflow/direct_workflow.py | 2 +- 9 files changed, 364 insertions(+), 35 deletions(-) create mode 100644 mistral/tests/unit/engine1/test_race_condition.py diff --git a/mistral/engine1/commands.py b/mistral/engine1/commands.py index 7b6baf5d..fe516e45 100644 --- a/mistral/engine1/commands.py +++ b/mistral/engine1/commands.py @@ -44,9 +44,13 @@ def _log_execution_state_change(name, from_state, to_state): class EngineCommand(object): """Engine command interface.""" - @abc.abstractmethod - def run(self, exec_db, wf_handler, cause_task_db=None): - """Runs the command. + def run_local(self, exec_db, wf_handler, cause_task_db=None): + """Runs local part of the command. + + "Local" means that the code can be performed within a scope + of an opened DB transaction. For example, for all commands + that simply change a state of execution (e.g. depending on + some conditions) it's enough to implement only this method. :param exec_db: Workflow execution DB object. :param wf_handler: Workflow handler currently being used. @@ -54,12 +58,32 @@ class EngineCommand(object): :return False if engine should stop further command processing, True otherwise. """ - raise NotImplementedError + return True + + def run_remote(self, exec_db, wf_handler, cause_task_db=None): + """Runs remote part of the command. + + "Remote" means that the code cannot be performed within a scope + of an opened DB transaction. All commands that deal with remote + invocations should implement this method. However, they may also + need to implement "run_local" if they need to do something with + DB state of execution and/or tasks. + + :param exec_db: Workflow execution DB object. + :param wf_handler: Workflow handler currently being used. + :param cause_task_db: Task that caused the command to run. + :return False if engine should stop further command processing, + True otherwise. + """ + return True class Noop(EngineCommand): """No-op command.""" - def run(self, exec_db, wf_handler, cause_task_db=None): + def run_local(self, exec_db, wf_handler, cause_task_db=None): + pass + + def run_remote(self, exec_db, wf_handler, cause_task_db=None): pass @@ -68,12 +92,14 @@ class RunTask(EngineCommand): self.task_spec = task_spec self.task_db = task_db - def run(self, exec_db, wf_handler, cause_task_db=None): + if task_db: + self.exec_db = task_db.execution + + def run_local(self, exec_db, wf_handler, cause_task_db=None): LOG.debug('Running workflow task: %s' % self.task_spec) self._prepare_task(exec_db, wf_handler, cause_task_db) self._before_task_start(wf_handler.wf_spec) - self._run_task() return True @@ -82,6 +108,7 @@ class RunTask(EngineCommand): return self.task_db = self._create_db_task(exec_db) + self.exec_db = self.task_db.execution # Evaluate Data Flow properties ('input', 'in_context'). data_flow.prepare_db_task( @@ -111,6 +138,11 @@ class RunTask(EngineCommand): 'project_id': exec_db.project_id }) + def run_remote(self, exec_db, wf_handler, cause_task_db=None): + self._run_task() + + return True + def _run_task(self): # Policies could possibly change task state. if self.task_db.state != states.RUNNING: @@ -130,7 +162,7 @@ class RunTask(EngineCommand): self._run_workflow() def _run_action(self): - exec_db = self.task_db.execution + exec_db = self.exec_db wf_spec = spec_parser.get_workflow_spec(exec_db.wf_spec) action_spec_name = self.task_spec.get_action_name() @@ -178,8 +210,10 @@ class RunTask(EngineCommand): action_input.update(a_m.get_action_context(self.task_db)) for_each = self.task_spec.get_for_each() + if for_each: action_input_collection = self._calc_for_each_input(action_input) + for a_input in action_input_collection: rpc.get_executor_client().run_action( self.task_db.id, @@ -204,11 +238,13 @@ class RunTask(EngineCommand): targets ) - def _calc_for_each_input(self, action_input): + @staticmethod + def _calc_for_each_input(action_input): # In case of for-each iterate over action_input and send # each part of data to executor. # Calculate action input collection for separating input. action_input_collection = [] + for key, value in action_input.items(): for index, item in enumerate(value): iter_context = {key: item} @@ -221,7 +257,7 @@ class RunTask(EngineCommand): return action_input_collection def _run_workflow(self): - parent_exec_db = self.task_db.execution + parent_exec_db = self.exec_db parent_wf_spec = spec_parser.get_workflow_spec(parent_exec_db.wf_spec) wf_spec_name = self.task_spec.get_workflow_name() @@ -251,7 +287,7 @@ class RunTask(EngineCommand): class FailWorkflow(EngineCommand): - def run(self, exec_db, wf_handler, cause_task_db=None): + def run_local(self, exec_db, wf_handler, cause_task_db=None): _log_execution_state_change( exec_db.wf_name, exec_db.state, @@ -262,9 +298,12 @@ class FailWorkflow(EngineCommand): return False + def run_remote(self, exec_db, wf_handler, cause_task_db=None): + return False + class SucceedWorkflow(EngineCommand): - def run(self, exec_db, wf_handler, cause_task_db=None): + def run_local(self, exec_db, wf_handler, cause_task_db=None): _log_execution_state_change( exec_db.wf_name, exec_db.state, @@ -275,18 +314,32 @@ class SucceedWorkflow(EngineCommand): return False + def run_remote(self, exec_db, wf_handler, cause_task_db=None): + return False + class PauseWorkflow(EngineCommand): - def run(self, exec_db, wf_handler, cause_task_db=None): + def run_local(self, exec_db, wf_handler, cause_task_db=None): + _log_execution_state_change( + exec_db.wf_name, + exec_db.state, + states.PAUSED + ) wf_handler.pause_workflow() return False + def run_remote(self, exec_db, wf_handler, cause_task_db=None): + return False + class RollbackWorkflow(EngineCommand): - def run(self, exec_db, wf_handler, cause_task_db=None): - pass + def run_local(self, exec_db, wf_handler, cause_task_db=None): + return True + + def run_remote(self, exec_db, wf_handler, cause_task_db=None): + return True RESERVED_COMMANDS = { diff --git a/mistral/engine1/default_engine.py b/mistral/engine1/default_engine.py index 0d0d1f8f..0391c65e 100644 --- a/mistral/engine1/default_engine.py +++ b/mistral/engine1/default_engine.py @@ -69,7 +69,9 @@ class DefaultEngine(base.Engine): # Calculate commands to process next. cmds = wf_handler.start_workflow(**params) - self._run_commands(cmds, exec_db, wf_handler) + self._run_local_commands(cmds, exec_db, wf_handler) + + self._run_remote_commands(cmds, exec_db, wf_handler) return exec_db @@ -95,9 +97,16 @@ class DefaultEngine(base.Engine): # Calculate commands to process next. cmds = wf_handler.on_task_result(task_db, raw_result) - self._run_commands(cmds, exec_db, wf_handler, task_db) + self._run_local_commands( + cmds, + exec_db, + wf_handler, + task_db + ) - self._check_subworkflow_completion(exec_db) + self._run_remote_commands(cmds, exec_db, wf_handler) + + self._check_subworkflow_completion(exec_db) return task_db @@ -122,7 +131,9 @@ class DefaultEngine(base.Engine): # Calculate commands to process next. cmds = wf_handler.resume_workflow() - self._run_commands(cmds, exec_db, wf_handler) + self._run_local_commands(cmds, exec_db, wf_handler) + + self._run_remote_commands(cmds, exec_db, wf_handler) return exec_db @@ -131,13 +142,26 @@ class DefaultEngine(base.Engine): raise NotImplementedError @staticmethod - def _run_commands(cmds, exec_db, wf_handler, cause_task_db=None): + def _run_local_commands(cmds, exec_db, wf_handler, cause_task_db=None): if not cmds: return for cmd in cmds: - if not cmd.run(exec_db, wf_handler, cause_task_db): - break + if not cmd.run_local(exec_db, wf_handler, cause_task_db): + return False + + return True + + @staticmethod + def _run_remote_commands(cmds, exec_db, wf_handler, cause_task_db=None): + if not cmds: + return + + for cmd in cmds: + if not cmd.run_remote(exec_db, wf_handler, cause_task_db): + return False + + return True @staticmethod def _create_db_execution(wf_db, wf_spec, wf_input, params): @@ -180,7 +204,11 @@ class DefaultEngine(base.Engine): wf_handler = wfh_factory.create_workflow_handler(exec_db) - commands.RunTask(task_spec, task_db).run(exec_db, wf_handler) + cmd = commands.RunTask(task_spec, task_db)\ + + cmd.run_local(exec_db, wf_handler) + + cmd.run_remote(exec_db, wf_handler) def _check_subworkflow_completion(self, exec_db): if not exec_db.parent_task_id: diff --git a/mistral/services/action_manager.py b/mistral/services/action_manager.py index d296f81f..24394ff3 100644 --- a/mistral/services/action_manager.py +++ b/mistral/services/action_manager.py @@ -28,6 +28,8 @@ from mistral import utils from mistral.utils import inspect_utils as i_utils +# TODO(rakhmerov): Make methods more consistent and granular. + LOG = logging.getLogger(__name__) ACTIONS_PATH = '../resources/actions' @@ -48,11 +50,11 @@ def get_registered_actions(**kwargs): return db_api.get_actions(**kwargs) -def _register_action_in_db(name, action_class, attributes, - description=None, input_str=None): +def register_action_class(name, action_class_str, attributes, + description=None, input_str=None): values = { 'name': name, - 'action_class': action_class, + 'action_class': action_class_str, 'attributes': attributes, 'description': description, 'input': input_str, @@ -74,6 +76,7 @@ def _clear_system_action_db(): def sync_db(): _clear_system_action_db() + register_action_classes() register_standard_actions() @@ -90,7 +93,7 @@ def _register_dynamic_action_classes(): for action in actions: attrs = i_utils.get_public_fields(action['class']) - _register_action_in_db( + register_action_class( action['name'], action_class_str, attrs, @@ -114,9 +117,13 @@ def register_action_classes(): attrs = i_utils.get_public_fields(mgr[name].plugin) - _register_action_in_db(name, action_class_str, attrs, - description=description, - input_str=input_str) + register_action_class( + name, + action_class_str, + attrs, + description=description, + input_str=input_str + ) _register_dynamic_action_classes() @@ -134,8 +141,10 @@ def get_action_class(action_full_name): action_db = get_action_db(action_full_name) if action_db: - return action_factory.construct_action_class(action_db.action_class, - action_db.attributes) + return action_factory.construct_action_class( + action_db.action_class, + action_db.attributes + ) def get_action_context(task_db): diff --git a/mistral/tests/unit/engine1/test_dataflow.py b/mistral/tests/unit/engine1/test_dataflow.py index e93e11f8..b2da8e15 100644 --- a/mistral/tests/unit/engine1/test_dataflow.py +++ b/mistral/tests/unit/engine1/test_dataflow.py @@ -21,6 +21,7 @@ from mistral.services import workbooks as wb_service from mistral.tests.unit.engine1 import base LOG = logging.getLogger(__name__) + # Use the set_default method to set value otherwise in certain test cases # the change in value is not permanent. cfg.CONF.set_default('auth_enable', False, group='pecan') @@ -28,7 +29,9 @@ cfg.CONF.set_default('auth_enable', False, group='pecan') WORKBOOK = """ --- version: '2.0' + name: wb + workflows: wf1: type: direct @@ -72,10 +75,33 @@ class DataFlowEngineTest(base.EngineTestCase): self.assertEqual(states.SUCCESS, exec_db.state) tasks = exec_db.tasks + + task1 = self._assert_single_item(tasks, name='task1') + task2 = self._assert_single_item(tasks, name='task2') task3 = self._assert_single_item(tasks, name='task3') self.assertEqual(states.SUCCESS, task3.state) + self.assertDictEqual( + { + 'task': { + 'task1': {'hi': 'Hi,'}, + }, + 'hi': 'Hi,', + }, + task1.output + ) + + self.assertDictEqual( + { + 'task': { + 'task2': {'username': 'Morpheus'}, + }, + 'username': 'Morpheus', + }, + task2.output + ) + self.assertDictEqual( { 'task': { diff --git a/mistral/tests/unit/engine1/test_direct_workflow.py b/mistral/tests/unit/engine1/test_direct_workflow.py index f887d931..afa4323b 100644 --- a/mistral/tests/unit/engine1/test_direct_workflow.py +++ b/mistral/tests/unit/engine1/test_direct_workflow.py @@ -22,6 +22,7 @@ from mistral.tests.unit.engine1 import base # TODO(nmakhotkin) Need to write more tests. LOG = logging.getLogger(__name__) + # Use the set_default method to set value otherwise in certain test cases # the change in value is not permanent. cfg.CONF.set_default('auth_enable', False, group='pecan') @@ -89,6 +90,7 @@ class DirectWorkflowEngineTest(base.EngineTestCase): exec_db = db_api.get_execution(exec_db.id) tasks = exec_db.tasks + task3 = self._assert_single_item(tasks, name='task3') task4 = self._assert_single_item(tasks, name='task4') diff --git a/mistral/tests/unit/engine1/test_race_condition.py b/mistral/tests/unit/engine1/test_race_condition.py new file mode 100644 index 00000000..074e66ed --- /dev/null +++ b/mistral/tests/unit/engine1/test_race_condition.py @@ -0,0 +1,208 @@ +# Copyright 2014 - Mirantis, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from eventlet import corolocal +from eventlet import semaphore +from oslo.config import cfg + +from mistral.actions import base as action_base +from mistral.db.v2 import api as db_api +from mistral.openstack.common import log as logging +from mistral.services import action_manager as a_m +from mistral.services import workflows as wf_service +from mistral.tests.unit.engine1 import base +from mistral.workflow import states + +LOG = logging.getLogger(__name__) + +# Use the set_default method to set value otherwise in certain test cases +# the change in value is not permanent. +cfg.CONF.set_default('auth_enable', False, group='pecan') + + +WF_LONG_ACTION = """ +--- +version: '2.0' + +wf: + type: direct + + description: | + The idea is to use action that runs longer than engine.start_workflow() + method. And we need to check that engine handles this situation. + + output: + result: $.result + + tasks: + task1: + action: std.block + publish: + result: $ +""" + +WF_SHORT_ACTION = """ +--- +version: '2.0' + +wf: + type: direct + + description: | + The idea is to use action that runs faster than engine.start_workflow(). + And we need to check that engine handles this situation as well. This was + a situation previously that led to a race condition in engine, method + on_task_result() was called while DB transaction in start_workflow() was + still active (not committed yet). + To emulate a short action we use a workflow with two start tasks so they + run both in parallel on the first engine iteration when we call method + start_workflow(). First task has a short action that just returns a + predefined result and the second task blocks until the test explicitly + unblocks it. So the first action will always end before start_workflow() + methods ends. + + output: + result: $.result + + tasks: + task1: + action: std.echo output=1 + publish: + result1: $ + + task2: + action: std.block +""" + +ACTION_SEMAPHORE = None +TEST_SEMAPHORE = None + + +class BlockingAction(action_base.Action): + def __init__(self): + pass + + @staticmethod + def unblock_test(): + TEST_SEMAPHORE.release() + + @staticmethod + def wait_for_test(): + ACTION_SEMAPHORE.acquire() + + def run(self): + self.unblock_test() + self.wait_for_test() + + print('Action completed [eventlet_id=%s]' % corolocal.get_ident()) + + return 'test' + + def test(self): + pass + + +class LongActionTest(base.EngineTestCase): + def setUp(self): + super(LongActionTest, self).setUp() + + global ACTION_SEMAPHORE + global TEST_SEMAPHORE + + ACTION_SEMAPHORE = semaphore.Semaphore(1) + TEST_SEMAPHORE = semaphore.Semaphore(0) + + a_m.register_action_class( + 'std.block', + '%s.%s' % (BlockingAction.__module__, BlockingAction.__name__), + None + ) + + @staticmethod + def block_action(): + ACTION_SEMAPHORE.acquire() + + @staticmethod + def unblock_action(): + ACTION_SEMAPHORE.release() + + @staticmethod + def wait_for_action(): + TEST_SEMAPHORE.acquire() + + def test_long_action(self): + wf_service.create_workflows(WF_LONG_ACTION) + + self.block_action() + + exec_db = self.engine.start_workflow('wf', None) + + exec_db = db_api.get_execution(exec_db.id) + + self.assertEqual(states.RUNNING, exec_db.state) + self.assertEqual(states.RUNNING, exec_db.tasks[0].state) + + self.wait_for_action() + + # Here's the point when the action is blocked but already running. + # Do the same check again, it should always pass. + exec_db = db_api.get_execution(exec_db.id) + + self.assertEqual(states.RUNNING, exec_db.state) + self.assertEqual(states.RUNNING, exec_db.tasks[0].state) + + self.unblock_action() + + self._await(lambda: self.is_execution_success(exec_db.id)) + + exec_db = db_api.get_execution(exec_db.id) + + self.assertDictEqual({'result': 'test'}, exec_db.output) + + def test_short_action(self): + wf_service.create_workflows(WF_SHORT_ACTION) + + self.block_action() + + exec_db = self.engine.start_workflow('wf', None) + + exec_db = db_api.get_execution(exec_db.id) + + self.assertEqual(states.RUNNING, exec_db.state) + + tasks = exec_db.tasks + + task1 = self._assert_single_item(exec_db.tasks, name='task1') + task2 = self._assert_single_item( + tasks, + name='task2', + state=states.RUNNING + ) + + self._await(lambda: self.is_task_success(task1.id)) + + self.unblock_action() + + self._await(lambda: self.is_task_success(task2.id)) + self._await(lambda: self.is_execution_success(exec_db.id)) + + task1 = db_api.get_task(task1.id) + + self.assertDictEqual( + { + 'result1': 1, + 'task': {'task1': {'result1': 1}} + }, + task1.output + ) diff --git a/mistral/utils/inspect_utils.py b/mistral/utils/inspect_utils.py index 59281940..ba5c96dc 100644 --- a/mistral/utils/inspect_utils.py +++ b/mistral/utils/inspect_utils.py @@ -22,6 +22,7 @@ def get_public_fields(obj): if not attr.startswith("_")] public_fields = {} + for attribute_str in public_attributes: attr = getattr(obj, attribute_str) is_field = not (inspect.isbuiltin(attr) diff --git a/mistral/workflow/base.py b/mistral/workflow/base.py index 31d98a5e..73f11632 100644 --- a/mistral/workflow/base.py +++ b/mistral/workflow/base.py @@ -129,7 +129,8 @@ class WorkflowHandler(object): return cmds - def _determine_task_output(self, task_spec, task_db, raw_result): + @staticmethod + def _determine_task_output(task_spec, task_db, raw_result): for_each = task_spec.get_for_each() t_name = task_spec.get_name() @@ -172,7 +173,8 @@ class WorkflowHandler(object): else: return data_flow.evaluate_task_output(task_spec, raw_result) - def _determine_task_state(self, task_db, task_spec, raw_result): + @staticmethod + def _determine_task_state(task_db, task_spec, raw_result): state = states.ERROR if raw_result.is_error() else states.SUCCESS for_each = task_spec.get_for_each() diff --git a/mistral/workflow/direct_workflow.py b/mistral/workflow/direct_workflow.py index 10939e28..b18c4d93 100644 --- a/mistral/workflow/direct_workflow.py +++ b/mistral/workflow/direct_workflow.py @@ -88,7 +88,7 @@ class DirectWorkflowHandler(base.WorkflowHandler): Expression 'on_complete' is not mutually exclusive to 'on_success' and 'on_error'. :param task_db: Task DB model. - :param remove_incomplete_joins: True if incomplete "join" + :param remove_unsatisfied_joins: True if incomplete "join" tasks must be excluded from the list of commands. :return: List of task specifications. """