Implementing 'acquire_lock' method and fixing workflow completion

* For sqlite it's based on eventlet semaphores
* For other drivers it's assumed that at least READ_COMMITTED
  transactions are used
* Engine method on_action_complete() now acquires lock on
  workflow execution object to prevent incorrect concurrent
  access causing problems in 'join' and 'timeout' policy

Change-Id: I80bc317de4bfd2547f8529d8f2b3238a004d7522
This commit is contained in:
Renat Akhmerov
2015-03-20 17:55:00 +06:00
parent c72481a66c
commit 8b5d58a24b
12 changed files with 378 additions and 52 deletions

View File

@@ -18,6 +18,7 @@ from oslo.config import cfg
from oslo.db import options from oslo.db import options
from oslo.db.sqlalchemy import session as db_session from oslo.db.sqlalchemy import session as db_session
from mistral.db.sqlalchemy import sqlite_lock
from mistral import exceptions as exc from mistral import exceptions as exc
from mistral.openstack.common import log as logging from mistral.openstack.common import log as logging
from mistral import utils from mistral import utils
@@ -111,32 +112,48 @@ def start_tx():
"""Opens new database session and starts new transaction assuming """Opens new database session and starts new transaction assuming
there wasn't any opened sessions within the same thread. there wasn't any opened sessions within the same thread.
""" """
ses = _get_thread_local_session() if _get_thread_local_session():
if ses: raise exc.DataAccessException(
raise exc.DataAccessException("Database transaction has already been" "Database transaction has already been started."
" started.") )
_set_thread_local_session(_get_session()) _set_thread_local_session(_get_session())
def release_locks_if_sqlite(session):
if get_driver_name() == 'sqlite':
sqlite_lock.release_locks(session)
def commit_tx(): def commit_tx():
"""Commits previously started database transaction.""" """Commits previously started database transaction."""
ses = _get_thread_local_session() ses = _get_thread_local_session()
if not ses:
raise exc.DataAccessException("Nothing to commit. Database transaction"
" has not been previously started.")
ses.commit() if not ses:
raise exc.DataAccessException(
"Nothing to commit. Database transaction"
" has not been previously started."
)
try:
ses.commit()
finally:
release_locks_if_sqlite(ses)
def rollback_tx(): def rollback_tx():
"""Rolls back previously started database transaction.""" """Rolls back previously started database transaction."""
ses = _get_thread_local_session() ses = _get_thread_local_session()
if not ses:
raise exc.DataAccessException("Nothing to roll back. Database"
" transaction has not been started.")
ses.rollback() if not ses:
raise exc.DataAccessException(
"Nothing to roll back. Database transaction has not been started."
)
try:
ses.rollback()
finally:
release_locks_if_sqlite(ses)
def end_tx(): def end_tx():
@@ -144,17 +161,24 @@ def end_tx():
It rolls back all uncommitted changes and closes database session. It rolls back all uncommitted changes and closes database session.
""" """
ses = _get_thread_local_session() ses = _get_thread_local_session()
if not ses: if not ses:
raise exc.DataAccessException("Database transaction has not been" raise exc.DataAccessException(
" started.") "Database transaction has not been started."
)
if ses.dirty: if ses.dirty:
ses.rollback() rollback_tx()
ses.close() ses.close()
_set_thread_local_session(None) _set_thread_local_session(None)
@session_aware()
def get_driver_name(session=None):
return session.bind.url.drivername
@session_aware() @session_aware()
def model_query(model, session=None): def model_query(model, session=None):
"""Query helper. """Query helper.

View File

@@ -0,0 +1,54 @@
# Copyright 2015 - Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from eventlet import semaphore
_mutex = semaphore.Semaphore()
_locks = {}
def acquire_lock(obj_id, session):
with _mutex:
if obj_id not in _locks:
_locks[obj_id] = (session, semaphore.BoundedSemaphore(1))
tup = _locks.get(obj_id)
tup[1].acquire()
# Make sure to update the dictionary once the lock is acquired
# to adjust session ownership.
_locks[obj_id] = (session, tup[1])
def release_locks(session):
with _mutex:
for obj_id, tup in _locks.items():
if tup[0] is session:
tup[1].release()
def get_locks():
return _locks
def cleanup():
with _mutex:
# NOTE: For the sake of simplicity we assume that we remove stale locks
# after all tests because this kind of locking can only be used with
# sqlite database. Supporting fully dynamically allocated (and removed)
# locks is much more complex task. If this method is not called after
# tests it will cause a memory leak.
_locks.clear()

View File

@@ -60,6 +60,13 @@ def transaction():
yield yield
# Locking.
def acquire_lock(model, id):
IMPL.acquire_lock(model, id)
# Workbooks. # Workbooks.
def get_workbook(name): def get_workbook(name):

View File

@@ -18,10 +18,12 @@ import sys
from oslo.config import cfg from oslo.config import cfg
from oslo.db import exception as db_exc from oslo.db import exception as db_exc
from oslo.utils import timeutils
import sqlalchemy as sa import sqlalchemy as sa
from mistral.db.sqlalchemy import base as b from mistral.db.sqlalchemy import base as b
from mistral.db.sqlalchemy import model_base as mb from mistral.db.sqlalchemy import model_base as mb
from mistral.db.sqlalchemy import sqlite_lock
from mistral.db.v2.sqlalchemy import models from mistral.db.v2.sqlalchemy import models
from mistral import exceptions as exc from mistral import exceptions as exc
from mistral.openstack.common import log as logging from mistral.openstack.common import log as logging
@@ -85,6 +87,19 @@ def transaction():
end_tx() end_tx()
@b.session_aware()
def acquire_lock(model, id, session=None):
if b.get_driver_name() != 'sqlite':
query = _secure_query(model).filter("id = '%s'" % id)
query.update(
{'updated_at': timeutils.utcnow()},
synchronize_session=False
)
else:
sqlite_lock.acquire_lock(id, session)
def _secure_query(model): def _secure_query(model):
query = b.model_query(model) query = b.model_query(model)

View File

@@ -17,6 +17,7 @@ import copy
import traceback import traceback
from mistral.db.v2 import api as db_api from mistral.db.v2 import api as db_api
from mistral.db.v2.sqlalchemy import models as db_models
from mistral.engine1 import base from mistral.engine1 import base
from mistral.engine1 import task_handler from mistral.engine1 import task_handler
from mistral.engine1 import utils from mistral.engine1 import utils
@@ -88,9 +89,13 @@ class DefaultEngine(base.Engine):
def on_task_state_change(self, state, task_ex_id): def on_task_state_change(self, state, task_ex_id):
with db_api.transaction(): with db_api.transaction():
task_ex = db_api.get_task_execution(task_ex_id) task_ex = db_api.get_task_execution(task_ex_id)
wf_ex = db_api.get_workflow_execution(
task_ex.workflow_execution_id wf_ex_id = task_ex.workflow_execution_id
)
# Must be before loading the object itself (see method doc).
self._lock_workflow_execution(wf_ex_id)
wf_ex = db_api.get_workflow_execution(wf_ex_id)
wf_trace.info( wf_trace.info(
task_ex, task_ex,
@@ -120,28 +125,47 @@ class DefaultEngine(base.Engine):
task_ex.processed = True task_ex.processed = True
if not cmds: self._dispatch_workflow_commands(wf_ex, cmds)
if task_ex.state == states.SUCCESS:
if not wf_utils.find_running_tasks(wf_ex): self._check_workflow_completion(wf_ex, action_ex, wf_ctrl)
wf_handler.succeed_workflow(
wf_ex, @staticmethod
wf_ctrl.evaluate_workflow_final_context() def _check_workflow_completion(wf_ex, action_ex, wf_ctrl):
) if states.is_paused_or_completed(wf_ex.state):
else: return
wf_handler.fail_workflow(wf_ex, task_ex, action_ex)
else: if wf_utils.find_incomplete_tasks(wf_ex):
self._dispatch_workflow_commands(wf_ex, cmds) return
if wf_ctrl.all_errors_handled():
wf_handler.succeed_workflow(
wf_ex,
wf_ctrl.evaluate_workflow_final_context()
)
else:
result_str = str(action_ex.output.get('result', "Unknown"))
state_info = (
"Failure caused by error in task '%s': %s" %
(action_ex.task_execution.name, result_str)
)
wf_handler.fail_workflow(wf_ex, state_info)
@u.log_exec(LOG) @u.log_exec(LOG)
def on_action_complete(self, action_ex_id, result): def on_action_complete(self, action_ex_id, result):
wf_exec_id = None wf_ex_id = None
try: try:
with db_api.transaction(): with db_api.transaction():
action_ex = db_api.get_action_execution(action_ex_id) action_ex = db_api.get_action_execution(action_ex_id)
wf_ex_id = action_ex.task_execution.workflow_execution_id
# Must be before loading the object itself (see method doc).
self._lock_workflow_execution(wf_ex_id)
wf_ex = action_ex.task_execution.workflow_execution wf_ex = action_ex.task_execution.workflow_execution
wf_exec_id = wf_ex.id
task_ex = task_handler.on_action_complete(action_ex, result) task_ex = task_handler.on_action_complete(action_ex, result)
@@ -160,12 +184,15 @@ class DefaultEngine(base.Engine):
"Failed to handle action execution result [id=%s]: %s\n%s", "Failed to handle action execution result [id=%s]: %s\n%s",
action_ex_id, e, traceback.format_exc() action_ex_id, e, traceback.format_exc()
) )
self._fail_workflow(wf_exec_id, e) self._fail_workflow(wf_ex_id, e)
raise e raise e
@u.log_exec(LOG) @u.log_exec(LOG)
def pause_workflow(self, execution_id): def pause_workflow(self, execution_id):
with db_api.transaction(): with db_api.transaction():
# Must be before loading the object itself (see method doc).
self._lock_workflow_execution(execution_id)
wf_ex = db_api.get_workflow_execution(execution_id) wf_ex = db_api.get_workflow_execution(execution_id)
wf_handler.set_execution_state(wf_ex, states.PAUSED) wf_handler.set_execution_state(wf_ex, states.PAUSED)
@@ -176,6 +203,9 @@ class DefaultEngine(base.Engine):
def resume_workflow(self, execution_id): def resume_workflow(self, execution_id):
try: try:
with db_api.transaction(): with db_api.transaction():
# Must be before loading the object itself (see method doc).
self._lock_workflow_execution(execution_id)
wf_ex = db_api.get_workflow_execution(execution_id) wf_ex = db_api.get_workflow_execution(execution_id)
if wf_ex.state != states.PAUSED: if wf_ex.state != states.PAUSED:
@@ -204,14 +234,14 @@ class DefaultEngine(base.Engine):
if states.is_completed(t_ex.state) and not t_ex.processed: if states.is_completed(t_ex.state) and not t_ex.processed:
t_ex.processed = True t_ex.processed = True
self._dispatch_workflow_commands(wf_ex, cmds)
if not cmds: if not cmds:
if not wf_utils.find_running_tasks(wf_ex): if not wf_utils.find_incomplete_tasks(wf_ex):
wf_handler.succeed_workflow( wf_handler.succeed_workflow(
wf_ex, wf_ex,
wf_ctrl.evaluate_workflow_final_context() wf_ctrl.evaluate_workflow_final_context()
) )
else:
self._dispatch_workflow_commands(wf_ex, cmds)
return wf_ex return wf_ex
except Exception as e: except Exception as e:
@@ -225,6 +255,9 @@ class DefaultEngine(base.Engine):
@u.log_exec(LOG) @u.log_exec(LOG)
def stop_workflow(self, execution_id, state, message=None): def stop_workflow(self, execution_id, state, message=None):
with db_api.transaction(): with db_api.transaction():
# Must be before loading the object itself (see method doc).
self._lock_workflow_execution(execution_id)
wf_ex = db_api.get_execution(execution_id) wf_ex = db_api.get_execution(execution_id)
wf_handler.set_execution_state(wf_ex, state, message) wf_handler.set_execution_state(wf_ex, state, message)
@@ -280,7 +313,6 @@ class DefaultEngine(base.Engine):
action_ex = db_api.get_action_execution(action_ex_id) action_ex = db_api.get_action_execution(action_ex_id)
task_handler.on_action_complete( task_handler.on_action_complete(
wf_ex,
action_ex, action_ex,
wf_utils.Result(error=err_msg) wf_utils.Result(error=err_msg)
) )
@@ -321,3 +353,12 @@ class DefaultEngine(base.Engine):
data_flow.add_environment_to_context(wf_ex, wf_ex.context) data_flow.add_environment_to_context(wf_ex, wf_ex.context)
return wf_ex return wf_ex
@staticmethod
def _lock_workflow_execution(wf_exec_id):
# NOTE: Workflow execution object must be locked before
# loading the object itself into the session (either with
# 'get_XXX' or 'load_XXX' methods). Otherwise, there can be
# multiple parallel transactions that see the same state
# and hence the rest of the method logic would not be atomic.
db_api.acquire_lock(db_models.WorkflowExecution, wf_exec_id)

View File

@@ -144,6 +144,10 @@ def _create_task_execution(wf_ex, task_spec, ctx):
'project_id': wf_ex.project_id 'project_id': wf_ex.project_id
}) })
# Add to collection explicitly so that it's in a proper
# state within the current session.
wf_ex.task_executions.append(task_ex)
# TODO(rakhmerov): May be it shouldn't be here. Need to think. # TODO(rakhmerov): May be it shouldn't be here. Need to think.
if task_spec.get_with_items(): if task_spec.get_with_items():
with_items.prepare_runtime_context(task_ex, task_spec) with_items.prepare_runtime_context(task_ex, task_spec)
@@ -152,15 +156,21 @@ def _create_task_execution(wf_ex, task_spec, ctx):
def _create_action_execution(task_ex, action_def, action_input): def _create_action_execution(task_ex, action_def, action_input):
return db_api.create_action_execution({ action_ex = db_api.create_action_execution({
'name': action_def.name, 'name': action_def.name,
'task_execution_id': task_ex.id, 'task_execution_id': task_ex.id,
'workflow_name': task_ex.workflow_name, 'workflow_name': task_ex.workflow_name,
'spec': action_def.spec, 'spec': action_def.spec,
'project_id': task_ex.project_id, 'project_id': task_ex.project_id,
'state': states.RUNNING, 'state': states.RUNNING,
'input': action_input 'input': action_input}
}) )
# Add to collection explicitly so that it's in a proper
# state within the current session.
task_ex.executions.append(action_ex)
return action_ex
def _before_task_start(task_ex, task_spec, wf_spec): def _before_task_start(task_ex, task_spec, wf_spec):

View File

@@ -34,22 +34,11 @@ def succeed_workflow(wf_ex, final_context):
_schedule_send_result_to_parent_workflow(wf_ex) _schedule_send_result_to_parent_workflow(wf_ex)
def fail_workflow(wf_ex, task_ex, action_ex): def fail_workflow(wf_ex, state_info):
if states.is_paused_or_completed(wf_ex.state): if states.is_paused_or_completed(wf_ex.state):
return return
# TODO(rakhmerov): How do we pass task result correctly?. set_execution_state(wf_ex, states.ERROR, state_info)
if action_ex:
msg = str(action_ex.output.get('result', "Unknown"))
else:
msg = "Unknown"
set_execution_state(
wf_ex,
states.ERROR,
"Failure caused by error in task '%s': %s"
% (task_ex.name, msg)
)
if wf_ex.task_execution_id: if wf_ex.task_execution_id:
_schedule_send_result_to_parent_workflow(wf_ex) _schedule_send_result_to_parent_workflow(wf_ex)

View File

@@ -27,6 +27,7 @@ import testtools.matchers as ttm
from mistral import context as auth_context from mistral import context as auth_context
from mistral.db.sqlalchemy import base as db_sa_base from mistral.db.sqlalchemy import base as db_sa_base
from mistral.db.sqlalchemy import sqlite_lock
from mistral.db.v1 import api as db_api_v1 from mistral.db.v1 import api as db_api_v1
from mistral.db.v2 import api as db_api_v2 from mistral.db.v2 import api as db_api_v2
from mistral import engine from mistral import engine
@@ -215,6 +216,8 @@ class DbTestCase(BaseTest):
db_api_v2.delete_cron_triggers() db_api_v2.delete_cron_triggers()
db_api_v2.delete_workflow_definitions() db_api_v2.delete_workflow_definitions()
sqlite_lock.cleanup()
def setUp(self): def setUp(self):
super(DbTestCase, self).setUp() super(DbTestCase, self).setUp()

View File

@@ -0,0 +1,164 @@
# Copyright 2015 - Mirantis, Inc.
# Copyright 2015 - StackStorm, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import eventlet
from oslo.config import cfg
import random
from mistral.db.sqlalchemy import sqlite_lock
from mistral.db.v2.sqlalchemy import api as db_api
from mistral.db.v2.sqlalchemy import models as db_models
from mistral.tests import base as test_base
WF_EXEC = {
'name': '1',
'spec': {},
'start_params': {},
'state': 'RUNNING',
'state_info': "Running...",
'created_at': None,
'updated_at': None,
'context': None,
'task_id': None,
'trust_id': None
}
class WorkflowExecutionTest(test_base.DbTestCase):
def setUp(self):
super(WorkflowExecutionTest, self).setUp()
cfg.CONF.set_default('auth_enable', True, group='pecan')
self.addCleanup(
cfg.CONF.set_default,
'auth_enable',
False,
group='pecan'
)
def _random_sleep(self):
eventlet.sleep(random.Random().randint(0, 10) * 0.001)
def _run_acquire_release_sqlite_lock(self, obj_id, session):
self._random_sleep()
sqlite_lock.acquire_lock(obj_id, session)
self._random_sleep()
sqlite_lock.release_locks(session)
def test_acquire_release_sqlite_lock(self):
threads = []
id = "object_id"
number = 500
for i in range(1, number):
threads.append(
eventlet.spawn(self._run_acquire_release_sqlite_lock, id, i)
)
[t.wait() for t in threads]
[t.kill() for t in threads]
self.assertEqual(1, len(sqlite_lock.get_locks()))
sqlite_lock.cleanup()
self.assertEqual(0, len(sqlite_lock.get_locks()))
def _run_correct_locking(self, wf_ex):
self._random_sleep()
with db_api.transaction():
# Here we lock the object before it gets loaded into the
# session and prevent reading the same object state by
# multiple transactions. Hence the rest of the transaction
# body works atomically (in a serialized manner) and the
# result (object name) must be equal to a number of
# transactions.
db_api.acquire_lock(db_models.WorkflowExecution, wf_ex.id)
# Refresh the object.
wf_ex = db_api.get_workflow_execution(wf_ex.id)
wf_ex.name = str(int(wf_ex.name) + 1)
return wf_ex.name
def test_correct_locking(self):
wf_ex = db_api.create_workflow_execution(WF_EXEC)
threads = []
number = 500
for i in range(1, number):
threads.append(
eventlet.spawn(self._run_correct_locking, wf_ex)
)
[t.wait() for t in threads]
[t.kill() for t in threads]
wf_ex = db_api.get_workflow_execution(wf_ex.id)
print("Correct locking test gave object name: %s" % wf_ex.name)
self.assertEqual(str(number), wf_ex.name)
def _run_invalid_locking(self, wf_ex):
self._random_sleep()
with db_api.transaction():
# Load object into the session (transaction).
wf_ex = db_api.get_workflow_execution(wf_ex.id)
# It's too late to lock the object here because it's already
# been loaded into the session so there should be multiple
# threads that read the same object state so they write the
# same value into DB. As a result we won't get a result
# (object name) equal to a number of transactions.
db_api.acquire_lock(db_models.WorkflowExecution, wf_ex.id)
wf_ex.name = str(int(wf_ex.name) + 1)
return wf_ex.name
def test_invalid_locking(self):
wf_ex = db_api.create_workflow_execution(WF_EXEC)
threads = []
number = 500
for i in range(1, number):
threads.append(
eventlet.spawn(self._run_invalid_locking, wf_ex)
)
[t.wait() for t in threads]
[t.kill() for t in threads]
wf_ex = db_api.get_workflow_execution(wf_ex.id)
print("Invalid locking test gave object name: %s" % wf_ex.name)
self.assertNotEqual(str(number), wf_ex.name)

View File

@@ -57,6 +57,15 @@ class WorkflowController(object):
return self._find_next_commands() return self._find_next_commands()
@abc.abstractmethod
def all_errors_handled(self):
"""Determines if all errors (if any) are handled.
:return: True if either there aren't errors at all or all
errors are considered handled.
"""
raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def evaluate_workflow_final_context(self): def evaluate_workflow_final_context(self):
"""Evaluates final workflow context assuming that workflow has finished. """Evaluates final workflow context assuming that workflow has finished.

View File

@@ -173,6 +173,13 @@ class DirectWorkflowController(base.WorkflowController):
return ctx return ctx
def all_errors_handled(self):
for t_ex in wf_utils.find_error_tasks(self.wf_ex):
if not self.get_on_error_clause(t_ex.name):
return False
return True
def _find_end_tasks(self): def _find_end_tasks(self):
return filter( return filter(
lambda t_db: not self._has_outbound_tasks(t_db), lambda t_db: not self._has_outbound_tasks(t_db),

View File

@@ -87,6 +87,9 @@ class ReverseWorkflowController(base.WorkflowController):
) )
) )
def all_errors_handled(self):
return len(wf_utils.find_error_tasks(self.wf_ex)) == 0
def _find_task_specs_with_satisfied_dependencies(self): def _find_task_specs_with_satisfied_dependencies(self):
"""Given a target task name finds tasks with no dependencies. """Given a target task name finds tasks with no dependencies.