Implementing 'acquire_lock' method and fixing workflow completion
* For sqlite it's based on eventlet semaphores * For other drivers it's assumed that at least READ_COMMITTED transactions are used * Engine method on_action_complete() now acquires lock on workflow execution object to prevent incorrect concurrent access causing problems in 'join' and 'timeout' policy Change-Id: I80bc317de4bfd2547f8529d8f2b3238a004d7522
This commit is contained in:
		@@ -18,6 +18,7 @@ from oslo.config import cfg
 | 
				
			|||||||
from oslo.db import options
 | 
					from oslo.db import options
 | 
				
			||||||
from oslo.db.sqlalchemy import session as db_session
 | 
					from oslo.db.sqlalchemy import session as db_session
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mistral.db.sqlalchemy import sqlite_lock
 | 
				
			||||||
from mistral import exceptions as exc
 | 
					from mistral import exceptions as exc
 | 
				
			||||||
from mistral.openstack.common import log as logging
 | 
					from mistral.openstack.common import log as logging
 | 
				
			||||||
from mistral import utils
 | 
					from mistral import utils
 | 
				
			||||||
@@ -111,32 +112,48 @@ def start_tx():
 | 
				
			|||||||
    """Opens new database session and starts new transaction assuming
 | 
					    """Opens new database session and starts new transaction assuming
 | 
				
			||||||
        there wasn't any opened sessions within the same thread.
 | 
					        there wasn't any opened sessions within the same thread.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    ses = _get_thread_local_session()
 | 
					    if _get_thread_local_session():
 | 
				
			||||||
    if ses:
 | 
					        raise exc.DataAccessException(
 | 
				
			||||||
        raise exc.DataAccessException("Database transaction has already been"
 | 
					            "Database transaction has already been started."
 | 
				
			||||||
                                      " started.")
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    _set_thread_local_session(_get_session())
 | 
					    _set_thread_local_session(_get_session())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def release_locks_if_sqlite(session):
 | 
				
			||||||
 | 
					    if get_driver_name() == 'sqlite':
 | 
				
			||||||
 | 
					        sqlite_lock.release_locks(session)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def commit_tx():
 | 
					def commit_tx():
 | 
				
			||||||
    """Commits previously started database transaction."""
 | 
					    """Commits previously started database transaction."""
 | 
				
			||||||
    ses = _get_thread_local_session()
 | 
					    ses = _get_thread_local_session()
 | 
				
			||||||
    if not ses:
 | 
					 | 
				
			||||||
        raise exc.DataAccessException("Nothing to commit. Database transaction"
 | 
					 | 
				
			||||||
                                      " has not been previously started.")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    ses.commit()
 | 
					    if not ses:
 | 
				
			||||||
 | 
					        raise exc.DataAccessException(
 | 
				
			||||||
 | 
					            "Nothing to commit. Database transaction"
 | 
				
			||||||
 | 
					            " has not been previously started."
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        ses.commit()
 | 
				
			||||||
 | 
					    finally:
 | 
				
			||||||
 | 
					        release_locks_if_sqlite(ses)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def rollback_tx():
 | 
					def rollback_tx():
 | 
				
			||||||
    """Rolls back previously started database transaction."""
 | 
					    """Rolls back previously started database transaction."""
 | 
				
			||||||
    ses = _get_thread_local_session()
 | 
					    ses = _get_thread_local_session()
 | 
				
			||||||
    if not ses:
 | 
					 | 
				
			||||||
        raise exc.DataAccessException("Nothing to roll back. Database"
 | 
					 | 
				
			||||||
                                      " transaction has not been started.")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    ses.rollback()
 | 
					    if not ses:
 | 
				
			||||||
 | 
					        raise exc.DataAccessException(
 | 
				
			||||||
 | 
					            "Nothing to roll back. Database transaction has not been started."
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        ses.rollback()
 | 
				
			||||||
 | 
					    finally:
 | 
				
			||||||
 | 
					        release_locks_if_sqlite(ses)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def end_tx():
 | 
					def end_tx():
 | 
				
			||||||
@@ -144,17 +161,24 @@ def end_tx():
 | 
				
			|||||||
        It rolls back all uncommitted changes and closes database session.
 | 
					        It rolls back all uncommitted changes and closes database session.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    ses = _get_thread_local_session()
 | 
					    ses = _get_thread_local_session()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if not ses:
 | 
					    if not ses:
 | 
				
			||||||
        raise exc.DataAccessException("Database transaction has not been"
 | 
					        raise exc.DataAccessException(
 | 
				
			||||||
                                      " started.")
 | 
					            "Database transaction has not been started."
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if ses.dirty:
 | 
					    if ses.dirty:
 | 
				
			||||||
        ses.rollback()
 | 
					        rollback_tx()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    ses.close()
 | 
					    ses.close()
 | 
				
			||||||
    _set_thread_local_session(None)
 | 
					    _set_thread_local_session(None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@session_aware()
 | 
				
			||||||
 | 
					def get_driver_name(session=None):
 | 
				
			||||||
 | 
					    return session.bind.url.drivername
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@session_aware()
 | 
					@session_aware()
 | 
				
			||||||
def model_query(model, session=None):
 | 
					def model_query(model, session=None):
 | 
				
			||||||
    """Query helper.
 | 
					    """Query helper.
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										54
									
								
								mistral/db/sqlalchemy/sqlite_lock.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								mistral/db/sqlalchemy/sqlite_lock.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,54 @@
 | 
				
			|||||||
 | 
					# Copyright 2015 - Mirantis, Inc.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#    Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					#    you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					#    You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#        http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#    Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					#    distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					#    See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					#    limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from eventlet import semaphore
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_mutex = semaphore.Semaphore()
 | 
				
			||||||
 | 
					_locks = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def acquire_lock(obj_id, session):
 | 
				
			||||||
 | 
					    with _mutex:
 | 
				
			||||||
 | 
					        if obj_id not in _locks:
 | 
				
			||||||
 | 
					            _locks[obj_id] = (session, semaphore.BoundedSemaphore(1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        tup = _locks.get(obj_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    tup[1].acquire()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Make sure to update the dictionary once the lock is acquired
 | 
				
			||||||
 | 
					    # to adjust session ownership.
 | 
				
			||||||
 | 
					    _locks[obj_id] = (session, tup[1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def release_locks(session):
 | 
				
			||||||
 | 
					    with _mutex:
 | 
				
			||||||
 | 
					        for obj_id, tup in _locks.items():
 | 
				
			||||||
 | 
					            if tup[0] is session:
 | 
				
			||||||
 | 
					                tup[1].release()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_locks():
 | 
				
			||||||
 | 
					    return _locks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def cleanup():
 | 
				
			||||||
 | 
					    with _mutex:
 | 
				
			||||||
 | 
					        # NOTE: For the sake of simplicity we assume that we remove stale locks
 | 
				
			||||||
 | 
					        # after all tests because this kind of locking can only be used with
 | 
				
			||||||
 | 
					        # sqlite database. Supporting fully dynamically allocated (and removed)
 | 
				
			||||||
 | 
					        # locks is much more complex task. If this method is not called after
 | 
				
			||||||
 | 
					        # tests it will cause a memory leak.
 | 
				
			||||||
 | 
					        _locks.clear()
 | 
				
			||||||
@@ -60,6 +60,13 @@ def transaction():
 | 
				
			|||||||
        yield
 | 
					        yield
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Locking.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def acquire_lock(model, id):
 | 
				
			||||||
 | 
					    IMPL.acquire_lock(model, id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Workbooks.
 | 
					# Workbooks.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_workbook(name):
 | 
					def get_workbook(name):
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -18,10 +18,12 @@ import sys
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from oslo.config import cfg
 | 
					from oslo.config import cfg
 | 
				
			||||||
from oslo.db import exception as db_exc
 | 
					from oslo.db import exception as db_exc
 | 
				
			||||||
 | 
					from oslo.utils import timeutils
 | 
				
			||||||
import sqlalchemy as sa
 | 
					import sqlalchemy as sa
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from mistral.db.sqlalchemy import base as b
 | 
					from mistral.db.sqlalchemy import base as b
 | 
				
			||||||
from mistral.db.sqlalchemy import model_base as mb
 | 
					from mistral.db.sqlalchemy import model_base as mb
 | 
				
			||||||
 | 
					from mistral.db.sqlalchemy import sqlite_lock
 | 
				
			||||||
from mistral.db.v2.sqlalchemy import models
 | 
					from mistral.db.v2.sqlalchemy import models
 | 
				
			||||||
from mistral import exceptions as exc
 | 
					from mistral import exceptions as exc
 | 
				
			||||||
from mistral.openstack.common import log as logging
 | 
					from mistral.openstack.common import log as logging
 | 
				
			||||||
@@ -85,6 +87,19 @@ def transaction():
 | 
				
			|||||||
        end_tx()
 | 
					        end_tx()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@b.session_aware()
 | 
				
			||||||
 | 
					def acquire_lock(model, id, session=None):
 | 
				
			||||||
 | 
					    if b.get_driver_name() != 'sqlite':
 | 
				
			||||||
 | 
					        query = _secure_query(model).filter("id = '%s'" % id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        query.update(
 | 
				
			||||||
 | 
					            {'updated_at': timeutils.utcnow()},
 | 
				
			||||||
 | 
					            synchronize_session=False
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        sqlite_lock.acquire_lock(id, session)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _secure_query(model):
 | 
					def _secure_query(model):
 | 
				
			||||||
    query = b.model_query(model)
 | 
					    query = b.model_query(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -17,6 +17,7 @@ import copy
 | 
				
			|||||||
import traceback
 | 
					import traceback
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from mistral.db.v2 import api as db_api
 | 
					from mistral.db.v2 import api as db_api
 | 
				
			||||||
 | 
					from mistral.db.v2.sqlalchemy import models as db_models
 | 
				
			||||||
from mistral.engine1 import base
 | 
					from mistral.engine1 import base
 | 
				
			||||||
from mistral.engine1 import task_handler
 | 
					from mistral.engine1 import task_handler
 | 
				
			||||||
from mistral.engine1 import utils
 | 
					from mistral.engine1 import utils
 | 
				
			||||||
@@ -88,9 +89,13 @@ class DefaultEngine(base.Engine):
 | 
				
			|||||||
    def on_task_state_change(self, state, task_ex_id):
 | 
					    def on_task_state_change(self, state, task_ex_id):
 | 
				
			||||||
        with db_api.transaction():
 | 
					        with db_api.transaction():
 | 
				
			||||||
            task_ex = db_api.get_task_execution(task_ex_id)
 | 
					            task_ex = db_api.get_task_execution(task_ex_id)
 | 
				
			||||||
            wf_ex = db_api.get_workflow_execution(
 | 
					
 | 
				
			||||||
                task_ex.workflow_execution_id
 | 
					            wf_ex_id = task_ex.workflow_execution_id
 | 
				
			||||||
            )
 | 
					
 | 
				
			||||||
 | 
					            # Must be before loading the object itself (see method doc).
 | 
				
			||||||
 | 
					            self._lock_workflow_execution(wf_ex_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            wf_ex = db_api.get_workflow_execution(wf_ex_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            wf_trace.info(
 | 
					            wf_trace.info(
 | 
				
			||||||
                task_ex,
 | 
					                task_ex,
 | 
				
			||||||
@@ -120,28 +125,47 @@ class DefaultEngine(base.Engine):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            task_ex.processed = True
 | 
					            task_ex.processed = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if not cmds:
 | 
					            self._dispatch_workflow_commands(wf_ex, cmds)
 | 
				
			||||||
                if task_ex.state == states.SUCCESS:
 | 
					
 | 
				
			||||||
                    if not wf_utils.find_running_tasks(wf_ex):
 | 
					            self._check_workflow_completion(wf_ex, action_ex, wf_ctrl)
 | 
				
			||||||
                        wf_handler.succeed_workflow(
 | 
					
 | 
				
			||||||
                            wf_ex,
 | 
					    @staticmethod
 | 
				
			||||||
                            wf_ctrl.evaluate_workflow_final_context()
 | 
					    def _check_workflow_completion(wf_ex, action_ex, wf_ctrl):
 | 
				
			||||||
                        )
 | 
					        if states.is_paused_or_completed(wf_ex.state):
 | 
				
			||||||
                else:
 | 
					            return
 | 
				
			||||||
                    wf_handler.fail_workflow(wf_ex, task_ex, action_ex)
 | 
					
 | 
				
			||||||
            else:
 | 
					        if wf_utils.find_incomplete_tasks(wf_ex):
 | 
				
			||||||
                self._dispatch_workflow_commands(wf_ex, cmds)
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if wf_ctrl.all_errors_handled():
 | 
				
			||||||
 | 
					            wf_handler.succeed_workflow(
 | 
				
			||||||
 | 
					                wf_ex,
 | 
				
			||||||
 | 
					                wf_ctrl.evaluate_workflow_final_context()
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            result_str = str(action_ex.output.get('result', "Unknown"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            state_info = (
 | 
				
			||||||
 | 
					                "Failure caused by error in task '%s': %s" %
 | 
				
			||||||
 | 
					                (action_ex.task_execution.name, result_str)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            wf_handler.fail_workflow(wf_ex, state_info)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @u.log_exec(LOG)
 | 
					    @u.log_exec(LOG)
 | 
				
			||||||
    def on_action_complete(self, action_ex_id, result):
 | 
					    def on_action_complete(self, action_ex_id, result):
 | 
				
			||||||
        wf_exec_id = None
 | 
					        wf_ex_id = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            with db_api.transaction():
 | 
					            with db_api.transaction():
 | 
				
			||||||
                action_ex = db_api.get_action_execution(action_ex_id)
 | 
					                action_ex = db_api.get_action_execution(action_ex_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                wf_ex_id = action_ex.task_execution.workflow_execution_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # Must be before loading the object itself (see method doc).
 | 
				
			||||||
 | 
					                self._lock_workflow_execution(wf_ex_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                wf_ex = action_ex.task_execution.workflow_execution
 | 
					                wf_ex = action_ex.task_execution.workflow_execution
 | 
				
			||||||
                wf_exec_id = wf_ex.id
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                task_ex = task_handler.on_action_complete(action_ex, result)
 | 
					                task_ex = task_handler.on_action_complete(action_ex, result)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -160,12 +184,15 @@ class DefaultEngine(base.Engine):
 | 
				
			|||||||
                "Failed to handle action execution result [id=%s]: %s\n%s",
 | 
					                "Failed to handle action execution result [id=%s]: %s\n%s",
 | 
				
			||||||
                action_ex_id, e, traceback.format_exc()
 | 
					                action_ex_id, e, traceback.format_exc()
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            self._fail_workflow(wf_exec_id, e)
 | 
					            self._fail_workflow(wf_ex_id, e)
 | 
				
			||||||
            raise e
 | 
					            raise e
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @u.log_exec(LOG)
 | 
					    @u.log_exec(LOG)
 | 
				
			||||||
    def pause_workflow(self, execution_id):
 | 
					    def pause_workflow(self, execution_id):
 | 
				
			||||||
        with db_api.transaction():
 | 
					        with db_api.transaction():
 | 
				
			||||||
 | 
					            # Must be before loading the object itself (see method doc).
 | 
				
			||||||
 | 
					            self._lock_workflow_execution(execution_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            wf_ex = db_api.get_workflow_execution(execution_id)
 | 
					            wf_ex = db_api.get_workflow_execution(execution_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            wf_handler.set_execution_state(wf_ex, states.PAUSED)
 | 
					            wf_handler.set_execution_state(wf_ex, states.PAUSED)
 | 
				
			||||||
@@ -176,6 +203,9 @@ class DefaultEngine(base.Engine):
 | 
				
			|||||||
    def resume_workflow(self, execution_id):
 | 
					    def resume_workflow(self, execution_id):
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            with db_api.transaction():
 | 
					            with db_api.transaction():
 | 
				
			||||||
 | 
					                # Must be before loading the object itself (see method doc).
 | 
				
			||||||
 | 
					                self._lock_workflow_execution(execution_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                wf_ex = db_api.get_workflow_execution(execution_id)
 | 
					                wf_ex = db_api.get_workflow_execution(execution_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if wf_ex.state != states.PAUSED:
 | 
					                if wf_ex.state != states.PAUSED:
 | 
				
			||||||
@@ -204,14 +234,14 @@ class DefaultEngine(base.Engine):
 | 
				
			|||||||
                    if states.is_completed(t_ex.state) and not t_ex.processed:
 | 
					                    if states.is_completed(t_ex.state) and not t_ex.processed:
 | 
				
			||||||
                        t_ex.processed = True
 | 
					                        t_ex.processed = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                self._dispatch_workflow_commands(wf_ex, cmds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if not cmds:
 | 
					                if not cmds:
 | 
				
			||||||
                    if not wf_utils.find_running_tasks(wf_ex):
 | 
					                    if not wf_utils.find_incomplete_tasks(wf_ex):
 | 
				
			||||||
                        wf_handler.succeed_workflow(
 | 
					                        wf_handler.succeed_workflow(
 | 
				
			||||||
                            wf_ex,
 | 
					                            wf_ex,
 | 
				
			||||||
                            wf_ctrl.evaluate_workflow_final_context()
 | 
					                            wf_ctrl.evaluate_workflow_final_context()
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    self._dispatch_workflow_commands(wf_ex, cmds)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                return wf_ex
 | 
					                return wf_ex
 | 
				
			||||||
        except Exception as e:
 | 
					        except Exception as e:
 | 
				
			||||||
@@ -225,6 +255,9 @@ class DefaultEngine(base.Engine):
 | 
				
			|||||||
    @u.log_exec(LOG)
 | 
					    @u.log_exec(LOG)
 | 
				
			||||||
    def stop_workflow(self, execution_id, state, message=None):
 | 
					    def stop_workflow(self, execution_id, state, message=None):
 | 
				
			||||||
        with db_api.transaction():
 | 
					        with db_api.transaction():
 | 
				
			||||||
 | 
					            # Must be before loading the object itself (see method doc).
 | 
				
			||||||
 | 
					            self._lock_workflow_execution(execution_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            wf_ex = db_api.get_execution(execution_id)
 | 
					            wf_ex = db_api.get_execution(execution_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            wf_handler.set_execution_state(wf_ex, state, message)
 | 
					            wf_handler.set_execution_state(wf_ex, state, message)
 | 
				
			||||||
@@ -280,7 +313,6 @@ class DefaultEngine(base.Engine):
 | 
				
			|||||||
                action_ex = db_api.get_action_execution(action_ex_id)
 | 
					                action_ex = db_api.get_action_execution(action_ex_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                task_handler.on_action_complete(
 | 
					                task_handler.on_action_complete(
 | 
				
			||||||
                    wf_ex,
 | 
					 | 
				
			||||||
                    action_ex,
 | 
					                    action_ex,
 | 
				
			||||||
                    wf_utils.Result(error=err_msg)
 | 
					                    wf_utils.Result(error=err_msg)
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
@@ -321,3 +353,12 @@ class DefaultEngine(base.Engine):
 | 
				
			|||||||
        data_flow.add_environment_to_context(wf_ex, wf_ex.context)
 | 
					        data_flow.add_environment_to_context(wf_ex, wf_ex.context)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return wf_ex
 | 
					        return wf_ex
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def _lock_workflow_execution(wf_exec_id):
 | 
				
			||||||
 | 
					        # NOTE: Workflow execution object must be locked before
 | 
				
			||||||
 | 
					        # loading the object itself into the session (either with
 | 
				
			||||||
 | 
					        # 'get_XXX' or 'load_XXX' methods). Otherwise, there can be
 | 
				
			||||||
 | 
					        # multiple parallel transactions that see the same state
 | 
				
			||||||
 | 
					        # and hence the rest of the method logic would not be atomic.
 | 
				
			||||||
 | 
					        db_api.acquire_lock(db_models.WorkflowExecution, wf_exec_id)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -144,6 +144,10 @@ def _create_task_execution(wf_ex, task_spec, ctx):
 | 
				
			|||||||
        'project_id': wf_ex.project_id
 | 
					        'project_id': wf_ex.project_id
 | 
				
			||||||
    })
 | 
					    })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Add to collection explicitly so that it's in a proper
 | 
				
			||||||
 | 
					    # state within the current session.
 | 
				
			||||||
 | 
					    wf_ex.task_executions.append(task_ex)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # TODO(rakhmerov): May be it shouldn't be here. Need to think.
 | 
					    # TODO(rakhmerov): May be it shouldn't be here. Need to think.
 | 
				
			||||||
    if task_spec.get_with_items():
 | 
					    if task_spec.get_with_items():
 | 
				
			||||||
        with_items.prepare_runtime_context(task_ex, task_spec)
 | 
					        with_items.prepare_runtime_context(task_ex, task_spec)
 | 
				
			||||||
@@ -152,15 +156,21 @@ def _create_task_execution(wf_ex, task_spec, ctx):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _create_action_execution(task_ex, action_def, action_input):
 | 
					def _create_action_execution(task_ex, action_def, action_input):
 | 
				
			||||||
    return db_api.create_action_execution({
 | 
					    action_ex = db_api.create_action_execution({
 | 
				
			||||||
        'name': action_def.name,
 | 
					        'name': action_def.name,
 | 
				
			||||||
        'task_execution_id': task_ex.id,
 | 
					        'task_execution_id': task_ex.id,
 | 
				
			||||||
        'workflow_name': task_ex.workflow_name,
 | 
					        'workflow_name': task_ex.workflow_name,
 | 
				
			||||||
        'spec': action_def.spec,
 | 
					        'spec': action_def.spec,
 | 
				
			||||||
        'project_id': task_ex.project_id,
 | 
					        'project_id': task_ex.project_id,
 | 
				
			||||||
        'state': states.RUNNING,
 | 
					        'state': states.RUNNING,
 | 
				
			||||||
        'input': action_input
 | 
					        'input': action_input}
 | 
				
			||||||
    })
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Add to collection explicitly so that it's in a proper
 | 
				
			||||||
 | 
					    # state within the current session.
 | 
				
			||||||
 | 
					    task_ex.executions.append(action_ex)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return action_ex
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _before_task_start(task_ex, task_spec, wf_spec):
 | 
					def _before_task_start(task_ex, task_spec, wf_spec):
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -34,22 +34,11 @@ def succeed_workflow(wf_ex, final_context):
 | 
				
			|||||||
        _schedule_send_result_to_parent_workflow(wf_ex)
 | 
					        _schedule_send_result_to_parent_workflow(wf_ex)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def fail_workflow(wf_ex, task_ex, action_ex):
 | 
					def fail_workflow(wf_ex, state_info):
 | 
				
			||||||
    if states.is_paused_or_completed(wf_ex.state):
 | 
					    if states.is_paused_or_completed(wf_ex.state):
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # TODO(rakhmerov): How do we pass task result correctly?.
 | 
					    set_execution_state(wf_ex, states.ERROR, state_info)
 | 
				
			||||||
    if action_ex:
 | 
					 | 
				
			||||||
        msg = str(action_ex.output.get('result', "Unknown"))
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        msg = "Unknown"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    set_execution_state(
 | 
					 | 
				
			||||||
        wf_ex,
 | 
					 | 
				
			||||||
        states.ERROR,
 | 
					 | 
				
			||||||
        "Failure caused by error in task '%s': %s"
 | 
					 | 
				
			||||||
        % (task_ex.name, msg)
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if wf_ex.task_execution_id:
 | 
					    if wf_ex.task_execution_id:
 | 
				
			||||||
        _schedule_send_result_to_parent_workflow(wf_ex)
 | 
					        _schedule_send_result_to_parent_workflow(wf_ex)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -27,6 +27,7 @@ import testtools.matchers as ttm
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from mistral import context as auth_context
 | 
					from mistral import context as auth_context
 | 
				
			||||||
from mistral.db.sqlalchemy import base as db_sa_base
 | 
					from mistral.db.sqlalchemy import base as db_sa_base
 | 
				
			||||||
 | 
					from mistral.db.sqlalchemy import sqlite_lock
 | 
				
			||||||
from mistral.db.v1 import api as db_api_v1
 | 
					from mistral.db.v1 import api as db_api_v1
 | 
				
			||||||
from mistral.db.v2 import api as db_api_v2
 | 
					from mistral.db.v2 import api as db_api_v2
 | 
				
			||||||
from mistral import engine
 | 
					from mistral import engine
 | 
				
			||||||
@@ -215,6 +216,8 @@ class DbTestCase(BaseTest):
 | 
				
			|||||||
            db_api_v2.delete_cron_triggers()
 | 
					            db_api_v2.delete_cron_triggers()
 | 
				
			||||||
            db_api_v2.delete_workflow_definitions()
 | 
					            db_api_v2.delete_workflow_definitions()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sqlite_lock.cleanup()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def setUp(self):
 | 
					    def setUp(self):
 | 
				
			||||||
        super(DbTestCase, self).setUp()
 | 
					        super(DbTestCase, self).setUp()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										164
									
								
								mistral/tests/unit/db/v2/test_locking.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										164
									
								
								mistral/tests/unit/db/v2/test_locking.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,164 @@
 | 
				
			|||||||
 | 
					# Copyright 2015 - Mirantis, Inc.
 | 
				
			||||||
 | 
					# Copyright 2015 - StackStorm, Inc.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#    Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					#    you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					#    You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#        http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#    Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					#    distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					#    See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					#    limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import eventlet
 | 
				
			||||||
 | 
					from oslo.config import cfg
 | 
				
			||||||
 | 
					import random
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mistral.db.sqlalchemy import sqlite_lock
 | 
				
			||||||
 | 
					from mistral.db.v2.sqlalchemy import api as db_api
 | 
				
			||||||
 | 
					from mistral.db.v2.sqlalchemy import models as db_models
 | 
				
			||||||
 | 
					from mistral.tests import base as test_base
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					WF_EXEC = {
 | 
				
			||||||
 | 
					    'name': '1',
 | 
				
			||||||
 | 
					    'spec': {},
 | 
				
			||||||
 | 
					    'start_params': {},
 | 
				
			||||||
 | 
					    'state': 'RUNNING',
 | 
				
			||||||
 | 
					    'state_info': "Running...",
 | 
				
			||||||
 | 
					    'created_at': None,
 | 
				
			||||||
 | 
					    'updated_at': None,
 | 
				
			||||||
 | 
					    'context': None,
 | 
				
			||||||
 | 
					    'task_id': None,
 | 
				
			||||||
 | 
					    'trust_id': None
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class WorkflowExecutionTest(test_base.DbTestCase):
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        super(WorkflowExecutionTest, self).setUp()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cfg.CONF.set_default('auth_enable', True, group='pecan')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.addCleanup(
 | 
				
			||||||
 | 
					            cfg.CONF.set_default,
 | 
				
			||||||
 | 
					            'auth_enable',
 | 
				
			||||||
 | 
					            False,
 | 
				
			||||||
 | 
					            group='pecan'
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _random_sleep(self):
 | 
				
			||||||
 | 
					        eventlet.sleep(random.Random().randint(0, 10) * 0.001)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _run_acquire_release_sqlite_lock(self, obj_id, session):
 | 
				
			||||||
 | 
					        self._random_sleep()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sqlite_lock.acquire_lock(obj_id, session)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._random_sleep()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sqlite_lock.release_locks(session)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_acquire_release_sqlite_lock(self):
 | 
				
			||||||
 | 
					        threads = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        id = "object_id"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        number = 500
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for i in range(1, number):
 | 
				
			||||||
 | 
					            threads.append(
 | 
				
			||||||
 | 
					                eventlet.spawn(self._run_acquire_release_sqlite_lock, id, i)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        [t.wait() for t in threads]
 | 
				
			||||||
 | 
					        [t.kill() for t in threads]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(1, len(sqlite_lock.get_locks()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sqlite_lock.cleanup()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(0, len(sqlite_lock.get_locks()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _run_correct_locking(self, wf_ex):
 | 
				
			||||||
 | 
					        self._random_sleep()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with db_api.transaction():
 | 
				
			||||||
 | 
					            # Here we lock the object before it gets loaded into the
 | 
				
			||||||
 | 
					            # session and prevent reading the same object state by
 | 
				
			||||||
 | 
					            # multiple transactions. Hence the rest of the transaction
 | 
				
			||||||
 | 
					            # body works atomically (in a serialized manner) and the
 | 
				
			||||||
 | 
					            # result (object name) must be equal to a number of
 | 
				
			||||||
 | 
					            # transactions.
 | 
				
			||||||
 | 
					            db_api.acquire_lock(db_models.WorkflowExecution, wf_ex.id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Refresh the object.
 | 
				
			||||||
 | 
					            wf_ex = db_api.get_workflow_execution(wf_ex.id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            wf_ex.name = str(int(wf_ex.name) + 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return wf_ex.name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_correct_locking(self):
 | 
				
			||||||
 | 
					        wf_ex = db_api.create_workflow_execution(WF_EXEC)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        threads = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        number = 500
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for i in range(1, number):
 | 
				
			||||||
 | 
					            threads.append(
 | 
				
			||||||
 | 
					                eventlet.spawn(self._run_correct_locking, wf_ex)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        [t.wait() for t in threads]
 | 
				
			||||||
 | 
					        [t.kill() for t in threads]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        wf_ex = db_api.get_workflow_execution(wf_ex.id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        print("Correct locking test gave object name: %s" % wf_ex.name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(str(number), wf_ex.name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _run_invalid_locking(self, wf_ex):
 | 
				
			||||||
 | 
					        self._random_sleep()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with db_api.transaction():
 | 
				
			||||||
 | 
					            # Load object into the session (transaction).
 | 
				
			||||||
 | 
					            wf_ex = db_api.get_workflow_execution(wf_ex.id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # It's too late to lock the object here because it's already
 | 
				
			||||||
 | 
					            # been loaded into the session so there should be multiple
 | 
				
			||||||
 | 
					            # threads that read the same object state so they write the
 | 
				
			||||||
 | 
					            # same value into DB. As a result we won't get a result
 | 
				
			||||||
 | 
					            # (object name) equal to a number of transactions.
 | 
				
			||||||
 | 
					            db_api.acquire_lock(db_models.WorkflowExecution, wf_ex.id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            wf_ex.name = str(int(wf_ex.name) + 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return wf_ex.name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_invalid_locking(self):
 | 
				
			||||||
 | 
					        wf_ex = db_api.create_workflow_execution(WF_EXEC)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        threads = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        number = 500
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for i in range(1, number):
 | 
				
			||||||
 | 
					            threads.append(
 | 
				
			||||||
 | 
					                eventlet.spawn(self._run_invalid_locking, wf_ex)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        [t.wait() for t in threads]
 | 
				
			||||||
 | 
					        [t.kill() for t in threads]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        wf_ex = db_api.get_workflow_execution(wf_ex.id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        print("Invalid locking test gave object name: %s" % wf_ex.name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertNotEqual(str(number), wf_ex.name)
 | 
				
			||||||
@@ -57,6 +57,15 @@ class WorkflowController(object):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        return self._find_next_commands()
 | 
					        return self._find_next_commands()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @abc.abstractmethod
 | 
				
			||||||
 | 
					    def all_errors_handled(self):
 | 
				
			||||||
 | 
					        """Determines if all errors (if any) are handled.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :return: True if either there aren't errors at all or all
 | 
				
			||||||
 | 
					            errors are considered handled.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @abc.abstractmethod
 | 
					    @abc.abstractmethod
 | 
				
			||||||
    def evaluate_workflow_final_context(self):
 | 
					    def evaluate_workflow_final_context(self):
 | 
				
			||||||
        """Evaluates final workflow context assuming that workflow has finished.
 | 
					        """Evaluates final workflow context assuming that workflow has finished.
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -173,6 +173,13 @@ class DirectWorkflowController(base.WorkflowController):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        return ctx
 | 
					        return ctx
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def all_errors_handled(self):
 | 
				
			||||||
 | 
					        for t_ex in wf_utils.find_error_tasks(self.wf_ex):
 | 
				
			||||||
 | 
					            if not self.get_on_error_clause(t_ex.name):
 | 
				
			||||||
 | 
					                return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _find_end_tasks(self):
 | 
					    def _find_end_tasks(self):
 | 
				
			||||||
        return filter(
 | 
					        return filter(
 | 
				
			||||||
            lambda t_db: not self._has_outbound_tasks(t_db),
 | 
					            lambda t_db: not self._has_outbound_tasks(t_db),
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -87,6 +87,9 @@ class ReverseWorkflowController(base.WorkflowController):
 | 
				
			|||||||
            )
 | 
					            )
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def all_errors_handled(self):
 | 
				
			||||||
 | 
					        return len(wf_utils.find_error_tasks(self.wf_ex)) == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _find_task_specs_with_satisfied_dependencies(self):
 | 
					    def _find_task_specs_with_satisfied_dependencies(self):
 | 
				
			||||||
        """Given a target task name finds tasks with no dependencies.
 | 
					        """Given a target task name finds tasks with no dependencies.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user