Implementing task execution infrastructure

* Implemented action calls
* Implemented subworkflow calls
* Adding executor interface and default implementation
* Unit tests
* Small refactoring in workflow handlers
* Creating all necessary RPC infrastructure
* Refactoring launch script
* Added __repr__() implementation to MistralContext
* Small fixes in old infrastructure

Change-Id: I134ea526c295ca9bda7214c5403a41966062ff79
This commit is contained in:
Renat Akhmerov 2014-08-19 12:19:35 +07:00
parent 80db103871
commit bc85fbdac8
24 changed files with 971 additions and 314 deletions

View File

@ -40,7 +40,7 @@ from wsgiref import simple_server
from mistral.api import app
from mistral import config
from mistral import context
from mistral import context as ctx
from mistral import engine
from mistral.engine import executor
from mistral.openstack.common import log as logging
@ -50,27 +50,43 @@ LOG = logging.getLogger(__name__)
def launch_executor(transport):
serializer = context.RpcContextSerializer(context.JsonPayloadSerializer())
target = messaging.Target(topic=cfg.CONF.executor.topic,
server=cfg.CONF.executor.host)
target = messaging.Target(
topic=cfg.CONF.executor.topic,
server=cfg.CONF.executor.host
)
# Since engine and executor are tightly coupled, use the engine
# configuration to decide which executor to get.
endpoints = [executor.get_executor(cfg.CONF.engine.engine, transport)]
server = messaging.get_rpc_server(
transport, target, endpoints, executor='eventlet',
serializer=serializer)
transport,
target,
endpoints,
executor='eventlet',
serializer=ctx.RpcContextSerializer(ctx.JsonPayloadSerializer())
)
server.start()
server.wait()
def launch_engine(transport):
serializer = context.RpcContextSerializer(context.JsonPayloadSerializer())
target = messaging.Target(topic=cfg.CONF.engine.topic,
server=cfg.CONF.engine.host)
target = messaging.Target(
topic=cfg.CONF.engine.topic,
server=cfg.CONF.engine.host
)
endpoints = [engine.get_engine(cfg.CONF.engine.engine, transport)]
server = messaging.get_rpc_server(
transport, target, endpoints, executor='eventlet',
serializer=serializer)
transport,
target,
endpoints,
executor='eventlet',
serializer=ctx.RpcContextSerializer(ctx.JsonPayloadSerializer())
)
server.start()
server.wait()
@ -78,10 +94,16 @@ def launch_engine(transport):
def launch_api(transport):
host = cfg.CONF.api.host
port = cfg.CONF.api.port
server = simple_server.make_server(host, port,
app.setup_app(transport=transport))
server = simple_server.make_server(
host,
port,
app.setup_app(transport=transport)
)
LOG.info("Mistral API is serving on http://%s:%s (PID=%s)" %
(host, port, os.getpid()))
server.serve_forever()
@ -89,6 +111,7 @@ def launch_any(transport, options):
# Launch the servers on different threads.
threads = [eventlet.spawn(LAUNCH_OPTIONS[option], transport)
for option in options]
[thread.wait() for thread in threads]
@ -123,7 +146,7 @@ def main():
# processes because the "fake" transport is using an in process queue.
transport = messaging.get_transport(cfg.CONF)
if (cfg.CONF.server == ['all']):
if cfg.CONF.server == ['all']:
# Launch all servers.
launch_any(transport, LAUNCH_OPTIONS.keys())
else:

View File

@ -23,24 +23,19 @@ from oslo.config import cfg
from mistral.openstack.common import log
from mistral import version
launch_opt = cfg.ListOpt(
'server',
default=['all'],
help='Specifies which mistral server to start by the launch script. '
'Valid options are all or any combination of '
'api, engine, and executor.'
)
api_opts = [
cfg.StrOpt('host', default='0.0.0.0', help='Mistral API server host'),
cfg.IntOpt('port', default=8989, help='Mistral API server port')
]
engine_opts = [
cfg.StrOpt('engine', default='default',
help='Mistral engine plugin'),
cfg.StrOpt('host', default='0.0.0.0',
help='Name of the engine node. This can be an opaque '
'identifier. It is not necessarily a hostname, '
'FQDN, or IP address.'),
cfg.StrOpt('topic', default='engine',
help='The message topic that the engine listens on.'),
cfg.StrOpt('version', default='1.0',
help='The version of the engine.')
]
pecan_opts = [
cfg.StrOpt('root', default='mistral.api.controllers.root.RootController',
help='Pecan root controller'),
@ -68,6 +63,19 @@ use_debugger = cfg.BoolOpt(
'Use at your own risk.'
)
engine_opts = [
cfg.StrOpt('engine', default='default',
help='Mistral engine plugin'),
cfg.StrOpt('host', default='0.0.0.0',
help='Name of the engine node. This can be an opaque '
'identifier. It is not necessarily a hostname, '
'FQDN, or IP address.'),
cfg.StrOpt('topic', default='engine',
help='The message topic that the engine listens on.'),
cfg.StrOpt('version', default='1.0',
help='The version of the engine.')
]
executor_opts = [
cfg.StrOpt('host', default='0.0.0.0',
help='Name of the executor node. This can be an opaque '
@ -79,14 +87,6 @@ executor_opts = [
help='The version of the executor.')
]
launch_opt = cfg.ListOpt(
'server',
default=['all'],
help='Specifies which mistral server to start by the launch script. '
'Valid options are all or any combination of '
'api, engine, and executor.'
)
wf_trace_log_name_opt = cfg.StrOpt('workflow_trace_log_name',
default='workflow_trace',
help='Logger name for pretty '

View File

@ -75,6 +75,9 @@ class MistralContext(BaseContext):
"is_admin",
])
def __repr__(self):
return "MistralContext %s" % self.to_dict()
def has_ctx():
return utils.has_thread_local(_CTX_THREAD_LOCAL_NAME)

View File

@ -51,7 +51,11 @@ class Execution(mb.MistralModelBase):
start_params = sa.Column(st.JsonDictType())
state = sa.Column(sa.String(20))
input = sa.Column(st.JsonDictType())
output = sa.Column(st.JsonDictType())
context = sa.Column(st.JsonDictType())
# Can't use ForeignKey constraint here because SqlAlchemy will detect
# a circular dependency and raise an error.
parent_task_id = sa.Column(sa.String(36))
class Task(mb.MistralModelBase):

View File

@ -65,7 +65,7 @@ class Engine(object):
self.transport = get_transport(transport)
@abc.abstractmethod
def _run_task(cls, task_id, action_name, action_params):
def _run_task(self, task_id, action_name, action_params):
raise NotImplementedError()
def start_workflow_execution(self, cntx, **kwargs):

View File

@ -36,6 +36,7 @@ class DefaultEngine(engine.Engine):
exctr = executor.ExecutorClient(self.transport)
LOG.info("Submitted task for execution: '%s'" % task_id)
exctr.handle_task(auth_context.ctx(),
task_id=task_id,
action_name=action_name,

View File

@ -76,6 +76,21 @@ class Engine(object):
raise NotImplementedError
@six.add_metaclass(abc.ABCMeta)
class Executor(object):
"""Action executor interface."""
@abc.abstractmethod
def run_action(self, task_id, action_name, action_params):
"""Runs action.
:param task_id: Corresponding task id.
:param action_name: Action name.
:param action_params: Action parameters.
"""
raise NotImplementedError()
@six.add_metaclass(abc.ABCMeta)
class WorkflowPolicy(object):
"""Workflow policy.

View File

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
#
# Copyright 2014 - Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -19,9 +17,9 @@ from oslo.config import cfg
from mistral.db.v2 import api as db_api
from mistral.engine1 import base
from mistral import exceptions as exc
from mistral.openstack.common import log as logging
from mistral.workbook import parser as spec_parser
from mistral.workflow import base as wf_base
from mistral.workflow import data_flow
from mistral.workflow import states
from mistral.workflow import workflow_handler_factory as wfh_factory
@ -39,101 +37,11 @@ WF_TRACE = logging.getLogger(cfg.CONF.workflow_trace_log_name)
# TODO(rakhmerov): Add necessary logging including WF_TRACE.
def _apply_task_policies(task_db):
# TODO(rakhmerov): Implement.
pass
def _apply_workflow_policies(exec_db, task_db):
# TODO(rakhmerov): Implement.
pass
def _create_db_execution(wb_db, wf_spec, input, start_params):
exec_db = db_api.create_execution({
'wf_spec': wf_spec.to_dict(),
'start_params': start_params,
'state': states.RUNNING,
'input': input,
'context': copy.copy(input) or {}
})
data_flow.add_openstack_data_to_context(wb_db, exec_db.context)
data_flow.add_execution_to_context(exec_db, exec_db.context)
return exec_db
def _create_db_tasks(exec_db, task_specs):
new_db_tasks = []
for task_spec in task_specs:
t = db_api.create_task({
'execution_id': exec_db.id,
'name': task_spec.get_name(),
'state': states.RUNNING,
'spec': task_spec.to_dict(),
'parameters': None,
'in_context': None,
'output': None,
'runtime_context': None
})
new_db_tasks.append(t)
return new_db_tasks
def _prepare_db_tasks(task_specs, exec_db, wf_handler):
wf_spec = spec_parser.get_workflow_spec(exec_db.wf_spec)
new_db_tasks = _create_db_tasks(exec_db, task_specs)
# Evaluate Data Flow properties ('parameters', 'in_context').
for t_db in new_db_tasks:
task_spec = wf_spec.get_tasks()[t_db.name]
data_flow.prepare_db_task(
t_db,
task_spec,
wf_handler.get_upstream_tasks(task_spec),
exec_db
)
def _run_tasks(task_specs):
for t in task_specs:
if t.get_action_name():
_run_action(t)
elif t.get_workflow_name():
_run_workflow(t)
else:
msg = "Neither 'action' nor 'workflow' is defined in task" \
" specification [task_spec=%s]" % t
raise exc.WorkflowException(msg)
def _run_action(t):
# TODO(rakhmerov): Implement.
pass
def _run_workflow(t):
# TODO(rakhmerov): Implement.
pass
def _process_task_specs(task_specs, exec_db, wf_handler):
LOG.debug('Processing workflow tasks: %s' % task_specs)
# DB tasks & Data Flow properties
_prepare_db_tasks(task_specs, exec_db, wf_handler)
# Running actions/workflows.
_run_tasks(task_specs)
class DefaultEngine(base.Engine):
def __init__(self, engine_client, executor_client):
self._engine_client = engine_client
self._executor_client = executor_client
def start_workflow(self, workbook_name, workflow_name, input, **params):
db_api.start_tx()
@ -144,7 +52,7 @@ class DefaultEngine(base.Engine):
spec_parser.get_workbook_spec_from_yaml(wb_db.definition)
wf_spec = wb_spec.get_workflows()[workflow_name]
exec_db = _create_db_execution(wb_db, wf_spec, input, params)
exec_db = self._create_db_execution(wb_db, wf_spec, input, params)
wf_handler = wfh_factory.create_workflow_handler(exec_db, wf_spec)
@ -152,7 +60,7 @@ class DefaultEngine(base.Engine):
task_specs = wf_handler.start_workflow(**params)
if task_specs:
_process_task_specs(task_specs, exec_db, wf_handler)
self._process_task_specs(task_specs, exec_db, wf_handler)
db_api.commit_tx()
finally:
@ -173,10 +81,13 @@ class DefaultEngine(base.Engine):
task_specs = wf_handler.on_task_result(task_db, raw_result)
if task_specs:
_apply_task_policies(task_db)
_apply_workflow_policies(exec_db, task_db)
self._apply_task_policies(task_db)
self._apply_workflow_policies(exec_db, task_db)
_process_task_specs(task_specs, exec_db, wf_handler)
self._process_task_specs(task_specs, exec_db, wf_handler)
if exec_db.state == states.SUCCESS and exec_db.parent_task_id:
self._process_subworkflow_output(exec_db)
db_api.commit_tx()
finally:
@ -212,7 +123,7 @@ class DefaultEngine(base.Engine):
task_specs = wf_handler.resume_workflow()
if task_specs:
_process_task_specs(task_specs, exec_db, wf_handler)
self._process_task_specs(task_specs, exec_db, wf_handler)
db_api.commit_tx()
finally:
@ -223,3 +134,111 @@ class DefaultEngine(base.Engine):
def rollback_workflow(self, execution_id):
# TODO(rakhmerov): Implement.
raise NotImplementedError
def _apply_task_policies(self, task_db):
# TODO(rakhmerov): Implement.
pass
def _apply_workflow_policies(self, exec_db, task_db):
# TODO(rakhmerov): Implement.
pass
def _process_task_specs(self, task_specs, exec_db, wf_handler):
LOG.debug('Processing workflow tasks: %s' % task_specs)
# DB tasks & Data Flow properties
db_tasks = self._prepare_db_tasks(task_specs, exec_db, wf_handler)
# Running actions/workflows.
self._run_tasks(db_tasks, task_specs)
def _prepare_db_tasks(self, task_specs, exec_db, wf_handler):
wf_spec = spec_parser.get_workflow_spec(exec_db.wf_spec)
new_db_tasks = self._create_db_tasks(exec_db, task_specs)
# Evaluate Data Flow properties ('parameters', 'in_context').
for t_db in new_db_tasks:
task_spec = wf_spec.get_tasks()[t_db.name]
data_flow.prepare_db_task(
t_db,
task_spec,
wf_handler.get_upstream_tasks(task_spec),
exec_db
)
return new_db_tasks
def _create_db_execution(self, wb_db, wf_spec, input, params):
exec_db = db_api.create_execution({
'wf_spec': wf_spec.to_dict(),
'start_params': params or {},
'state': states.RUNNING,
'input': input or {},
'output': {},
'context': copy.copy(input) or {},
'parent_task_id': params.get('parent_task_id')
})
data_flow.add_openstack_data_to_context(wb_db, exec_db.context)
data_flow.add_execution_to_context(exec_db, exec_db.context)
return exec_db
def _create_db_tasks(self, exec_db, task_specs):
new_db_tasks = []
for task_spec in task_specs:
t = db_api.create_task({
'execution_id': exec_db.id,
'name': task_spec.get_name(),
'state': states.RUNNING,
'spec': task_spec.to_dict(),
'parameters': None,
'in_context': None,
'output': None,
'runtime_context': None
})
new_db_tasks.append(t)
return new_db_tasks
def _run_tasks(self, db_tasks, task_specs):
for t_db, t_spec in zip(db_tasks, task_specs):
if t_spec.get_action_name():
self._run_action(t_db, t_spec)
elif t_spec.get_workflow_name():
self._run_workflow(t_db, t_spec)
def _run_action(self, task_db, task_spec):
# TODO(rakhmerov): Take care of ad-hoc actions.
action_name = task_spec.get_action_name()
self._executor_client.run_action(
task_db.id,
action_name,
task_db.parameters or {}
)
def _run_workflow(self, task_db, task_spec):
wb_name = task_spec.get_workflow_namespace()
wf_name = task_spec.get_short_workflow_name()
wf_input = task_db.parameters
start_params = copy.copy(task_spec.get_workflow_parameters())
start_params.update({'parent_task_id': task_db.id})
self._engine_client.start_workflow(
wb_name,
wf_name,
wf_input,
**start_params
)
def _process_subworkflow_output(self, exec_db):
self._engine_client.on_task_result(
exec_db.parent_task_id,
wf_base.TaskResult(data=exec_db.output)
)

View File

@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
#
# Copyright 2013 - Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from oslo.config import cfg
from mistral.actions import action_factory as a_f
from mistral.engine1 import base
from mistral import exceptions as exc
from mistral.openstack.common import log as logging
from mistral.workflow import base as wf_base
LOG = logging.getLogger(__name__)
WORKFLOW_TRACE = logging.getLogger(cfg.CONF.workflow_trace_log_name)
class DefaultExecutor(base.Executor):
def __init__(self, engine_client):
self._engine_client = engine_client
def run_action(self, task_id, action_name, action_params):
"""Runs action.
:param task_id: Corresponding task id.
:param action_name: Action name.
:param action_params: Action parameters.
"""
action_cls = a_f.get_action_class(action_name)
try:
action = action_cls(**action_params)
result = action.run()
if action.is_sync():
self._engine_client.on_task_result(
task_id,
wf_base.TaskResult(data=result)
)
except exc.ActionException as e:
LOG.exception(
"Failed to run action [task_id=%s, action_cls='%s',"
" params='%s']\n %s" %
(task_id, action_cls, action_params, e)
)
self._engine_client.on_task_result(
task_id,
wf_base.TaskResult(error=str(e))
)

View File

@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-
#
# Copyright 2014 - Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -14,58 +12,195 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from oslo.config import cfg
from oslo import messaging
from mistral import context as auth_ctx
from mistral.engine1 import base
from mistral.engine1 import default_engine as def_eng
from mistral.engine1 import default_executor as def_executor
from mistral.openstack.common import log as logging
from mistral.workflow import base as wf_base
LOG = logging.getLogger(__name__)
# TODO(rakhmerov): Add engine and executor servers so that we don't need to
# adopt them to work with rpc (taking care about transport, signatures etc.).
_TRANSPORT = None
class EngineClient(object):
"""RPC client for the Engine."""
_ENGINE_SERVER = None
_ENGINE_CLIENT = None
_EXECUTOR_SERVER = None
_EXECUTOR_CLIENT = None
def get_transport():
global _TRANSPORT
return _TRANSPORT if _TRANSPORT \
else messaging.get_transport(cfg.CONF)
def get_engine_server():
global _ENGINE_SERVER
if not _ENGINE_SERVER:
# TODO(rakhmerov): It should be configurable.
_ENGINE_SERVER = EngineServer(
def_eng.DefaultEngine(get_engine_client(), get_executor_client())
)
return _ENGINE_SERVER
def get_engine_client():
global _ENGINE_CLIENT
if not _ENGINE_CLIENT:
_ENGINE_CLIENT = EngineClient(get_transport())
return _ENGINE_CLIENT
def get_executor_server():
global _EXECUTOR_SERVER
if not _EXECUTOR_SERVER:
# TODO(rakhmerov): It should be configurable.
_EXECUTOR_SERVER = ExecutorServer(
def_executor.DefaultExecutor(get_engine_client())
)
return _EXECUTOR_SERVER
def get_executor_client():
global _EXECUTOR_CLIENT
if not _ENGINE_CLIENT:
_EXECUTOR_CLIENT = ExecutorClient(get_transport())
return _EXECUTOR_CLIENT
# TODO(rakhmerov): Take care of request context
class EngineServer(object):
"""RPC Engine server."""
def __init__(self, engine):
self._engine = engine
def start_workflow(self, rpc_ctx, workbook_name, workflow_name, input,
params):
"""Receives calls over RPC to start workflows on engine.
:param rpc_ctx: RPC request context.
:return: Workflow execution.
"""
LOG.info(
"Received RPC request 'start_workflow'[rpc_ctx=%s,"
" workbook_name=%s, workflow_name=%s, input=%s, params=%s]"
% (rpc_ctx, workbook_name, workflow_name, input, params)
)
return self._engine.start_workflow(
workbook_name,
workflow_name,
input,
**params
)
def on_task_result(self, rpc_ctx, task_id, result_data, result_error):
"""Receives calls over RPC to communicate task result to engine.
:param rpc_ctx: RPC request context.
:return: Task.
"""
task_result = wf_base.TaskResult(result_data, result_error)
LOG.info(
"Received RPC request 'on_task_result'[rpc_ctx=%s,"
" task_id=%s, task_result=%s]" % (rpc_ctx, task_id, task_result)
)
return self._engine.on_task_result(task_id, task_result)
def stop_workflow(self, rpc_ctx, execution_id):
"""Receives calls over RPC to stop workflows on engine.
:param rpc_ctx: Request context.
:return: Workflow execution.
"""
LOG.info(
"Received RPC request 'stop_workflow'[rpc_ctx=%s,"
" execution_id=%s]" % (rpc_ctx, execution_id)
)
return self._engine.stop_workflow(execution_id)
def resume_workflow(self, rpc_ctx, execution_id):
"""Receives calls over RPC to resume workflows on engine.
:param rpc_ctx: RPC request context.
:return: Workflow execution.
"""
LOG.info(
"Received RPC request 'resume_workflow'[rpc_ctx=%s,"
" execution_id=%s]" % (rpc_ctx, execution_id)
)
return self._engine.resume_workflow(execution_id)
def rollback_workflow(self, rpc_ctx, execution_id):
"""Receives calls over RPC to rollback workflows on engine.
:param rpc_ctx: RPC request context.
:return: Workflow execution.
"""
LOG.info(
"Received RPC request 'rollback_workflow'[rpc_ctx=%s,"
" execution_id=%s]" % (rpc_ctx, execution_id)
)
return self._engine.resume_workflow(execution_id)
class EngineClient(base.Engine):
"""RPC Engine client."""
def __init__(self, transport):
"""Construct an RPC client for the Engine.
"""Constructs an RPC client for engine.
:param transport: Messaging transport.
:type transport: Transport.
"""
serializer = auth_ctx.RpcContextSerializer(
auth_ctx.JsonPayloadSerializer())
# TODO(rakhmerov): Clarify topic.
target = messaging.Target(
topic='mistral.engine1.default_engine:DefaultEngine'
)
self._client = messaging.RPCClient(
transport,
target,
messaging.Target(topic=cfg.CONF.engine.topic),
serializer=serializer
)
def start_workflow(self, workbook_name, workflow_name, task_name, input):
"""Starts a workflow execution based on the specified workbook name
and target task.
def start_workflow(self, workbook_name, workflow_name, input, **params):
"""Starts workflow sending a request to engine over RPC.
:param workbook_name: Workbook name.
:param task_name: Target task name.
:param input: Workflow input data.
:return: Workflow execution.
"""
kwargs = {
'workbook_name': workbook_name,
'workflow_name': workflow_name,
'task_name': task_name,
'input': input
}
return self._client.call(auth_ctx.ctx(), 'start_workflow', **kwargs)
return self._client.call(
auth_ctx.ctx(),
'start_workflow',
workbook_name=workbook_name,
workflow_name=workflow_name,
input=input,
params=params
)
def on_task_result(self, task_id, task_result):
"""Conveys task result to Mistral Engine.
@ -79,78 +214,101 @@ class EngineClient(object):
it possibly needs to move the workflow on, i.e. run other workflow
tasks for which all dependencies are satisfied.
:param task_id: Task id.
:param task_result: Task result data.
:return: Task.
"""
kwargs = {
'task_id': task_id,
'task_result': task_result
}
return self._client.call(auth_ctx.ctx(), 'on_task_result', **kwargs)
return self._client.call(
auth_ctx.ctx(),
'on_task_result',
task_id=task_id,
result_data=task_result.data,
result_error=task_result.error
)
def stop_workflow(self, execution_id):
"""Stops the workflow with the given execution id.
:param execution_id: Workflow execution id.
:return: Workflow execution.
"""
kwargs = {'execution_id': execution_id}
return self._client.call(auth_ctx.ctx(), 'stop_workflow', **kwargs)
return self._client.call(
auth_ctx.ctx(),
'stop_workflow',
execution_id=execution_id
)
def resume_workflow(self, execution_id):
"""Resumes the workflow with the given execution id.
:param execution_id: Workflow execution id.
:return: Workflow execution.
"""
kwargs = {'execution_id': execution_id}
return self._client.call(auth_ctx.ctx(), 'resume_workflow', **kwargs)
return self._client.call(
auth_ctx.ctx(),
'resume_workflow',
execution_id=execution_id
)
def rollback_workflow(self, execution_id):
"""Rolls back the workflow with the given execution id.
:param execution_id: Workflow execution id.
:return: Workflow execution.
"""
kwargs = {'execution_id': execution_id}
return self._client.call(auth_ctx.ctx(), 'rollback_workflow', **kwargs)
return self._client.call(
auth_ctx.ctx(),
'rollback_workflow',
execution_id=execution_id
)
class ExecutorClient(object):
"""RPC client for Executor."""
class ExecutorServer(object):
"""RPC Executor server."""
def __init__(self, executor):
self._executor = executor
def run_action(self, rpc_ctx, task_id, action_name, params):
"""Receives calls over RPC to run task on engine.
:param rpc_ctx: RPC request context dictionary.
"""
LOG.info(
"Received RPC request 'run_action'[rpc_ctx=%s,"
" task_id=%s, action_name=%s, params=%s]"
% (rpc_ctx, task_id, action_name, params)
)
self._executor.run_action(task_id, action_name, params)
class ExecutorClient(base.Executor):
"""RPC Executor client."""
def __init__(self, transport):
"""Construct an RPC client for the Executor.
"""Constructs an RPC client for the Executor.
:param transport: Messaging transport.
:type transport: Transport.
"""
serializer = auth_ctx.RpcContextSerializer(
auth_ctx.JsonPayloadSerializer())
# TODO(rakhmerov): Clarify topic.
target = messaging.Target(
topic='mistral.engine1.default_engine:DefaultExecutor'
auth_ctx.JsonPayloadSerializer()
)
self._client = messaging.RPCClient(
transport,
target,
messaging.Target(topic=cfg.CONF.executor.topic),
serializer=serializer
)
# TODO(rakhmerov): Most likely it will be a different method.
def handle_task(self, cntx, **kwargs):
"""Send the task request to Executor for execution.
def run_action(self, task_id, action_name, action_params):
"""Sends a request to run action to executor."""
:param cntx: a request context dict
:type cntx: MistralContext
:param kwargs: a dict of method arguments
:type kwargs: dict
"""
return self._client.cast(cntx, 'handle_task', **kwargs)
kwargs = {
'task_id': task_id,
'action_name': action_name,
'params': action_params
}
return self._client.cast(auth_ctx.ctx(), 'run_action', **kwargs)

View File

@ -24,6 +24,8 @@ from oslotest import base
from stevedore import driver
import testtools.matchers as ttm
import time
from mistral import context as auth_context
from mistral.db.sqlalchemy import base as db_sa_base
from mistral.db.v1 import api as db_api_v1
@ -32,7 +34,6 @@ from mistral.engine import executor
from mistral.openstack.common import log as logging
from mistral import version
RESOURCES_PATH = 'tests/resources/'
LOG = logging.getLogger(__name__)
@ -43,20 +44,30 @@ def get_resource(resource_name):
RESOURCES_PATH + resource_name)).read()
# TODO(rakhmerov): Remove together with the current engine implementation.
def get_fake_transport():
# Get transport here to let oslo.messaging setup default config
# before changing the rpc_backend to the fake driver; otherwise,
# oslo.messaging will throw exception.
messaging.get_transport(cfg.CONF)
cfg.CONF.set_default('rpc_backend', 'fake')
url = transport.TransportURL.parse(cfg.CONF, None, None)
kwargs = dict(default_exchange=cfg.CONF.control_exchange,
allowed_remote_exmods=[])
mgr = driver.DriverManager('oslo.messaging.drivers',
kwargs = dict(
default_exchange=cfg.CONF.control_exchange,
allowed_remote_exmods=[]
)
mgr = driver.DriverManager(
'oslo.messaging.drivers',
url.transport,
invoke_on_load=True,
invoke_args=[cfg.CONF, url],
invoke_kwds=kwargs)
invoke_kwds=kwargs
)
return transport.Transport(mgr.driver)
@ -133,6 +144,28 @@ class BaseTest(base.BaseTestCase):
self.fail(self._formatMessage(msg, standardMsg))
def _await(self, predicate, delay=1, timeout=30):
"""Awaits for predicate function to evaluate to True.
If within a configured timeout predicate function hasn't evaluated
to True then an exception is raised.
:param predicate: Predication function.
:param delay: Delay in seconds between predicate function calls.
:param timeout: Maximum amount of time to wait for predication
function to evaluate to True.
:return:
"""
end_time = time.time() + timeout
while True:
if predicate():
break
if time.time() + delay > end_time:
raise AssertionError("Failed to wait for expected result.")
time.sleep(delay)
class DbTestCase(BaseTest):
def setUp(self):
@ -155,12 +188,14 @@ class DbTestCase(BaseTest):
return db_sa_base._get_thread_local_session() is not None
# TODO(rakhmerov): Remove together with the current engine implementation.
class EngineTestCase(DbTestCase):
transport = get_fake_transport()
backend = engine.get_engine(cfg.CONF.engine.engine, transport)
def __init__(self, *args, **kwargs):
super(EngineTestCase, self).__init__(*args, **kwargs)
self.engine = engine.EngineClient(self.transport)
@classmethod
@ -168,22 +203,26 @@ class EngineTestCase(DbTestCase):
"""Mock the engine convey_task_results to send request directly
to the engine instead of going through the oslo.messaging transport.
"""
cntx = {}
kwargs = {'task_id': task_id,
kwargs = {
'task_id': task_id,
'state': state,
'result': result}
return cls.backend.convey_task_result(cntx, **kwargs)
'result': result
}
return cls.backend.convey_task_result({}, **kwargs)
@classmethod
def mock_start_workflow(cls, workbook_name, task_name, context=None):
"""Mock the engine start_workflow_execution to send request directly
to the engine instead of going through the oslo.messaging transport.
"""
cntx = {}
kwargs = {'workbook_name': workbook_name,
kwargs = {
'workbook_name': workbook_name,
'task_name': task_name,
'context': context}
return cls.backend.start_workflow_execution(cntx, **kwargs)
'context': context
}
return cls.backend.start_workflow_execution({}, **kwargs)
@classmethod
def mock_get_workflow_state(cls, workbook_name, execution_id):
@ -191,10 +230,12 @@ class EngineTestCase(DbTestCase):
directly to the engine instead of going through the oslo.messaging
transport.
"""
cntx = {}
kwargs = {'workbook_name': workbook_name,
'execution_id': execution_id}
return cls.backend.get_workflow_execution_state(cntx, **kwargs)
kwargs = {
'workbook_name': workbook_name,
'execution_id': execution_id
}
return cls.backend.get_workflow_execution_state({}, **kwargs)
@classmethod
def mock_run_task(cls, task_id, action_name, params):
@ -202,10 +243,13 @@ class EngineTestCase(DbTestCase):
executor instead of going through the oslo.messaging transport.
"""
exctr = executor.get_executor(cfg.CONF.engine.engine, cls.transport)
exctr.handle_task(auth_context.ctx(),
exctr.handle_task(
auth_context.ctx(),
task_id=task_id,
action_name=action_name,
params=params)
params=params
)
@classmethod
def mock_handle_task(cls, cntx, **kwargs):
@ -213,4 +257,5 @@ class EngineTestCase(DbTestCase):
executor instead of going through the oslo.messaging transport.
"""
exctr = executor.get_executor(cfg.CONF.engine.engine, cls.transport)
return exctr.handle_task(cntx, **kwargs)

View File

@ -0,0 +1,129 @@
# Copyright 2014 - Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import eventlet
from oslo.config import cfg
from oslo import messaging
from oslo.messaging import transport
from mistral import context as ctx
from mistral.engine1 import default_engine as def_eng
from mistral.engine1 import default_executor as def_exec
from mistral.engine1 import rpc
from mistral.openstack.common import log as logging
from mistral.tests import base
from stevedore import driver
eventlet.monkey_patch(
os=True,
select=True,
socket=True,
thread=False if '--use-debugger' in sys.argv else True,
time=True
)
LOG = logging.getLogger(__name__)
def get_fake_transport():
# Get transport here to let oslo.messaging setup default config
# before changing the rpc_backend to the fake driver; otherwise,
# oslo.messaging will throw exception.
messaging.get_transport(cfg.CONF)
cfg.CONF.set_default('rpc_backend', 'fake')
url = transport.TransportURL.parse(cfg.CONF, None)
kwargs = dict(
default_exchange=cfg.CONF.control_exchange,
allowed_remote_exmods=[]
)
mgr = driver.DriverManager(
'oslo.messaging.drivers',
url.transport,
invoke_on_load=True,
invoke_args=(cfg.CONF, url),
invoke_kwds=kwargs
)
return transport.Transport(mgr.driver)
def launch_engine_server(transport, engine):
target = messaging.Target(
topic=cfg.CONF.engine.topic,
server=cfg.CONF.engine.host
)
server = messaging.get_rpc_server(
transport,
target,
[rpc.EngineServer(engine)],
executor='eventlet',
serializer=ctx.RpcContextSerializer(ctx.JsonPayloadSerializer())
)
server.start()
server.wait()
def launch_executor_server(transport, executor):
target = messaging.Target(
topic=cfg.CONF.executor.topic,
server=cfg.CONF.executor.host
)
server = messaging.get_rpc_server(
transport,
target,
[rpc.ExecutorServer(executor)],
executor='eventlet',
serializer=ctx.RpcContextSerializer(ctx.JsonPayloadSerializer())
)
server.start()
server.wait()
class EngineTestCase(base.DbTestCase):
def setUp(self):
super(EngineTestCase, self).setUp()
transport = base.get_fake_transport()
engine_client = rpc.EngineClient(transport)
executor_client = rpc.ExecutorClient(transport)
self.engine = def_eng.DefaultEngine(engine_client, executor_client)
self.executor = def_exec.DefaultExecutor(engine_client)
LOG.info("Starting engine and executor threads...")
self.threads = [
eventlet.spawn(launch_engine_server, transport, self.engine),
eventlet.spawn(launch_executor_server, transport, self.executor),
]
def tearDown(self):
super(EngineTestCase, self).tearDown()
LOG.info("Finishing engine and executor threads...")
[thread.kill() for thread in self.threads]

View File

@ -1,6 +1,4 @@
# -*- coding: utf-8 -*-
#
# Copyright 2013 - Mirantis, Inc.
# Copyright 2014 - Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -15,6 +13,7 @@
# limitations under the License.
import copy
import mock
from mistral.db.v2 import api as db_api
from mistral.db.v2.sqlalchemy import models
@ -62,7 +61,9 @@ class DefaultEngineTest(base.DbTestCase):
'tags': ['test']
})
self.engine = d_eng.DefaultEngine()
# Note: For purposes of this test we can easily use
# simple magic mocks for engine and executor clients
self.engine = d_eng.DefaultEngine(mock.MagicMock(), mock.MagicMock())
def test_start_workflow(self):
wf_input = {
@ -104,7 +105,11 @@ class DefaultEngineTest(base.DbTestCase):
# Start workflow.
exec_db = self.engine.start_workflow(
'my_wb', 'wf1', wf_input, task_name='task2')
'my_wb',
'wf1',
wf_input,
task_name='task2'
)
self.assertIsNotNone(exec_db)
self.assertEqual(states.RUNNING, exec_db.state)
@ -172,7 +177,7 @@ class DefaultEngineTest(base.DbTestCase):
self.assertEqual('task2', task2_db.name)
self.assertEqual(states.SUCCESS, task2_db.state)
in_context = copy.copy(wf_input)
in_context = copy.deepcopy(wf_input)
in_context.update(task1_db.output)
self._assert_dict_contains_subset(in_context, task2_db.in_context)

View File

@ -0,0 +1,141 @@
# Copyright 2014 - Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mistral.db.v2 import api as db_api
from mistral.openstack.common import log as logging
from mistral.tests.unit.engine1 import base
from mistral.workbook import parser as spec_parser
from mistral.workflow import states
LOG = logging.getLogger(__name__)
WORKBOOK = """
---
Version: '2.0'
Workflows:
wf1:
type: reverse
parameters:
- param1
- param2
output:
final_result: $.final_result
tasks:
task1:
action: std.echo output="{$.param1}"
publish:
result1: $
task2:
action: std.echo output="'{$.param1} & {$.param2}'"
publish:
final_result: $
requires: [task1]
wf2:
type: linear
start_task: task1
output:
slogan: $.slogan
tasks:
task1:
workflow: my_wb.wf1 param1='Bonnie' param2='Clyde'
workflow_parameters:
task_name: task2
publish:
slogan: "{$.final_result} is a cool movie!"
"""
class SubworkflowsTest(base.EngineTestCase):
def setUp(self):
super(SubworkflowsTest, self).setUp()
wb_spec = spec_parser.get_workbook_spec_from_yaml(WORKBOOK)
db_api.create_workbook({
'name': 'my_wb',
'description': 'Simple workbook for testing engine.',
'definition': WORKBOOK,
'spec': wb_spec.to_dict(),
'tags': ['test']
})
def test_subworkflow(self):
exec1_db = self.engine.start_workflow('my_wb', 'wf2', None)
# Execution 1.
self.assertIsNotNone(exec1_db)
self.assertDictEqual({}, exec1_db.input)
self.assertDictEqual({}, exec1_db.start_params)
db_execs = db_api.get_executions()
self.assertEqual(2, len(db_execs))
# Execution 2.
exec2_db = db_execs[0] if db_execs[0].id != exec1_db.id \
else db_execs[1]
self.assertIsNotNone(exec2_db.parent_task_id)
self.assertDictEqual(
{
'task_name': 'task2',
'parent_task_id': exec2_db.parent_task_id
},
exec2_db.start_params
)
self.assertDictEqual(
{
'param1': 'Bonnie',
'param2': 'Clyde'
},
exec2_db.input
)
# Wait till workflow 'wf1' is completed.
self._await(
lambda: db_api.get_execution(exec2_db.id).state == states.SUCCESS,
)
exec2_db = db_api.get_execution(exec2_db.id)
self.assertDictEqual(
{
'final_result': "'Bonnie & Clyde'"
},
exec2_db.output
)
# Wait till workflow 'wf2' is completed.
self._await(
lambda: db_api.get_execution(exec1_db.id).state == states.SUCCESS,
)
exec1_db = db_api.get_execution(exec1_db.id)
self.assertEqual(
"'Bonnie & Clyde' is a cool movie!",
exec1_db.context['slogan']
)
self.assertDictEqual(
{
'slogan': "'Bonnie & Clyde' is a cool movie!"
},
exec1_db.output
)

View File

@ -52,7 +52,7 @@ Workflows:
tasks:
task3:
workflow: wf1 name="John Doe" age=32
workflow: wf1 name="John Doe" age=32 param1="Bonnie" param2="Clyde"
"""
# TODO(rakhmerov): Add more tests when v2 spec is complete.
@ -163,7 +163,12 @@ class DSLv2ModelTest(base.BaseTest):
self.assertEqual('wf1', task3_spec.get_short_workflow_name())
self.assertIsNone(task3_spec.get_workflow_namespace())
self.assertEqual(
{'name': 'John Doe', 'age': '32'},
{
'name': 'John Doe',
'age': '32',
'param1': 'Bonnie',
'param2': 'Clyde'
},
task3_spec.get_parameters()
)

View File

@ -1,6 +1,4 @@
# -*- coding: utf-8 -*-
#
# Copyright 2013 - Mirantis, Inc.
# Copyright 2014 - Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -17,6 +17,7 @@
import logging
import threading
from eventlet import corolocal
# Thread local storage.
@ -102,6 +103,13 @@ def merge_dicts(left, right):
:param left: Left dictionary.
:param right: Right dictionary.
"""
if left is None:
return right
if right is None:
return left
for k, v in right.iteritems():
if k not in left:
left[k] = v

View File

@ -23,7 +23,7 @@ from mistral.workbook.v2 import retry
# TODO(rakhmerov): In progress.
CMD_PTRN = re.compile("^[\w\.]+[^=\s\"]*")
PARAMS_PTRN = re.compile("([\w]+)=(\".*\"|\'.*'|[\d\.]*)")
PARAMS_PTRN = re.compile("([\w]+)=(\"[^=]*\"|\'[^=]*'|[\d\.]*)")
class TaskSpec(base.BaseSpec):
@ -35,6 +35,7 @@ class TaskSpec(base.BaseSpec):
"name": {"type": "string"},
"action": {"type": ["string", "null"]},
"workflow": {"type": ["string", "null"]},
"workflow_parameters": {"type": ["object", "null"]},
"parameters": {"type": ["object", "null"]},
"publish": {"type": ["object", "null"]},
"retry": {"type": ["object", "null"]},
@ -53,8 +54,9 @@ class TaskSpec(base.BaseSpec):
super(TaskSpec, self).__init__(data)
self._name = data['name']
self._action = data.get('action', None)
self._workflow = data.get('workflow', None)
self._action = data.get('action')
self._workflow = data.get('workflow')
self._workflow_parameters = data.get('workflow_parameters')
self._parameters = data.get('parameters', {})
self._publish = data.get('publish', {})
self._retry = self._spec_property('retry', retry.RetrySpec)
@ -92,7 +94,11 @@ class TaskSpec(base.BaseSpec):
params = {}
for k, v in re.findall(PARAMS_PTRN, cmd_str):
params[k] = v.replace('"', '').replace("'", '')
# Remove embracing quotes.
if v[0] == '"' or v[0] == "'":
v = v[1:-1]
params[k] = v
return cmd, params
@ -127,6 +133,9 @@ class TaskSpec(base.BaseSpec):
def get_short_workflow_name(self):
return self._workflow.split('.')[-1] if self._workflow else None
def get_workflow_parameters(self):
return self._workflow_parameters
def get_parameters(self):
return self._parameters

View File

@ -27,7 +27,7 @@ class WorkflowSpec(base.BaseSpec):
"type": {"enum": ["reverse", "linear"]},
"start_task": {"type": "string"},
"parameters": {"type": ["array", "null"]},
"output": {"type": ["array", "null"]},
"output": {"type": ["string", "object", "array", "null"]},
"tasks": {"type": "object"},
},
"required": ["Version", "name", "type", "tasks"],
@ -41,6 +41,8 @@ class WorkflowSpec(base.BaseSpec):
self._name = data['name']
self._type = data['type']
self._parameters = data.get('parameters')
self._output = data.get('output')
self._start_task_name = data.get('start_task')
self._tasks = self._spec_property('tasks', tasks.TaskSpecList)
@ -61,6 +63,12 @@ class WorkflowSpec(base.BaseSpec):
def get_type(self):
return self._type
def get_parameters(self):
return self._parameters
def get_output(self):
return self._output
def get_start_task_name(self):
return self._start_task_name

View File

@ -15,6 +15,7 @@
import abc
from mistral import exceptions as exc
from mistral.openstack.common import log as logging
from mistral import utils
from mistral.workbook import parser as spec_parser
from mistral.workflow import data_flow
from mistral.workflow import states
@ -51,7 +52,6 @@ class WorkflowHandler(object):
"""
raise NotImplementedError
@abc.abstractmethod
def on_task_result(self, task_db, raw_result):
"""Handles event of arriving a task result.
@ -70,8 +70,45 @@ class WorkflowHandler(object):
task_db.output =\
data_flow.evaluate_task_output(task_spec, raw_result)
if task_db.state == states.ERROR:
# TODO(rakhmerov): Temporary hack, need to use policies.
self._set_execution_state(states.ERROR)
return []
task_specs = self._find_next_tasks(task_db)
if len(task_specs) == 0:
# If there are no running tasks at this point we can conclude that
# the workflow has finished.
if not self._find_running_tasks():
self._set_execution_state(states.SUCCESS)
task_out_ctx = data_flow.evaluate_outbound_context(task_db)
self.exec_db.context = utils.merge_dicts(
self.exec_db.context,
task_out_ctx
)
self.exec_db.output = data_flow.evaluate_workflow_output(
self.wf_spec,
task_out_ctx
)
return task_specs
@abc.abstractmethod
def _find_next_tasks(self, task_db):
"""Finds tasks that should run next.
A concrete algorithm of finding such tasks depends on a concrete
workflow handler.
:param task_db: Task DB model causing the operation (completed).
:return: List of task specifications.
"""
raise NotImplementedError
def is_stopped_or_finished(self):
return states.is_stopped_or_finished(self.exec_db.state)
@ -114,6 +151,10 @@ class WorkflowHandler(object):
" state=%s -> %s]" % (self.exec_db, cur_state, state)
raise exc.WorkflowException(msg)
def _find_running_tasks(self):
return [t_db for t_db in self.exec_db.tasks
if states.RUNNING == t_db.state]
class TaskResult(object):
"""Explicit data structure containing a result of task execution."""
@ -122,6 +163,9 @@ class TaskResult(object):
self.data = data
self.error = error
def __repr__(self):
return 'TaskResult [data=%s, error=%s]' % (self.data, self.error)
def is_error(self):
return self.error is not None

View File

@ -69,7 +69,7 @@ def _evaluate_upstream_context(upstream_db_tasks):
def evaluate_task_output(task_spec, raw_result):
"""Evaluates task output give a raw task result from action/workflow.
"""Evaluates task output given a raw task result from action/workflow.
:param task_spec: Task specification
:param raw_result: Raw task result that comes from action/workflow
@ -87,6 +87,20 @@ def evaluate_task_output(task_spec, raw_result):
return output
def evaluate_workflow_output(wf_spec, context):
"""Evaluates workflow output.
:param wf_spec: Workflow specification.
:param context: Final Data Flow context (cause task's outbound context).
"""
output_dict = wf_spec.get_output()
# Evaluate 'publish' clause using raw result as a context.
output = expr.evaluate_recursively(output_dict, context)
return output or context
def evaluate_outbound_context(task_db):
"""Evaluates task outbound Data Flow context.
@ -95,8 +109,12 @@ def evaluate_outbound_context(task_db):
:param task_db: DB task.
:return: Outbound task Data Flow context.
"""
in_context = copy.deepcopy(dict(task_db.in_context)) \
if task_db.in_context is not None else {}
return utils.merge_dicts(
copy.copy(task_db.in_context) or {},
in_context,
task_db.output
)

View File

@ -38,25 +38,6 @@ class LinearWorkflowHandler(base.WorkflowHandler):
return [self._find_start_task()]
def on_task_result(self, task_db, raw_result):
super(LinearWorkflowHandler, self).on_task_result(task_db, raw_result)
if task_db.state == states.ERROR:
# TODO(rakhmerov): Temporary hack, need to use policies.
self._set_execution_state(states.ERROR)
return []
task_specs = self._find_next_tasks(task_db)
if len(task_specs) == 0:
# If there are no running tasks at this point we can conclude that
# the workflow has finished.
if not self._find_running_tasks():
self._set_execution_state(states.SUCCESS)
return task_specs
def get_upstream_tasks(self, task_spec):
# TODO(rakhmerov): For linear workflow it's pretty hard to do
# so we may need to get rid of it at all.
@ -112,7 +93,3 @@ class LinearWorkflowHandler(base.WorkflowHandler):
task_specs.append(self.wf_spec.get_tasks()[t_name])
return task_specs
def _find_running_tasks(self):
return [t_db for t_db in self.exec_db.tasks
if states.RUNNING == t_db.state]

View File

@ -52,29 +52,35 @@ class ReverseWorkflowHandler(base.WorkflowHandler):
return task_specs
def on_task_result(self, task_db, raw_result):
super(ReverseWorkflowHandler, self).on_task_result(
task_db,
raw_result
)
if task_db.state == states.ERROR:
# TODO(rakhmerov): Temporary hack, need to use policies.
self._set_execution_state(states.ERROR)
return []
task_specs = self._find_resolved_tasks()
if len(task_specs) == 0:
self._set_execution_state(states.SUCCESS)
return task_specs
def get_upstream_tasks(self, task_spec):
return [self.wf_spec.get_tasks()[t_name]
for t_name in task_spec.get_requires() or []]
def _find_next_tasks(self, task_db):
"""Finds all tasks with resolved dependencies.
:param task_db: Task DB model causing the operation
(not used in this handler).
:return: Tasks with resolved dependencies.
"""
# We need to analyse the graph and see which tasks are ready to start.
resolved_task_specs = []
success_task_names = set()
for t in self.exec_db.tasks:
if t.state == states.SUCCESS:
success_task_names.add(t.name)
for t_spec in self.wf_spec.get_tasks():
if not (set(t_spec.get_requires()) - success_task_names):
t_db = self._find_db_task(t_spec.get_name())
if not t_db or t_db.state == states.IDLE:
resolved_task_specs.append(t_spec)
return resolved_task_specs
def _find_tasks_without_dependencies(self, task_spec):
"""Given a target task name finds tasks with no dependencies.
@ -123,29 +129,6 @@ class ReverseWorkflowHandler(base.WorkflowHandler):
return dep_t_specs
def _find_resolved_tasks(self):
"""Finds all tasks with resolved dependencies.
:return: Tasks with resolved dependencies.
"""
# We need to analyse the graph and see which tasks are ready to start.
resolved_task_specs = []
success_task_names = set()
for t in self.exec_db.tasks:
if t.state == states.SUCCESS:
success_task_names.add(t.name)
for t_spec in self.wf_spec.get_tasks():
if not (set(t_spec.get_requires()) - success_task_names):
task_db = self._find_db_task(t_spec.get_name())
if not task_db or task_db.state == states.IDLE:
resolved_task_specs.append(t_spec)
return resolved_task_specs
def _find_db_task(self, name):
db_tasks = filter(lambda t: t.name == name, self.exec_db.tasks)

View File

@ -41,7 +41,7 @@ def _select_workflow_handler(wf_spec):
if wf_type == 'reverse':
return reverse_workflow.ReverseWorkflowHandler
if wf_type == 'direct':
if wf_type == 'linear':
return linear_workflow.LinearWorkflowHandler
return None