Merge "Avoid storing workflow input in task inbound context"

This commit is contained in:
Jenkins 2016-09-22 04:36:17 +00:00 committed by Gerrit Code Review
commit 0e120ed41f
14 changed files with 346 additions and 80 deletions

View File

@ -208,7 +208,12 @@ class WorkflowExecution(Execution):
output = sa.orm.deferred(sa.Column(st.JsonLongDictType(), nullable=True))
params = sa.Column(st.JsonLongDictType())
# TODO(rakhmerov): We need to get rid of this field at all.
# Initial workflow context containing workflow variables, environment,
# openstack security context etc.
# NOTES:
# * Data stored in this structure should not be copied into inbound
# contexts of tasks. No need to duplicate it.
# * This structure does not contain workflow input.
context = sa.Column(st.JsonLongDictType())

View File

@ -175,8 +175,15 @@ class TaskPolicy(object):
:param task_ex: DB model for task that is about to start.
:param task_spec: Task specification.
"""
# No-op by default.
data_flow.evaluate_object_fields(self, task_ex.in_context)
wf_ex = task_ex.workflow_execution
ctx_view = data_flow.ContextView(
task_ex.in_context,
wf_ex.context,
wf_ex.input
)
data_flow.evaluate_object_fields(self, ctx_view)
self._validate()
@ -186,8 +193,15 @@ class TaskPolicy(object):
:param task_ex: Completed task DB model.
:param task_spec: Completed task specification.
"""
# No-op by default.
data_flow.evaluate_object_fields(self, task_ex.in_context)
wf_ex = task_ex.workflow_execution
ctx_view = data_flow.ContextView(
task_ex.in_context,
wf_ex.context,
wf_ex.input
)
data_flow.evaluate_object_fields(self, ctx_view)
self._validate()

View File

@ -305,9 +305,17 @@ class RetryPolicy(base.TaskPolicy):
context_key
)
wf_ex = task_ex.workflow_execution
ctx_view = data_flow.ContextView(
data_flow.evaluate_task_outbound_context(task_ex),
wf_ex.context,
wf_ex.input
)
continue_on_evaluation = expressions.evaluate(
self._continue_on_clause,
data_flow.evaluate_task_outbound_context(task_ex)
ctx_view
)
task_ex.runtime_context = runtime_context

View File

@ -362,7 +362,16 @@ class RegularTask(Task):
def _get_action_input(self, ctx=None):
ctx = ctx or self.ctx
input_dict = expr.evaluate_recursively(self.task_spec.get_input(), ctx)
ctx_view = data_flow.ContextView(
ctx,
self.wf_ex.context,
self.wf_ex.input
)
input_dict = expr.evaluate_recursively(
self.task_spec.get_input(),
ctx_view
)
return utils.merge_dicts(
input_dict,
@ -478,9 +487,15 @@ class WithItemsTask(RegularTask):
:return: the list of tuples containing indexes
and the corresponding input dict.
"""
ctx_view = data_flow.ContextView(
self.ctx,
self.wf_ex.context,
self.wf_ex.input
)
with_items_inputs = expr.evaluate_recursively(
self.task_spec.get_with_items(),
self.ctx
ctx_view
)
with_items.validate_input(with_items_inputs)

View File

@ -14,7 +14,6 @@
# limitations under the License.
import abc
import copy
from oslo_config import cfg
from oslo_log import log as logging
from osprofiler import profiler
@ -220,7 +219,6 @@ class Workflow(object):
})
self.wf_ex.input = input_dict or {}
self.wf_ex.context = copy.deepcopy(input_dict) or {}
env = _get_environment(params)
@ -309,18 +307,23 @@ class Workflow(object):
wf_ctrl = wf_base.get_controller(self.wf_ex, self.wf_spec)
if wf_ctrl.any_cancels():
self._cancel_workflow(
_build_cancel_info_message(wf_ctrl, self.wf_ex)
)
msg = _build_cancel_info_message(wf_ctrl, self.wf_ex)
self._cancel_workflow(msg)
elif wf_ctrl.all_errors_handled():
self._succeed_workflow(wf_ctrl.evaluate_workflow_final_context())
ctx = wf_ctrl.evaluate_workflow_final_context()
self._succeed_workflow(ctx)
else:
self._fail_workflow(_build_fail_info_message(wf_ctrl, self.wf_ex))
msg = _build_fail_info_message(wf_ctrl, self.wf_ex)
self._fail_workflow(msg)
return 0
def _succeed_workflow(self, final_context, msg=None):
self.wf_ex.output = data_flow.evaluate_workflow_output(
self.wf_ex,
self.wf_spec,
final_context
)

View File

@ -17,6 +17,8 @@ from oslo_config import cfg
from mistral.db.v2 import api as db_api
from mistral.db.v2.sqlalchemy import models
from mistral import exceptions as exc
from mistral import expressions as expr
from mistral.services import workflows as wf_service
from mistral.tests.unit import base as test_base
from mistral.tests.unit.engine import base as engine_test_base
@ -84,12 +86,8 @@ class DataFlowEngineTest(engine_test_base.EngineTestCase):
)
# Make sure that task inbound context doesn't contain workflow
# specification, input and params.
ctx = task1.in_context
self.assertFalse('spec' in ctx['__execution'])
self.assertFalse('input' in ctx['__execution'])
self.assertFalse('params' in ctx['__execution'])
# execution info.
self.assertFalse('__execution' in task1.in_context)
def test_linear_with_branches_dataflow(self):
linear_with_branches_wf = """---
@ -595,3 +593,106 @@ class DataFlowTest(test_base.BaseTest):
[1, 1],
data_flow.get_task_execution_result(task_ex)
)
def test_context_view(self):
ctx = data_flow.ContextView(
{
'k1': 'v1',
'k11': 'v11',
'k3': 'v3'
},
{
'k2': 'v2',
'k21': 'v21',
'k3': 'v32'
}
)
self.assertIsInstance(ctx, dict)
self.assertEqual(5, len(ctx))
self.assertIn('k1', ctx)
self.assertIn('k11', ctx)
self.assertIn('k3', ctx)
self.assertIn('k2', ctx)
self.assertIn('k21', ctx)
self.assertEqual('v1', ctx['k1'])
self.assertEqual('v1', ctx.get('k1'))
self.assertEqual('v11', ctx['k11'])
self.assertEqual('v11', ctx.get('k11'))
self.assertEqual('v3', ctx['k3'])
self.assertEqual('v2', ctx['k2'])
self.assertEqual('v2', ctx.get('k2'))
self.assertEqual('v21', ctx['k21'])
self.assertEqual('v21', ctx.get('k21'))
self.assertIsNone(ctx.get('Not existing key'))
self.assertRaises(exc.MistralError, ctx.update)
self.assertRaises(exc.MistralError, ctx.clear)
self.assertRaises(exc.MistralError, ctx.pop, 'k1')
self.assertRaises(exc.MistralError, ctx.popitem)
self.assertRaises(exc.MistralError, ctx.__setitem__, 'k5', 'v5')
self.assertRaises(exc.MistralError, ctx.__delitem__, 'k2')
self.assertEqual('v1', expr.evaluate('<% $.k1 %>', ctx))
self.assertEqual('v2', expr.evaluate('<% $.k2 %>', ctx))
self.assertEqual('v3', expr.evaluate('<% $.k3 %>', ctx))
# Now change the order of dictionaries and make sure to have
# a different for key 'k3'.
ctx = data_flow.ContextView(
{
'k2': 'v2',
'k21': 'v21',
'k3': 'v32'
},
{
'k1': 'v1',
'k11': 'v11',
'k3': 'v3'
}
)
self.assertEqual('v32', expr.evaluate('<% $.k3 %>', ctx))
def test_context_view_eval_root_with_yaql(self):
ctx = data_flow.ContextView(
{'k1': 'v1'},
{'k2': 'v2'}
)
res = expr.evaluate('<% $ %>', ctx)
self.assertIsNotNone(res)
self.assertIsInstance(res, dict)
self.assertEqual(2, len(res))
def test_context_view_eval_keys(self):
ctx = data_flow.ContextView(
{'k1': 'v1'},
{'k2': 'v2'}
)
res = expr.evaluate('<% $.keys() %>', ctx)
self.assertIsNotNone(res)
self.assertIsInstance(res, list)
self.assertEqual(2, len(res))
self.assertIn('k1', res)
self.assertIn('k2', res)
def test_context_view_eval_values(self):
ctx = data_flow.ContextView(
{'k1': 'v1'},
{'k2': 'v2'}
)
res = expr.evaluate('<% $.values() %>', ctx)
self.assertIsNotNone(res)
self.assertIsInstance(res, list)
self.assertEqual(2, len(res))
self.assertIn('v1', res)
self.assertIn('v2', res)

View File

@ -116,7 +116,6 @@ class DefaultEngineTest(base.DbTestCase):
self.assertIsNotNone(wf_ex)
self.assertEqual(states.RUNNING, wf_ex.state)
self.assertEqual('my execution', wf_ex.description)
self._assert_dict_contains_subset(wf_input, wf_ex.context)
self.assertIn('__execution', wf_ex.context)
# Note: We need to reread execution to access related tasks.
@ -133,9 +132,6 @@ class DefaultEngineTest(base.DbTestCase):
self.assertDictEqual({}, task_ex.runtime_context)
# Data Flow properties.
self._assert_dict_contains_subset(wf_input, task_ex.in_context)
self.assertIn('__execution', task_ex.in_context)
action_execs = db_api.get_action_executions(
task_execution_id=task_ex.id
)
@ -159,7 +155,6 @@ class DefaultEngineTest(base.DbTestCase):
self.assertIsNotNone(wf_ex)
self.assertEqual(states.RUNNING, wf_ex.state)
self._assert_dict_contains_subset(wf_input, wf_ex.context)
self.assertIn('__execution', wf_ex.context)
# Note: We need to reread execution to access related tasks.
@ -176,9 +171,6 @@ class DefaultEngineTest(base.DbTestCase):
self.assertDictEqual({}, task_ex.runtime_context)
# Data Flow properties.
self._assert_dict_contains_subset(wf_input, task_ex.in_context)
self.assertIn('__execution', task_ex.in_context)
action_execs = db_api.get_action_executions(
task_execution_id=task_ex.id
)
@ -318,8 +310,7 @@ class DefaultEngineTest(base.DbTestCase):
self.assertEqual(states.RUNNING, task1_ex.state)
self.assertIsNotNone(task1_ex.spec)
self.assertDictEqual({}, task1_ex.runtime_context)
self._assert_dict_contains_subset(wf_input, task1_ex.in_context)
self.assertIn('__execution', task1_ex.in_context)
self.assertNotIn('__execution', task1_ex.in_context)
action_execs = db_api.get_action_executions(
task_execution_id=task1_ex.id
@ -345,8 +336,6 @@ class DefaultEngineTest(base.DbTestCase):
# Data Flow properties.
task1_ex = db_api.get_task_execution(task1_ex.id) # Re-read the state.
self._assert_dict_contains_subset(wf_input, task1_ex.in_context)
self.assertIn('__execution', task1_ex.in_context)
self.assertDictEqual({'var': 'Hey'}, task1_ex.published)
self.assertDictEqual({'output': 'Hey'}, task1_action_ex.input)
self.assertDictEqual({'result': 'Hey'}, task1_action_ex.output)
@ -396,7 +385,6 @@ class DefaultEngineTest(base.DbTestCase):
self.assertEqual(states.SUCCESS, task2_action_ex.state)
# Data Flow properties.
self.assertIn('__execution', task1_ex.in_context)
self.assertDictEqual({'output': 'Hi'}, task2_action_ex.input)
self.assertDictEqual({}, task2_ex.published)
self.assertDictEqual({'output': 'Hi'}, task2_action_ex.input)

View File

@ -17,6 +17,7 @@ from oslo_config import cfg
from mistral.actions import std_actions
from mistral.db.v2 import api as db_api
from mistral.db.v2.sqlalchemy import models
from mistral.engine import policies
from mistral import exceptions as exc
from mistral.services import workbooks as wb_service
@ -364,18 +365,29 @@ class PoliciesTest(base.EngineTestCase):
"delay": {"type": "integer"}
}
}
task_db = type('Task', (object,), {'in_context': {'int_var': 5}})
wf_ex = models.WorkflowExecution(
id='1-2-3-4',
context={},
input={}
)
task_ex = models.TaskExecution(in_context={'int_var': 5})
task_ex.workflow_execution = wf_ex
policy.delay = "<% $.int_var %>"
# Validation is ok.
policy.before_task_start(task_db, None)
policy.before_task_start(task_ex, None)
policy.delay = "some_string"
# Validation is failing now.
exception = self.assertRaises(
exc.InvalidModelException,
policy.before_task_start, task_db, None
policy.before_task_start,
task_ex,
None
)
self.assertIn("Invalid data type in TaskPolicy", str(exception))
@ -494,7 +506,11 @@ class PoliciesTest(base.EngineTestCase):
self.assertEqual(states.RUNNING, task_ex.state)
self.assertDictEqual({}, task_ex.runtime_context)
self.assertEqual(2, task_ex.in_context['wait_after'])
# TODO(rakhmerov): This check doesn't make sense anymore because
# we don't store evaluated value anywhere.
# Need to create a better test.
# self.assertEqual(2, task_ex.in_context['wait_after'])
def test_retry_policy(self):
wb_service.create_workbook_v2(RETRY_WB)
@ -535,8 +551,11 @@ class PoliciesTest(base.EngineTestCase):
self.assertEqual(states.RUNNING, task_ex.state)
self.assertDictEqual({}, task_ex.runtime_context)
self.assertEqual(3, task_ex.in_context["count"])
self.assertEqual(1, task_ex.in_context["delay"])
# TODO(rakhmerov): This check doesn't make sense anymore because
# we don't store evaluated values anywhere.
# Need to create a better test.
# self.assertEqual(3, task_ex.in_context["count"])
# self.assertEqual(1, task_ex.in_context["delay"])
def test_retry_policy_never_happen(self):
retry_wb = """---
@ -908,7 +927,10 @@ class PoliciesTest(base.EngineTestCase):
self.assertEqual(states.RUNNING, task_ex.state)
self.assertEqual(1, task_ex.in_context['timeout'])
# TODO(rakhmerov): This check doesn't make sense anymore because
# we don't store evaluated 'timeout' value anywhere.
# Need to create a better test.
# self.assertEqual(1, task_ex.in_context['timeout'])
def test_pause_before_policy(self):
wb_service.create_workbook_v2(PAUSE_BEFORE_WB)
@ -1012,9 +1034,7 @@ class PoliciesTest(base.EngineTestCase):
self.assertEqual(states.SUCCESS, task_ex.state)
runtime_context = task_ex.runtime_context
self.assertEqual(4, runtime_context['concurrency'])
self.assertEqual(4, task_ex.runtime_context['concurrency'])
def test_concurrency_is_in_runtime_context_from_var(self):
wb_service.create_workbook_v2(CONCURRENCY_WB_FROM_VAR)
@ -1023,12 +1043,13 @@ class PoliciesTest(base.EngineTestCase):
wf_ex = self.engine.start_workflow('wb.wf1', {'concurrency': 4})
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_ex = self._assert_single_item(
wf_ex.task_executions,
name='task1'
)
self.assertEqual(4, task_ex.in_context['concurrency'])
self.assertEqual(4, task_ex.runtime_context['concurrency'])
def test_wrong_policy_prop_type(self):
wb = """---

View File

@ -64,14 +64,6 @@ class WorkflowVariablesTest(base.EngineTestCase):
self.assertEqual(states.SUCCESS, task1.state)
self._assert_dict_contains_subset(
{
'literal_var': 'Literal value',
'yaql_var': 'Hello Renat'
},
task1.in_context
)
self.assertDictEqual(
{
'literal_var': 'Literal value',

View File

@ -158,7 +158,7 @@ class SpecificationCachingTest(base.DbTestCase):
self.assertEqual(0, spec_parser.get_wf_execution_spec_cache_size())
self.assertEqual(2, spec_parser.get_wf_definition_spec_cache_size())
def test_update_workflow_spec_for_execution(self):
def test_cache_workflow_spec_by_execution_id(self):
wf_text = """
version: '2.0'

View File

@ -36,7 +36,9 @@ class DirectWorkflowControllerTest(base.DbTestCase):
id='1-2-3-4',
spec=wf_spec.to_dict(),
state=states.RUNNING,
workflow_id=wfs[0].id
workflow_id=wfs[0].id,
input={},
context={}
)
self.wf_ex = wf_ex

View File

@ -182,16 +182,15 @@ class WorkflowController(object):
# to cover 'split' (aka 'merge') use case.
upstream_task_execs = self._get_upstream_task_executions(task_spec)
upstream_ctx = data_flow.evaluate_upstream_context(upstream_task_execs)
ctx = u.merge_dicts(
copy.deepcopy(self.wf_ex.context),
upstream_ctx
)
ctx = data_flow.evaluate_upstream_context(upstream_task_execs)
# TODO(rakhmerov): Seems like we can fully get rid of '__env' in
# task context if we are OK to have it only in workflow execution
# object (wf_ex.context). Now we can selectively modify env
# for some tasks if we resume or re-run a workflow.
if self.wf_ex.context:
ctx['__env'] = u.merge_dicts(
copy.deepcopy(upstream_ctx.get('__env', {})),
copy.deepcopy(ctx.get('__env', {})),
copy.deepcopy(self.wf_ex.context.get('__env', {}))
)

View File

@ -20,6 +20,7 @@ from oslo_log import log as logging
from mistral import context as auth_ctx
from mistral.db.v2.sqlalchemy import models
from mistral import exceptions as exc
from mistral import expressions as expr
from mistral import utils
from mistral.utils import inspect_utils
@ -31,6 +32,101 @@ LOG = logging.getLogger(__name__)
CONF = cfg.CONF
class ContextView(dict):
"""Workflow context view.
It's essentially an immutable composite structure providing fast lookup
over multiple dictionaries w/o having to merge those dictionaries every
time. The lookup algorithm simply iterates over provided dictionaries
one by one and returns a value taken from the first dictionary where
the provided key exists. This means that these dictionaries must be
provided in the order of decreasing priorities.
Note: Although this class extends built-in 'dict' it shouldn't be
considered a normal dictionary because it may not implement all
methods and account for all corner cases. It's only a read-only view.
"""
def __init__(self, *dicts):
super(ContextView, self).__init__()
self.dicts = dicts or []
def __getitem__(self, key):
for d in self.dicts:
if key in d:
return d[key]
raise KeyError(key)
def get(self, key, default=None):
for d in self.dicts:
if key in d:
return d[key]
return default
def __contains__(self, key):
return any(key in d for d in self.dicts)
def keys(self):
keys = set()
for d in self.dicts:
keys.update(d.keys())
return keys
def items(self):
return [(k, self[k]) for k in self.keys()]
def values(self):
return [self[k] for k in self.keys()]
def iteritems(self):
# NOTE: This is for compatibility with Python 2.7
# YAQL converts output objects after they are evaluated
# to basic types and it uses six.iteritems() internally
# which calls d.items() in case of Python 2.7 and d.iteritems()
# for Python 2.7
return iter(self.items())
def iterkeys(self):
# NOTE: This is for compatibility with Python 2.7
# See the comment for iteritems().
return iter(self.keys())
def itervalues(self):
# NOTE: This is for compatibility with Python 2.7
# See the comment for iteritems().
return iter(self.values())
def __len__(self):
return len(self.keys())
@staticmethod
def _raise_immutable_error():
raise exc.MistralError('Context view is immutable.')
def __setitem__(self, key, value):
self._raise_immutable_error()
def update(self, E=None, **F):
self._raise_immutable_error()
def clear(self):
self._raise_immutable_error()
def pop(self, k, d=None):
self._raise_immutable_error()
def popitem(self):
self._raise_immutable_error()
def __delitem__(self, key):
self._raise_immutable_error()
def evaluate_upstream_context(upstream_task_execs):
published_vars = {}
ctx = {}
@ -90,7 +186,13 @@ def publish_variables(task_ex, task_spec):
if task_ex.state != states.SUCCESS:
return
expr_ctx = task_ex.in_context
wf_ex = task_ex.workflow_execution
expr_ctx = ContextView(
task_ex.in_context,
wf_ex.context,
wf_ex.input
)
if task_ex.name in expr_ctx:
LOG.warning(
@ -112,26 +214,28 @@ def evaluate_task_outbound_context(task_ex):
:param task_ex: DB task.
:return: Outbound task Data Flow context.
"""
in_context = (copy.deepcopy(dict(task_ex.in_context))
if task_ex.in_context is not None else {})
in_context = (
copy.deepcopy(dict(task_ex.in_context))
if task_ex.in_context is not None else {}
)
return utils.update_dict(in_context, task_ex.published)
def evaluate_workflow_output(wf_spec, ctx):
def evaluate_workflow_output(wf_ex, wf_spec, ctx):
"""Evaluates workflow output.
:param wf_ex: Workflow execution.
:param wf_spec: Workflow specification.
:param ctx: Final Data Flow context (cause task's outbound context).
"""
ctx = copy.deepcopy(ctx)
output_dict = wf_spec.get_output()
# Evaluate workflow 'publish' clause using the final workflow context.
output = expr.evaluate_recursively(output_dict, ctx)
# Evaluate workflow 'output' clause using the final workflow context.
ctx_view = ContextView(ctx, wf_ex.context, wf_ex.input)
output = expr.evaluate_recursively(output_dict, ctx_view)
# TODO(rakhmerov): Many don't like that we return the whole context
# if 'output' is not explicitly defined.
@ -168,6 +272,7 @@ def add_execution_to_context(wf_ex):
def add_environment_to_context(wf_ex):
# TODO(rakhmerov): This is redundant, we can always get env from WF params
wf_ex.context = wf_ex.context or {}
# If env variables are provided, add an evaluated copy into the context.
@ -181,10 +286,13 @@ def add_environment_to_context(wf_ex):
def add_workflow_variables_to_context(wf_ex, wf_spec):
wf_ex.context = wf_ex.context or {}
return utils.merge_dicts(
wf_ex.context,
expr.evaluate_recursively(wf_spec.get_vars(), wf_ex.context)
)
# The context for calculating workflow variables is workflow input
# and other data already stored in workflow initial context.
ctx_view = ContextView(wf_ex.context, wf_ex.input)
wf_vars = expr.evaluate_recursively(wf_spec.get_vars(), ctx_view)
utils.merge_dicts(wf_ex.context, wf_vars)
def evaluate_object_fields(obj, context):

View File

@ -183,9 +183,15 @@ class DirectWorkflowController(base.WorkflowController):
def all_errors_handled(self):
for t_ex in lookup_utils.find_error_task_executions(self.wf_ex.id):
ctx_view = data_flow.ContextView(
data_flow.evaluate_task_outbound_context(t_ex),
self.wf_ex.context,
self.wf_ex.input
)
tasks_on_error = self._find_next_tasks_for_clause(
self.wf_spec.get_on_error_clause(t_ex.name),
data_flow.evaluate_task_outbound_context(t_ex)
ctx_view
)
if not tasks_on_error:
@ -218,7 +224,11 @@ class DirectWorkflowController(base.WorkflowController):
t_state = task_ex.state
t_name = task_ex.name
ctx = data_flow.evaluate_task_outbound_context(task_ex)
ctx_view = data_flow.ContextView(
data_flow.evaluate_task_outbound_context(task_ex),
self.wf_ex.context,
self.wf_ex.input
)
t_names_and_params = []
@ -226,7 +236,7 @@ class DirectWorkflowController(base.WorkflowController):
t_names_and_params += (
self._find_next_tasks_for_clause(
self.wf_spec.get_on_complete_clause(t_name),
ctx
ctx_view
)
)
@ -234,7 +244,7 @@ class DirectWorkflowController(base.WorkflowController):
t_names_and_params += (
self._find_next_tasks_for_clause(
self.wf_spec.get_on_error_clause(t_name),
ctx
ctx_view
)
)
@ -242,7 +252,7 @@ class DirectWorkflowController(base.WorkflowController):
t_names_and_params += (
self._find_next_tasks_for_clause(
self.wf_spec.get_on_success_clause(t_name),
ctx
ctx_view
)
)