Implementing 'acquire_lock' method and fixing workflow completion

* For sqlite it's based on eventlet semaphores
* For other drivers it's assumed that at least READ_COMMITTED
  transactions are used
* Engine method on_action_complete() now acquires lock on
  workflow execution object to prevent incorrect concurrent
  access causing problems in 'join' and 'timeout' policy

Change-Id: I80bc317de4bfd2547f8529d8f2b3238a004d7522
This commit is contained in:
Renat Akhmerov 2015-03-20 17:55:00 +06:00
parent c72481a66c
commit 8b5d58a24b
12 changed files with 378 additions and 52 deletions

View File

@ -18,6 +18,7 @@ from oslo.config import cfg
from oslo.db import options
from oslo.db.sqlalchemy import session as db_session
from mistral.db.sqlalchemy import sqlite_lock
from mistral import exceptions as exc
from mistral.openstack.common import log as logging
from mistral import utils
@ -111,32 +112,48 @@ def start_tx():
"""Opens new database session and starts new transaction assuming
there wasn't any opened sessions within the same thread.
"""
ses = _get_thread_local_session()
if ses:
raise exc.DataAccessException("Database transaction has already been"
" started.")
if _get_thread_local_session():
raise exc.DataAccessException(
"Database transaction has already been started."
)
_set_thread_local_session(_get_session())
def release_locks_if_sqlite(session):
if get_driver_name() == 'sqlite':
sqlite_lock.release_locks(session)
def commit_tx():
"""Commits previously started database transaction."""
ses = _get_thread_local_session()
if not ses:
raise exc.DataAccessException("Nothing to commit. Database transaction"
" has not been previously started.")
if not ses:
raise exc.DataAccessException(
"Nothing to commit. Database transaction"
" has not been previously started."
)
try:
ses.commit()
finally:
release_locks_if_sqlite(ses)
def rollback_tx():
"""Rolls back previously started database transaction."""
ses = _get_thread_local_session()
if not ses:
raise exc.DataAccessException("Nothing to roll back. Database"
" transaction has not been started.")
if not ses:
raise exc.DataAccessException(
"Nothing to roll back. Database transaction has not been started."
)
try:
ses.rollback()
finally:
release_locks_if_sqlite(ses)
def end_tx():
@ -144,17 +161,24 @@ def end_tx():
It rolls back all uncommitted changes and closes database session.
"""
ses = _get_thread_local_session()
if not ses:
raise exc.DataAccessException("Database transaction has not been"
" started.")
raise exc.DataAccessException(
"Database transaction has not been started."
)
if ses.dirty:
ses.rollback()
rollback_tx()
ses.close()
_set_thread_local_session(None)
@session_aware()
def get_driver_name(session=None):
return session.bind.url.drivername
@session_aware()
def model_query(model, session=None):
"""Query helper.

View File

@ -0,0 +1,54 @@
# Copyright 2015 - Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from eventlet import semaphore
_mutex = semaphore.Semaphore()
_locks = {}
def acquire_lock(obj_id, session):
with _mutex:
if obj_id not in _locks:
_locks[obj_id] = (session, semaphore.BoundedSemaphore(1))
tup = _locks.get(obj_id)
tup[1].acquire()
# Make sure to update the dictionary once the lock is acquired
# to adjust session ownership.
_locks[obj_id] = (session, tup[1])
def release_locks(session):
with _mutex:
for obj_id, tup in _locks.items():
if tup[0] is session:
tup[1].release()
def get_locks():
return _locks
def cleanup():
with _mutex:
# NOTE: For the sake of simplicity we assume that we remove stale locks
# after all tests because this kind of locking can only be used with
# sqlite database. Supporting fully dynamically allocated (and removed)
# locks is much more complex task. If this method is not called after
# tests it will cause a memory leak.
_locks.clear()

View File

@ -60,6 +60,13 @@ def transaction():
yield
# Locking.
def acquire_lock(model, id):
IMPL.acquire_lock(model, id)
# Workbooks.
def get_workbook(name):

View File

@ -18,10 +18,12 @@ import sys
from oslo.config import cfg
from oslo.db import exception as db_exc
from oslo.utils import timeutils
import sqlalchemy as sa
from mistral.db.sqlalchemy import base as b
from mistral.db.sqlalchemy import model_base as mb
from mistral.db.sqlalchemy import sqlite_lock
from mistral.db.v2.sqlalchemy import models
from mistral import exceptions as exc
from mistral.openstack.common import log as logging
@ -85,6 +87,19 @@ def transaction():
end_tx()
@b.session_aware()
def acquire_lock(model, id, session=None):
if b.get_driver_name() != 'sqlite':
query = _secure_query(model).filter("id = '%s'" % id)
query.update(
{'updated_at': timeutils.utcnow()},
synchronize_session=False
)
else:
sqlite_lock.acquire_lock(id, session)
def _secure_query(model):
query = b.model_query(model)

View File

@ -17,6 +17,7 @@ import copy
import traceback
from mistral.db.v2 import api as db_api
from mistral.db.v2.sqlalchemy import models as db_models
from mistral.engine1 import base
from mistral.engine1 import task_handler
from mistral.engine1 import utils
@ -88,9 +89,13 @@ class DefaultEngine(base.Engine):
def on_task_state_change(self, state, task_ex_id):
with db_api.transaction():
task_ex = db_api.get_task_execution(task_ex_id)
wf_ex = db_api.get_workflow_execution(
task_ex.workflow_execution_id
)
wf_ex_id = task_ex.workflow_execution_id
# Must be before loading the object itself (see method doc).
self._lock_workflow_execution(wf_ex_id)
wf_ex = db_api.get_workflow_execution(wf_ex_id)
wf_trace.info(
task_ex,
@ -120,28 +125,47 @@ class DefaultEngine(base.Engine):
task_ex.processed = True
if not cmds:
if task_ex.state == states.SUCCESS:
if not wf_utils.find_running_tasks(wf_ex):
self._dispatch_workflow_commands(wf_ex, cmds)
self._check_workflow_completion(wf_ex, action_ex, wf_ctrl)
@staticmethod
def _check_workflow_completion(wf_ex, action_ex, wf_ctrl):
if states.is_paused_or_completed(wf_ex.state):
return
if wf_utils.find_incomplete_tasks(wf_ex):
return
if wf_ctrl.all_errors_handled():
wf_handler.succeed_workflow(
wf_ex,
wf_ctrl.evaluate_workflow_final_context()
)
else:
wf_handler.fail_workflow(wf_ex, task_ex, action_ex)
else:
self._dispatch_workflow_commands(wf_ex, cmds)
result_str = str(action_ex.output.get('result', "Unknown"))
state_info = (
"Failure caused by error in task '%s': %s" %
(action_ex.task_execution.name, result_str)
)
wf_handler.fail_workflow(wf_ex, state_info)
@u.log_exec(LOG)
def on_action_complete(self, action_ex_id, result):
wf_exec_id = None
wf_ex_id = None
try:
with db_api.transaction():
action_ex = db_api.get_action_execution(action_ex_id)
wf_ex_id = action_ex.task_execution.workflow_execution_id
# Must be before loading the object itself (see method doc).
self._lock_workflow_execution(wf_ex_id)
wf_ex = action_ex.task_execution.workflow_execution
wf_exec_id = wf_ex.id
task_ex = task_handler.on_action_complete(action_ex, result)
@ -160,12 +184,15 @@ class DefaultEngine(base.Engine):
"Failed to handle action execution result [id=%s]: %s\n%s",
action_ex_id, e, traceback.format_exc()
)
self._fail_workflow(wf_exec_id, e)
self._fail_workflow(wf_ex_id, e)
raise e
@u.log_exec(LOG)
def pause_workflow(self, execution_id):
with db_api.transaction():
# Must be before loading the object itself (see method doc).
self._lock_workflow_execution(execution_id)
wf_ex = db_api.get_workflow_execution(execution_id)
wf_handler.set_execution_state(wf_ex, states.PAUSED)
@ -176,6 +203,9 @@ class DefaultEngine(base.Engine):
def resume_workflow(self, execution_id):
try:
with db_api.transaction():
# Must be before loading the object itself (see method doc).
self._lock_workflow_execution(execution_id)
wf_ex = db_api.get_workflow_execution(execution_id)
if wf_ex.state != states.PAUSED:
@ -204,14 +234,14 @@ class DefaultEngine(base.Engine):
if states.is_completed(t_ex.state) and not t_ex.processed:
t_ex.processed = True
self._dispatch_workflow_commands(wf_ex, cmds)
if not cmds:
if not wf_utils.find_running_tasks(wf_ex):
if not wf_utils.find_incomplete_tasks(wf_ex):
wf_handler.succeed_workflow(
wf_ex,
wf_ctrl.evaluate_workflow_final_context()
)
else:
self._dispatch_workflow_commands(wf_ex, cmds)
return wf_ex
except Exception as e:
@ -225,6 +255,9 @@ class DefaultEngine(base.Engine):
@u.log_exec(LOG)
def stop_workflow(self, execution_id, state, message=None):
with db_api.transaction():
# Must be before loading the object itself (see method doc).
self._lock_workflow_execution(execution_id)
wf_ex = db_api.get_execution(execution_id)
wf_handler.set_execution_state(wf_ex, state, message)
@ -280,7 +313,6 @@ class DefaultEngine(base.Engine):
action_ex = db_api.get_action_execution(action_ex_id)
task_handler.on_action_complete(
wf_ex,
action_ex,
wf_utils.Result(error=err_msg)
)
@ -321,3 +353,12 @@ class DefaultEngine(base.Engine):
data_flow.add_environment_to_context(wf_ex, wf_ex.context)
return wf_ex
@staticmethod
def _lock_workflow_execution(wf_exec_id):
# NOTE: Workflow execution object must be locked before
# loading the object itself into the session (either with
# 'get_XXX' or 'load_XXX' methods). Otherwise, there can be
# multiple parallel transactions that see the same state
# and hence the rest of the method logic would not be atomic.
db_api.acquire_lock(db_models.WorkflowExecution, wf_exec_id)

View File

@ -144,6 +144,10 @@ def _create_task_execution(wf_ex, task_spec, ctx):
'project_id': wf_ex.project_id
})
# Add to collection explicitly so that it's in a proper
# state within the current session.
wf_ex.task_executions.append(task_ex)
# TODO(rakhmerov): May be it shouldn't be here. Need to think.
if task_spec.get_with_items():
with_items.prepare_runtime_context(task_ex, task_spec)
@ -152,15 +156,21 @@ def _create_task_execution(wf_ex, task_spec, ctx):
def _create_action_execution(task_ex, action_def, action_input):
return db_api.create_action_execution({
action_ex = db_api.create_action_execution({
'name': action_def.name,
'task_execution_id': task_ex.id,
'workflow_name': task_ex.workflow_name,
'spec': action_def.spec,
'project_id': task_ex.project_id,
'state': states.RUNNING,
'input': action_input
})
'input': action_input}
)
# Add to collection explicitly so that it's in a proper
# state within the current session.
task_ex.executions.append(action_ex)
return action_ex
def _before_task_start(task_ex, task_spec, wf_spec):

View File

@ -34,22 +34,11 @@ def succeed_workflow(wf_ex, final_context):
_schedule_send_result_to_parent_workflow(wf_ex)
def fail_workflow(wf_ex, task_ex, action_ex):
def fail_workflow(wf_ex, state_info):
if states.is_paused_or_completed(wf_ex.state):
return
# TODO(rakhmerov): How do we pass task result correctly?.
if action_ex:
msg = str(action_ex.output.get('result', "Unknown"))
else:
msg = "Unknown"
set_execution_state(
wf_ex,
states.ERROR,
"Failure caused by error in task '%s': %s"
% (task_ex.name, msg)
)
set_execution_state(wf_ex, states.ERROR, state_info)
if wf_ex.task_execution_id:
_schedule_send_result_to_parent_workflow(wf_ex)

View File

@ -27,6 +27,7 @@ import testtools.matchers as ttm
from mistral import context as auth_context
from mistral.db.sqlalchemy import base as db_sa_base
from mistral.db.sqlalchemy import sqlite_lock
from mistral.db.v1 import api as db_api_v1
from mistral.db.v2 import api as db_api_v2
from mistral import engine
@ -215,6 +216,8 @@ class DbTestCase(BaseTest):
db_api_v2.delete_cron_triggers()
db_api_v2.delete_workflow_definitions()
sqlite_lock.cleanup()
def setUp(self):
super(DbTestCase, self).setUp()

View File

@ -0,0 +1,164 @@
# Copyright 2015 - Mirantis, Inc.
# Copyright 2015 - StackStorm, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import eventlet
from oslo.config import cfg
import random
from mistral.db.sqlalchemy import sqlite_lock
from mistral.db.v2.sqlalchemy import api as db_api
from mistral.db.v2.sqlalchemy import models as db_models
from mistral.tests import base as test_base
WF_EXEC = {
'name': '1',
'spec': {},
'start_params': {},
'state': 'RUNNING',
'state_info': "Running...",
'created_at': None,
'updated_at': None,
'context': None,
'task_id': None,
'trust_id': None
}
class WorkflowExecutionTest(test_base.DbTestCase):
def setUp(self):
super(WorkflowExecutionTest, self).setUp()
cfg.CONF.set_default('auth_enable', True, group='pecan')
self.addCleanup(
cfg.CONF.set_default,
'auth_enable',
False,
group='pecan'
)
def _random_sleep(self):
eventlet.sleep(random.Random().randint(0, 10) * 0.001)
def _run_acquire_release_sqlite_lock(self, obj_id, session):
self._random_sleep()
sqlite_lock.acquire_lock(obj_id, session)
self._random_sleep()
sqlite_lock.release_locks(session)
def test_acquire_release_sqlite_lock(self):
threads = []
id = "object_id"
number = 500
for i in range(1, number):
threads.append(
eventlet.spawn(self._run_acquire_release_sqlite_lock, id, i)
)
[t.wait() for t in threads]
[t.kill() for t in threads]
self.assertEqual(1, len(sqlite_lock.get_locks()))
sqlite_lock.cleanup()
self.assertEqual(0, len(sqlite_lock.get_locks()))
def _run_correct_locking(self, wf_ex):
self._random_sleep()
with db_api.transaction():
# Here we lock the object before it gets loaded into the
# session and prevent reading the same object state by
# multiple transactions. Hence the rest of the transaction
# body works atomically (in a serialized manner) and the
# result (object name) must be equal to a number of
# transactions.
db_api.acquire_lock(db_models.WorkflowExecution, wf_ex.id)
# Refresh the object.
wf_ex = db_api.get_workflow_execution(wf_ex.id)
wf_ex.name = str(int(wf_ex.name) + 1)
return wf_ex.name
def test_correct_locking(self):
wf_ex = db_api.create_workflow_execution(WF_EXEC)
threads = []
number = 500
for i in range(1, number):
threads.append(
eventlet.spawn(self._run_correct_locking, wf_ex)
)
[t.wait() for t in threads]
[t.kill() for t in threads]
wf_ex = db_api.get_workflow_execution(wf_ex.id)
print("Correct locking test gave object name: %s" % wf_ex.name)
self.assertEqual(str(number), wf_ex.name)
def _run_invalid_locking(self, wf_ex):
self._random_sleep()
with db_api.transaction():
# Load object into the session (transaction).
wf_ex = db_api.get_workflow_execution(wf_ex.id)
# It's too late to lock the object here because it's already
# been loaded into the session so there should be multiple
# threads that read the same object state so they write the
# same value into DB. As a result we won't get a result
# (object name) equal to a number of transactions.
db_api.acquire_lock(db_models.WorkflowExecution, wf_ex.id)
wf_ex.name = str(int(wf_ex.name) + 1)
return wf_ex.name
def test_invalid_locking(self):
wf_ex = db_api.create_workflow_execution(WF_EXEC)
threads = []
number = 500
for i in range(1, number):
threads.append(
eventlet.spawn(self._run_invalid_locking, wf_ex)
)
[t.wait() for t in threads]
[t.kill() for t in threads]
wf_ex = db_api.get_workflow_execution(wf_ex.id)
print("Invalid locking test gave object name: %s" % wf_ex.name)
self.assertNotEqual(str(number), wf_ex.name)

View File

@ -57,6 +57,15 @@ class WorkflowController(object):
return self._find_next_commands()
@abc.abstractmethod
def all_errors_handled(self):
"""Determines if all errors (if any) are handled.
:return: True if either there aren't errors at all or all
errors are considered handled.
"""
raise NotImplementedError
@abc.abstractmethod
def evaluate_workflow_final_context(self):
"""Evaluates final workflow context assuming that workflow has finished.

View File

@ -173,6 +173,13 @@ class DirectWorkflowController(base.WorkflowController):
return ctx
def all_errors_handled(self):
for t_ex in wf_utils.find_error_tasks(self.wf_ex):
if not self.get_on_error_clause(t_ex.name):
return False
return True
def _find_end_tasks(self):
return filter(
lambda t_db: not self._has_outbound_tasks(t_db),

View File

@ -87,6 +87,9 @@ class ReverseWorkflowController(base.WorkflowController):
)
)
def all_errors_handled(self):
return len(wf_utils.find_error_tasks(self.wf_ex)) == 0
def _find_task_specs_with_satisfied_dependencies(self):
"""Given a target task name finds tasks with no dependencies.