Merge "Add "published_global" field to the task execution REST resource"

This commit is contained in:
Zuul 2019-09-05 18:31:26 +00:00 committed by Gerrit Code Review
commit 51293d9b7e
5 changed files with 183 additions and 27 deletions

View File

@ -51,26 +51,14 @@ STATE_TYPES = wtypes.Enum(
)
def _load_deferred_fields(ex, fields):
if not ex:
return ex
# We need to refer lazy-loaded fields explicitly in
# order to make sure that they are correctly loaded.
for f in fields:
hasattr(ex, f)
return ex
def _get_workflow_execution_resource_with_output(wf_ex):
_load_deferred_fields(wf_ex, ['params', 'input', 'output'])
rest_utils.load_deferred_fields(wf_ex, ['params', 'input', 'output'])
return resources.Execution.from_db_model(wf_ex)
def _get_workflow_execution_resource(wf_ex):
_load_deferred_fields(wf_ex, ['params', 'input'])
rest_utils.load_deferred_fields(wf_ex, ['params', 'input'])
return resources.Execution.from_db_model(wf_ex)
@ -84,9 +72,35 @@ def _get_workflow_execution(id, must_exist=True):
else:
wf_ex = db_api.load_workflow_execution(id)
return _load_deferred_fields(wf_ex, ['params', 'input', 'output'])
return rest_utils.load_deferred_fields(wf_ex, ['params', 'input',
'output', 'context'])
# Use retries to prevent possible failures.
@rest_utils.rest_retry_on_db_error
def _get_task_executions(wf_id):
with db_api.transaction():
task_execs = db_api.get_task_executions(workflow_execution_id=wf_id)
for task_ex in task_execs:
rest_utils.load_deferred_fields(task_ex, ['spec'])
return task_execs
def _get_published_global_from_tasks(task_execs, wf_ex):
wf_published_global_vars = {}
for task_ex in task_execs:
published_global_vars = task.get_published_global(task_ex, wf_ex)
if published_global_vars:
merge_dicts(wf_published_global_vars, published_global_vars)
return wf_published_global_vars
def _execution_with_publish_global(wf_ex, wf_published_global_vars):
wf_execution = resources.Execution.from_db_model(wf_ex)
if wf_published_global_vars:
wf_execution.published_global = wf_published_global_vars
return wf_execution
# TODO(rakhmerov): Make sure to make all needed renaming on public API.
@ -107,7 +121,12 @@ class ExecutionsController(rest.RestController):
wf_ex = _get_workflow_execution(id)
return resources.Execution.from_db_model(wf_ex)
task_execs = _get_task_executions(wf_ex.id)
wf_published_global_vars = _get_published_global_from_tasks(task_execs,
wf_ex)
return _execution_with_publish_global(wf_ex, wf_published_global_vars)
@rest_utils.wrap_wsme_controller_exception
@wsme_pecan.wsexpose(

View File

@ -314,6 +314,8 @@ class Execution(resource.Resource):
project_id = wsme.wsattr(wtypes.text, readonly=True)
published_global = types.jsontype
@classmethod
def sample(cls):
return cls(
@ -326,6 +328,7 @@ class Execution(resource.Resource):
state='SUCCESS',
input={},
output={},
published_global={'key': 'value'},
params={
'env': {'k1': 'abc', 'k2': 123},
'notify': [
@ -399,6 +402,7 @@ class Task(resource.Resource):
result = wtypes.text
published = types.jsontype
published_global = types.jsontype
processed = bool
created_at = wtypes.text
@ -432,6 +436,7 @@ class Task(resource.Resource):
},
result='task result',
published={'key': 'value'},
published_global={'key': 'value'},
processed=True,
created_at='1970-01-01T00:00:00.000000',
updated_at='1970-01-01T00:00:00.000000',

View File

@ -27,14 +27,17 @@ from mistral.api.controllers.v2 import types
from mistral import context
from mistral.db.v2 import api as db_api
from mistral import exceptions as exc
from mistral import expressions as expr
from mistral.lang import parser as spec_parser
from mistral.rpc import clients as rpc
from mistral.utils import filter_utils
from mistral.utils import rest_utils
from mistral.workflow import data_flow
from mistral.workflow.data_flow import ContextView
from mistral.workflow.data_flow import get_current_task_dict
from mistral.workflow.data_flow import get_workflow_environment_dict
from mistral.workflow import states
LOG = logging.getLogger(__name__)
STATE_TYPES = wtypes.Enum(
@ -55,6 +58,49 @@ def _get_task_resource_with_result(task_ex):
return task
# Use retries to prevent possible failures.
@rest_utils.rest_retry_on_db_error
def _get_task_execution(id):
with db_api.transaction():
task_ex = db_api.get_task_execution(id)
rest_utils.load_deferred_fields(task_ex, ['workflow_execution'])
rest_utils.load_deferred_fields(task_ex.workflow_execution,
['context', 'input',
'params', 'root_execution'])
return _get_task_resource_with_result(task_ex), task_ex
def get_published_global(task_ex, wf_ex=None):
if task_ex.state not in [states.SUCCESS, states.ERROR]:
return
if wf_ex is None:
wf_ex = task_ex.workflow_execution
expr_ctx = ContextView(
get_current_task_dict(task_ex),
task_ex.in_context,
get_workflow_environment_dict(wf_ex),
wf_ex.context,
wf_ex.input
)
task_spec = spec_parser.get_task_spec(task_ex.spec)
publish_spec = task_spec.get_publish(task_ex.state)
if not publish_spec:
return
global_vars = publish_spec.get_global()
return expr.evaluate_recursively(global_vars, expr_ctx)
def _task_with_published_global(task, task_ex):
published_global_vars = get_published_global(task_ex)
if published_global_vars:
task.published_global = published_global_vars
return task
class TaskExecutionsController(rest.RestController):
@rest_utils.wrap_wsme_controller_exception
@wsme_pecan.wsexpose(resources.Executions, types.uuid, types.uuid, int,
@ -137,15 +183,6 @@ class TaskExecutionsController(rest.RestController):
)
# Use retries to prevent possible failures.
@rest_utils.rest_retry_on_db_error
def _get_task_execution(id):
with db_api.transaction():
task_ex = db_api.get_task_execution(id)
return _get_task_resource_with_result(task_ex)
class TasksController(rest.RestController):
action_executions = action_execution.TasksActionExecutionController()
workflow_executions = TaskExecutionsController()
@ -160,7 +197,8 @@ class TasksController(rest.RestController):
acl.enforce('tasks:get', context.ctx())
LOG.debug("Fetch task [id=%s]", id)
return _get_task_execution(id)
task, task_ex = _get_task_execution(id)
return _task_with_published_global(task, task_ex)
@rest_utils.wrap_wsme_controller_exception
@wsme_pecan.wsexpose(resources.Tasks, types.uuid, int, types.uniquelist,

View File

@ -0,0 +1,82 @@
# Copyright 2019 - Nokia Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from mistral.services import workflows as wf_service
from mistral.tests.unit.api import base
from mistral.tests.unit.engine import base as engine_base
WF_TEXT = """---
version: '2.0'
wf:
tasks:
task1:
action: std.noop
on-success:
publish:
branch:
my_var: Branch local value
global:
my_var: Global value
next:
- task2
task2:
action: std.noop
publish:
local: <% $.my_var %>
global: <% global(my_var) %>
"""
def _find_task(task_name, tasks):
return next(
(
task
for task in tasks
if task['name'] == task_name
), None
)
class TestGlobalPublish(base.APITest, engine_base.EngineTestCase):
def setUp(self):
super(TestGlobalPublish, self).setUp()
wf_service.create_workflows(WF_TEXT)
wf_ex = self.engine.start_workflow('wf')
self.await_workflow_success(wf_ex.id)
self.wf_id = wf_ex.id
def test_global_publish_in_task_exec(self):
resp = self.app.get('/v2/tasks/')
tasks = resp.json['tasks']
task = _find_task('task1', tasks)
self.assertIsNotNone(task, 'task1 not found')
resp = self.app.get('/v2/tasks/%s/' % task['id'])
self.assert_for_published_global(resp)
def test_global_publish_in_wf_exec(self):
resp = self.app.get('/v2/executions/%s/' % self.wf_id)
self.assert_for_published_global(resp)
def assert_for_published_global(self, resp):
self.assertEqual(200, resp.status_int)
self.assertEqual(resp.json['published_global'],
'{"my_var": "Global value"}')

View File

@ -267,3 +267,15 @@ def create_db_retry_object():
def rest_retry_on_db_error(func):
return db_utils.retry_on_db_error(func, create_db_retry_object())
def load_deferred_fields(ex, fields):
if not ex:
return ex
# We need to refer lazy-loaded fields explicitly in
# order to make sure that they are correctly loaded.
for f in fields:
hasattr(ex, f)
return ex