Implementing with-items concurrency
* It allows executor to run at most N actions related to current with-items task at a time. Implements blueprint: with-items-concurrency Related-Bug: #1476871 Related-Bug: #1484483 Change-Id: Ifb193f06ae8e9becb6c43f36c07f5ffddb04c71c
This commit is contained in:
parent
3326affb89
commit
723a182a89
@ -142,9 +142,7 @@ class DefaultEngine(base.Engine, coordination.Service):
|
||||
task_ex = db_api.get_task_execution(task_ex_id)
|
||||
# TODO(rakhmerov): The method is mostly needed for policy and
|
||||
# we are supposed to get the same action execution as when the
|
||||
# policy worked. But by the moment this method is called the
|
||||
# last execution object may have changed. It's a race condition.
|
||||
execution = task_ex.executions[-1]
|
||||
# policy worked.
|
||||
|
||||
wf_ex_id = task_ex.workflow_execution_id
|
||||
|
||||
@ -161,13 +159,13 @@ class DefaultEngine(base.Engine, coordination.Service):
|
||||
|
||||
task_ex.state = state
|
||||
|
||||
self._on_task_state_change(task_ex, wf_ex, action_ex=execution)
|
||||
self._on_task_state_change(task_ex, wf_ex)
|
||||
|
||||
def _on_task_state_change(self, task_ex, wf_ex, action_ex=None):
|
||||
def _on_task_state_change(self, task_ex, wf_ex):
|
||||
task_spec = spec_parser.get_task_spec(task_ex.spec)
|
||||
wf_spec = spec_parser.get_workflow_spec(wf_ex.spec)
|
||||
|
||||
if states.is_completed(task_ex.state):
|
||||
if task_handler.is_task_completed(task_ex, task_spec):
|
||||
task_handler.after_task_complete(task_ex, task_spec, wf_spec)
|
||||
|
||||
# Ignore DELAYED state.
|
||||
@ -184,6 +182,11 @@ class DefaultEngine(base.Engine, coordination.Service):
|
||||
self._dispatch_workflow_commands(wf_ex, cmds)
|
||||
|
||||
self._check_workflow_completion(wf_ex, wf_ctrl)
|
||||
elif task_handler.need_to_continue(task_ex, task_spec):
|
||||
# Re-run existing task.
|
||||
cmds = [commands.RunExistingTask(task_ex, reset=False)]
|
||||
|
||||
self._dispatch_workflow_commands(wf_ex, cmds)
|
||||
|
||||
@staticmethod
|
||||
def _check_workflow_completion(wf_ex, wf_ctrl):
|
||||
@ -233,7 +236,7 @@ class DefaultEngine(base.Engine, coordination.Service):
|
||||
if states.is_paused_or_completed(wf_ex.state):
|
||||
return action_ex
|
||||
|
||||
self._on_task_state_change(task_ex, wf_ex, action_ex)
|
||||
self._on_task_state_change(task_ex, wf_ex)
|
||||
|
||||
return action_ex.get_clone()
|
||||
except Exception as e:
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import operator
|
||||
|
||||
from oslo_log import log as logging
|
||||
|
||||
@ -62,6 +63,8 @@ def run_existing_task(task_ex_id, reset=True):
|
||||
not task_spec.get_with_items()):
|
||||
return
|
||||
|
||||
# Reset nested executions only if task is not already RUNNING.
|
||||
if task_ex.state != states.RUNNING:
|
||||
# Reset state of processed task and related action executions.
|
||||
if reset:
|
||||
action_exs = task_ex.executions
|
||||
@ -90,26 +93,9 @@ def _run_existing_task(task_ex, task_spec, wf_spec):
|
||||
task_ex.in_context
|
||||
)
|
||||
|
||||
# TODO(rakhmerov): May be it shouldn't be here. Need to think.
|
||||
if task_spec.get_with_items():
|
||||
with_items.prepare_runtime_context(task_ex, task_spec, input_dicts)
|
||||
|
||||
action_exs = db_api.get_action_executions(
|
||||
task_execution_id=task_ex.id,
|
||||
state=states.SUCCESS,
|
||||
accepted=True
|
||||
)
|
||||
|
||||
with_items_indices = [
|
||||
action_ex.runtime_context['with_items_index']
|
||||
for action_ex in action_exs
|
||||
if 'with_items_index' in action_ex.runtime_context
|
||||
]
|
||||
|
||||
# In some cases we can have no input, e.g. in case of 'with-items'.
|
||||
if input_dicts:
|
||||
for index, input_d in enumerate(input_dicts):
|
||||
if index not in with_items_indices:
|
||||
for index, input_d in input_dicts:
|
||||
_run_action_or_workflow(task_ex, task_spec, input_d, index)
|
||||
else:
|
||||
_schedule_noop_action(task_ex, task_spec)
|
||||
@ -200,9 +186,13 @@ def on_action_complete(action_ex, result):
|
||||
if not task_spec.get_with_items():
|
||||
_complete_task(task_ex, task_spec, task_state)
|
||||
else:
|
||||
if (task_state == states.ERROR or
|
||||
with_items.iterations_completed(task_ex)):
|
||||
_complete_task(task_ex, task_spec, task_state)
|
||||
with_items.increase_capacity(task_ex)
|
||||
if with_items.is_completed(task_ex):
|
||||
_complete_task(
|
||||
task_ex,
|
||||
task_spec,
|
||||
with_items.get_final_state(task_ex)
|
||||
)
|
||||
|
||||
return task_ex
|
||||
|
||||
@ -245,6 +235,9 @@ def _get_input_dictionaries(wf_spec, task_ex, task_spec, ctx):
|
||||
should run with.
|
||||
In case of 'with-items' the result list will contain input dictionaries
|
||||
for all 'with-items' iterations correspondingly.
|
||||
|
||||
:return the list of tuples containing indexes
|
||||
and the corresponding input dict.
|
||||
"""
|
||||
# TODO(rakhmerov): Think how to get rid of this.
|
||||
ctx = data_flow.extract_task_result_proxies_to_context(ctx)
|
||||
@ -257,7 +250,7 @@ def _get_input_dictionaries(wf_spec, task_ex, task_spec, ctx):
|
||||
ctx
|
||||
)
|
||||
|
||||
return [input_dict]
|
||||
return enumerate([input_dict])
|
||||
else:
|
||||
return _get_with_items_input(wf_spec, task_ex, task_spec, ctx)
|
||||
|
||||
@ -297,7 +290,8 @@ def _get_with_items_input(wf_spec, task_ex, task_spec, ctx):
|
||||
{'itemX': 2, 'itemY': 'b'}
|
||||
]
|
||||
|
||||
:return: list containing dicts of each action input.
|
||||
:return: the list of tuples containing indexes
|
||||
and the corresponding input dict.
|
||||
"""
|
||||
with_items_inputs = expr.evaluate_recursively(
|
||||
task_spec.get_with_items(), ctx
|
||||
@ -325,7 +319,21 @@ def _get_with_items_input(wf_spec, task_ex, task_spec, ctx):
|
||||
wf_spec, task_ex, task_spec, new_ctx
|
||||
))
|
||||
|
||||
return action_inputs
|
||||
with_items.prepare_runtime_context(task_ex, task_spec, action_inputs)
|
||||
|
||||
indices = with_items.get_indices_for_loop(task_ex)
|
||||
with_items.decrease_capacity(task_ex, len(indices))
|
||||
|
||||
if indices:
|
||||
current_inputs = operator.itemgetter(*indices)(action_inputs)
|
||||
|
||||
return zip(
|
||||
indices,
|
||||
current_inputs if isinstance(current_inputs, tuple)
|
||||
else [current_inputs]
|
||||
)
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def _get_action_input(wf_spec, task_ex, task_spec, ctx):
|
||||
@ -507,3 +515,19 @@ def _set_task_state(task_ex, state):
|
||||
)
|
||||
|
||||
task_ex.state = state
|
||||
|
||||
|
||||
def is_task_completed(task_ex, task_spec):
|
||||
if task_spec.get_with_items():
|
||||
return with_items.is_completed(task_ex)
|
||||
|
||||
return states.is_completed(task_ex.state)
|
||||
|
||||
|
||||
def need_to_continue(task_ex, task_spec):
|
||||
# For now continue is available only for with-items.
|
||||
if task_spec.get_with_items():
|
||||
return (with_items.has_more_iterations(task_ex)
|
||||
and with_items.get_concurrency_spec(task_spec))
|
||||
|
||||
return False
|
||||
|
@ -67,6 +67,28 @@ workflows:
|
||||
action: std.echo output="Task 2"
|
||||
"""
|
||||
|
||||
|
||||
WITH_ITEMS_WORKBOOK_CONCURRENCY = """
|
||||
---
|
||||
version: '2.0'
|
||||
name: wb3
|
||||
workflows:
|
||||
wf1:
|
||||
type: direct
|
||||
tasks:
|
||||
t1:
|
||||
with-items: i in <% list(range(0, 4)) %>
|
||||
action: std.echo output="Task 1.<% $.i %>"
|
||||
concurrency: 2
|
||||
publish:
|
||||
v1: <% $.t1 %>
|
||||
on-success:
|
||||
- t2
|
||||
t2:
|
||||
action: std.echo output="Task 2"
|
||||
"""
|
||||
|
||||
|
||||
JOIN_WORKBOOK = """
|
||||
---
|
||||
version: '2.0'
|
||||
@ -271,6 +293,77 @@ class DirectWorkflowRerunTest(base.EngineTestCase):
|
||||
|
||||
self.assertEqual(1, len(task_2_action_exs))
|
||||
|
||||
@mock.patch.object(
|
||||
std_actions.EchoAction,
|
||||
'run',
|
||||
mock.MagicMock(
|
||||
side_effect=[
|
||||
exc.ActionException(), # Mock task1 exception for initial run.
|
||||
'Task 1.1', # Mock task1 success for initial run.
|
||||
exc.ActionException(), # Mock task1 exception for initial run.
|
||||
'Task 1.3', # Mock task1 success for initial run.
|
||||
'Task 1.0', # Mock task1 success for rerun.
|
||||
'Task 1.2', # Mock task1 success for rerun.
|
||||
'Task 2' # Mock task2 success.
|
||||
]
|
||||
)
|
||||
)
|
||||
def test_rerun_with_items_concurrency(self):
|
||||
wb_service.create_workbook_v2(WITH_ITEMS_WORKBOOK_CONCURRENCY)
|
||||
|
||||
# Run workflow and fail task.
|
||||
wf_ex = self.engine.start_workflow('wb3.wf1', {})
|
||||
self._await(lambda: self.is_execution_error(wf_ex.id))
|
||||
|
||||
wf_ex = db_api.get_workflow_execution(wf_ex.id)
|
||||
|
||||
self.assertEqual(states.ERROR, wf_ex.state)
|
||||
self.assertEqual(1, len(wf_ex.task_executions))
|
||||
|
||||
task_1_ex = self._assert_single_item(wf_ex.task_executions, name='t1')
|
||||
|
||||
self.assertEqual(states.ERROR, task_1_ex.state)
|
||||
|
||||
task_1_action_exs = db_api.get_action_executions(
|
||||
task_execution_id=task_1_ex.id
|
||||
)
|
||||
|
||||
self.assertEqual(4, len(task_1_action_exs))
|
||||
|
||||
# Resume workflow and re-run failed task.
|
||||
self.engine.rerun_workflow(wf_ex.id, task_1_ex.id, reset=False)
|
||||
wf_ex = db_api.get_workflow_execution(wf_ex.id)
|
||||
|
||||
self.assertEqual(states.RUNNING, wf_ex.state)
|
||||
|
||||
self._await(lambda: self.is_execution_success(wf_ex.id), delay=10)
|
||||
wf_ex = db_api.get_workflow_execution(wf_ex.id)
|
||||
|
||||
self.assertEqual(states.SUCCESS, wf_ex.state)
|
||||
self.assertEqual(2, len(wf_ex.task_executions))
|
||||
|
||||
task_1_ex = self._assert_single_item(wf_ex.task_executions, name='t1')
|
||||
task_2_ex = self._assert_single_item(wf_ex.task_executions, name='t2')
|
||||
|
||||
# Check action executions of task 1.
|
||||
self.assertEqual(states.SUCCESS, task_1_ex.state)
|
||||
|
||||
task_1_action_exs = db_api.get_action_executions(
|
||||
task_execution_id=task_1_ex.id)
|
||||
|
||||
# The action executions that succeeded should not re-run.
|
||||
self.assertEqual(6, len(task_1_action_exs))
|
||||
self.assertListEqual(['Task 1.0', 'Task 1.1', 'Task 1.2', 'Task 1.3'],
|
||||
task_1_ex.published.get('v1'))
|
||||
|
||||
# Check action executions of task 2.
|
||||
self.assertEqual(states.SUCCESS, task_2_ex.state)
|
||||
|
||||
task_2_action_exs = db_api.get_action_executions(
|
||||
task_execution_id=task_2_ex.id)
|
||||
|
||||
self.assertEqual(1, len(task_2_action_exs))
|
||||
|
||||
@mock.patch.object(
|
||||
std_actions.EchoAction,
|
||||
'run',
|
||||
|
@ -20,6 +20,7 @@ from mistral.actions import base as action_base
|
||||
from mistral.db.v2 import api as db_api
|
||||
from mistral import exceptions as exc
|
||||
from mistral.services import workbooks as wb_service
|
||||
from mistral.services import workflows as wf_service
|
||||
from mistral.tests import base as test_base
|
||||
from mistral.tests.unit.engine import base
|
||||
from mistral import utils
|
||||
@ -161,6 +162,21 @@ class RandomSleepEchoAction(action_base.Action):
|
||||
|
||||
|
||||
class WithItemsEngineTest(base.EngineTestCase):
|
||||
def assert_capacity(self, capacity, task_ex):
|
||||
self.assertEqual(
|
||||
capacity,
|
||||
task_ex.runtime_context['with_items_context']['capacity']
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_incomplete_action_ex(task_ex):
|
||||
return [ex for ex in task_ex.executions if not ex.accepted][0]
|
||||
|
||||
@staticmethod
|
||||
def get_running_action_exs_number(task_ex):
|
||||
return len([ex for ex in task_ex.executions
|
||||
if ex.state == states.RUNNING])
|
||||
|
||||
def test_with_items_simple(self):
|
||||
wb_service.create_workbook_v2(WORKBOOK)
|
||||
|
||||
@ -176,7 +192,7 @@ class WithItemsEngineTest(base.EngineTestCase):
|
||||
|
||||
tasks = wf_ex.task_executions
|
||||
task1 = self._assert_single_item(tasks, name='task1')
|
||||
with_items_context = task1.runtime_context['with_items']
|
||||
with_items_context = task1.runtime_context['with_items_context']
|
||||
|
||||
self.assertEqual(3, with_items_context['count'])
|
||||
|
||||
@ -197,6 +213,37 @@ class WithItemsEngineTest(base.EngineTestCase):
|
||||
self.assertEqual(1, len(tasks))
|
||||
self.assertEqual(states.SUCCESS, task1.state)
|
||||
|
||||
def test_with_items_fail(self):
|
||||
workflow = """---
|
||||
version: "2.0"
|
||||
|
||||
with_items:
|
||||
type: direct
|
||||
|
||||
tasks:
|
||||
task1:
|
||||
with-items: i in [1, 2, 3]
|
||||
action: std.fail
|
||||
on-error: task2
|
||||
|
||||
task2:
|
||||
action: std.echo output="With-items failed"
|
||||
"""
|
||||
wf_service.create_workflows(workflow)
|
||||
|
||||
# Start workflow.
|
||||
wf_ex = self.engine.start_workflow('with_items', {})
|
||||
|
||||
self._await(
|
||||
lambda: self.is_execution_success(wf_ex.id),
|
||||
)
|
||||
|
||||
# Note: We need to reread execution to access related tasks.
|
||||
wf_ex = db_api.get_workflow_execution(wf_ex.id)
|
||||
|
||||
tasks = wf_ex.task_executions
|
||||
self.assertEqual(2, len(tasks))
|
||||
|
||||
def test_with_items_static_var(self):
|
||||
wb_service.create_workbook_v2(WORKBOOK_WITH_STATIC_VAR)
|
||||
|
||||
@ -468,3 +515,268 @@ class WithItemsEngineTest(base.EngineTestCase):
|
||||
|
||||
self.assertEqual(1, len(tasks))
|
||||
self.assertEqual(states.SUCCESS, task1.state)
|
||||
|
||||
def test_with_items_concurrency_1(self):
|
||||
workflow_with_concurrency_1 = """---
|
||||
version: "2.0"
|
||||
|
||||
concurrency_test:
|
||||
type: direct
|
||||
|
||||
input:
|
||||
- names: ["John", "Ivan", "Mistral"]
|
||||
|
||||
tasks:
|
||||
task1:
|
||||
action: std.async_noop
|
||||
with-items: name in <% $.names %>
|
||||
concurrency: 1
|
||||
|
||||
"""
|
||||
wf_service.create_workflows(workflow_with_concurrency_1)
|
||||
|
||||
# Start workflow.
|
||||
wf_ex = self.engine.start_workflow('concurrency_test', {})
|
||||
wf_ex = db_api.get_execution(wf_ex.id)
|
||||
task_ex = wf_ex.task_executions[0]
|
||||
|
||||
task_ex = db_api.get_task_execution(task_ex.id)
|
||||
self.assert_capacity(0, task_ex)
|
||||
self.assertEqual(1, self.get_running_action_exs_number(task_ex))
|
||||
|
||||
# 1st iteration complete.
|
||||
self.engine.on_action_complete(
|
||||
self.get_incomplete_action_ex(task_ex).id,
|
||||
wf_utils.Result("John")
|
||||
)
|
||||
|
||||
task_ex = db_api.get_task_execution(task_ex.id)
|
||||
self.assert_capacity(0, task_ex)
|
||||
self.assertEqual(1, self.get_running_action_exs_number(task_ex))
|
||||
|
||||
# 2nd iteration complete.
|
||||
self.engine.on_action_complete(
|
||||
self.get_incomplete_action_ex(task_ex).id,
|
||||
wf_utils.Result("Ivan")
|
||||
)
|
||||
|
||||
task_ex = db_api.get_task_execution(task_ex.id)
|
||||
self.assert_capacity(0, task_ex)
|
||||
self.assertEqual(1, self.get_running_action_exs_number(task_ex))
|
||||
|
||||
# 3rd iteration complete.
|
||||
self.engine.on_action_complete(
|
||||
self.get_incomplete_action_ex(task_ex).id,
|
||||
wf_utils.Result("Mistral")
|
||||
)
|
||||
|
||||
task_ex = db_api.get_task_execution(task_ex.id)
|
||||
self.assert_capacity(1, task_ex)
|
||||
|
||||
self._await(
|
||||
lambda: self.is_execution_success(wf_ex.id),
|
||||
)
|
||||
|
||||
task_ex = db_api.get_task_execution(task_ex.id)
|
||||
# Since we know that we can receive results in random order,
|
||||
# check is not depend on order of items.
|
||||
result = data_flow.get_task_execution_result(task_ex)
|
||||
|
||||
self.assertTrue(isinstance(result, list))
|
||||
|
||||
self.assertIn('John', result)
|
||||
self.assertIn('Ivan', result)
|
||||
self.assertIn('Mistral', result)
|
||||
|
||||
self.assertEqual(states.SUCCESS, task_ex.state)
|
||||
|
||||
def test_with_items_concurrency_2(self):
|
||||
workflow_with_concurrency_2 = """---
|
||||
version: "2.0"
|
||||
|
||||
concurrency_test:
|
||||
type: direct
|
||||
|
||||
input:
|
||||
- names: ["John", "Ivan", "Mistral", "Hello"]
|
||||
|
||||
tasks:
|
||||
task1:
|
||||
action: std.async_noop
|
||||
with-items: name in <% $.names %>
|
||||
concurrency: 2
|
||||
|
||||
"""
|
||||
wf_service.create_workflows(workflow_with_concurrency_2)
|
||||
|
||||
# Start workflow.
|
||||
wf_ex = self.engine.start_workflow('concurrency_test', {})
|
||||
wf_ex = db_api.get_execution(wf_ex.id)
|
||||
task_ex = wf_ex.task_executions[0]
|
||||
|
||||
self.assert_capacity(0, task_ex)
|
||||
self.assertEqual(2, self.get_running_action_exs_number(task_ex))
|
||||
|
||||
# 1st iteration complete.
|
||||
self.engine.on_action_complete(
|
||||
self.get_incomplete_action_ex(task_ex).id,
|
||||
wf_utils.Result("John")
|
||||
)
|
||||
|
||||
task_ex = db_api.get_task_execution(task_ex.id)
|
||||
self.assert_capacity(0, task_ex)
|
||||
self.assertEqual(2, self.get_running_action_exs_number(task_ex))
|
||||
|
||||
# 2nd iteration complete.
|
||||
self.engine.on_action_complete(
|
||||
self.get_incomplete_action_ex(task_ex).id,
|
||||
wf_utils.Result("Ivan")
|
||||
)
|
||||
|
||||
task_ex = db_api.get_task_execution(task_ex.id)
|
||||
self.assert_capacity(0, task_ex)
|
||||
self.assertEqual(2, self.get_running_action_exs_number(task_ex))
|
||||
|
||||
# 3rd iteration complete.
|
||||
self.engine.on_action_complete(
|
||||
self.get_incomplete_action_ex(task_ex).id,
|
||||
wf_utils.Result("Mistral")
|
||||
)
|
||||
|
||||
task_ex = db_api.get_task_execution(task_ex.id)
|
||||
self.assert_capacity(1, task_ex)
|
||||
|
||||
# 4th iteration complete.
|
||||
self.engine.on_action_complete(
|
||||
self.get_incomplete_action_ex(task_ex).id,
|
||||
wf_utils.Result("Hello")
|
||||
)
|
||||
|
||||
task_ex = db_api.get_task_execution(task_ex.id)
|
||||
self.assert_capacity(2, task_ex)
|
||||
|
||||
self._await(
|
||||
lambda: self.is_execution_success(wf_ex.id),
|
||||
)
|
||||
|
||||
task_ex = db_api.get_task_execution(task_ex.id)
|
||||
# Since we know that we can receive results in random order,
|
||||
# check is not depend on order of items.
|
||||
result = data_flow.get_task_execution_result(task_ex)
|
||||
self.assertTrue(isinstance(result, list))
|
||||
|
||||
self.assertIn('John', result)
|
||||
self.assertIn('Ivan', result)
|
||||
self.assertIn('Mistral', result)
|
||||
self.assertIn('Hello', result)
|
||||
|
||||
self.assertEqual(states.SUCCESS, task_ex.state)
|
||||
|
||||
def test_with_items_concurrency_2_fail(self):
|
||||
workflow_with_concurrency_2_fail = """---
|
||||
version: "2.0"
|
||||
|
||||
concurrency_test_fail:
|
||||
type: direct
|
||||
|
||||
tasks:
|
||||
task1:
|
||||
with-items: i in [1, 2, 3, 4]
|
||||
action: std.fail
|
||||
concurrency: 2
|
||||
on-error: task2
|
||||
|
||||
task2:
|
||||
action: std.echo output="With-items failed"
|
||||
|
||||
"""
|
||||
wf_service.create_workflows(workflow_with_concurrency_2_fail)
|
||||
|
||||
# Start workflow.
|
||||
wf_ex = self.engine.start_workflow('concurrency_test_fail', {})
|
||||
|
||||
self._await(
|
||||
lambda: self.is_execution_success(wf_ex.id),
|
||||
)
|
||||
wf_ex = db_api.get_execution(wf_ex.id)
|
||||
|
||||
task_exs = wf_ex.task_executions
|
||||
|
||||
self.assertEqual(2, len(task_exs))
|
||||
|
||||
task_2 = self._assert_single_item(task_exs, name='task2')
|
||||
|
||||
self.assertEqual(
|
||||
"With-items failed",
|
||||
data_flow.get_task_execution_result(task_2)
|
||||
)
|
||||
|
||||
def test_with_items_concurrency_3(self):
|
||||
workflow_with_concurrency_3 = """---
|
||||
version: "2.0"
|
||||
|
||||
concurrency_test:
|
||||
type: direct
|
||||
|
||||
input:
|
||||
- names: ["John", "Ivan", "Mistral"]
|
||||
|
||||
tasks:
|
||||
task1:
|
||||
action: std.async_noop
|
||||
with-items: name in <% $.names %>
|
||||
concurrency: 3
|
||||
|
||||
"""
|
||||
wf_service.create_workflows(workflow_with_concurrency_3)
|
||||
|
||||
# Start workflow.
|
||||
wf_ex = self.engine.start_workflow('concurrency_test', {})
|
||||
wf_ex = db_api.get_execution(wf_ex.id)
|
||||
task_ex = wf_ex.task_executions[0]
|
||||
|
||||
self.assert_capacity(0, task_ex)
|
||||
self.assertEqual(3, self.get_running_action_exs_number(task_ex))
|
||||
|
||||
# 1st iteration complete.
|
||||
self.engine.on_action_complete(
|
||||
self.get_incomplete_action_ex(task_ex).id,
|
||||
wf_utils.Result("John")
|
||||
)
|
||||
|
||||
task_ex = db_api.get_task_execution(task_ex.id)
|
||||
self.assert_capacity(1, task_ex)
|
||||
|
||||
# 2nd iteration complete.
|
||||
self.engine.on_action_complete(
|
||||
self.get_incomplete_action_ex(task_ex).id,
|
||||
wf_utils.Result("Ivan")
|
||||
)
|
||||
|
||||
task_ex = db_api.get_task_execution(task_ex.id)
|
||||
self.assert_capacity(2, task_ex)
|
||||
|
||||
# 3rd iteration complete.
|
||||
self.engine.on_action_complete(
|
||||
self.get_incomplete_action_ex(task_ex).id,
|
||||
wf_utils.Result("Mistral")
|
||||
)
|
||||
|
||||
task_ex = db_api.get_task_execution(task_ex.id)
|
||||
self.assert_capacity(3, task_ex)
|
||||
|
||||
self._await(
|
||||
lambda: self.is_execution_success(wf_ex.id),
|
||||
)
|
||||
|
||||
task_ex = db_api.get_task_execution(task_ex.id)
|
||||
# Since we know that we can receive results in random order,
|
||||
# check is not depend on order of items.
|
||||
result = data_flow.get_task_execution_result(task_ex)
|
||||
self.assertTrue(isinstance(result, list))
|
||||
|
||||
self.assertIn('John', result)
|
||||
self.assertIn('Ivan', result)
|
||||
self.assertIn('Mistral', result)
|
||||
|
||||
self.assertEqual(states.SUCCESS, task_ex.state)
|
||||
|
55
mistral/tests/unit/workflow/test_with_items.py
Normal file
55
mistral/tests/unit/workflow/test_with_items.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Copyright 2015 - Mirantis, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from mistral.db.v2.sqlalchemy import models
|
||||
from mistral.tests import base
|
||||
from mistral.workflow import states
|
||||
from mistral.workflow import with_items
|
||||
|
||||
|
||||
class WithItemsTest(base.BaseTest):
|
||||
@staticmethod
|
||||
def get_action_ex(accepted, state, index):
|
||||
return models.ActionExecution(
|
||||
accepted=accepted,
|
||||
state=state,
|
||||
runtime_context={
|
||||
'with_items_index': index
|
||||
}
|
||||
)
|
||||
|
||||
def test_get_indices(self):
|
||||
# Task execution for running 6 items with concurrency=3.
|
||||
task_ex = models.TaskExecution(
|
||||
runtime_context={
|
||||
'with_items_context': {
|
||||
'capacity': 3,
|
||||
'count': 6
|
||||
}
|
||||
},
|
||||
executions=[]
|
||||
)
|
||||
|
||||
# Set 3 items: 2 success and 1 error unaccepted.
|
||||
task_ex.executions += [
|
||||
self.get_action_ex(True, states.SUCCESS, 0),
|
||||
self.get_action_ex(True, states.SUCCESS, 1),
|
||||
self.get_action_ex(False, states.ERROR, 2)
|
||||
]
|
||||
|
||||
# Then call get_indices and expect [2, 3, 4].
|
||||
indices = with_items.get_indices_for_loop(task_ex)
|
||||
|
||||
self.assertListEqual([2, 3, 4], indices)
|
@ -12,20 +12,42 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import six
|
||||
|
||||
from mistral.db.v2 import api as db_api
|
||||
from mistral import exceptions as exc
|
||||
from mistral.workflow import states
|
||||
|
||||
|
||||
_CAPACITY = 'capacity'
|
||||
_COUNT = 'count'
|
||||
_WITH_ITEMS = 'with_items_context'
|
||||
|
||||
|
||||
def _get_context(task_ex):
|
||||
return task_ex.runtime_context['with_items']
|
||||
return task_ex.runtime_context[_WITH_ITEMS]
|
||||
|
||||
|
||||
def get_count(task_ex):
|
||||
return _get_context(task_ex)['count']
|
||||
return _get_context(task_ex)[_COUNT]
|
||||
|
||||
|
||||
def is_completed(task_ex):
|
||||
action_exs = db_api.get_action_executions(
|
||||
task_execution_id=task_ex.id,
|
||||
accepted=True
|
||||
)
|
||||
count = get_count(task_ex) or 1
|
||||
|
||||
return count == len(action_exs)
|
||||
|
||||
|
||||
def get_index(task_ex):
|
||||
return _get_context(task_ex)['index']
|
||||
return len(
|
||||
filter(
|
||||
lambda x: x.accepted or states.RUNNING, task_ex.executions
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_concurrency_spec(task_spec):
|
||||
@ -34,38 +56,88 @@ def get_concurrency_spec(task_spec):
|
||||
return policies.get_concurrency() if policies else None
|
||||
|
||||
|
||||
def get_indexes_for_loop(task_ex, task_spec):
|
||||
concurrency_spec = get_concurrency_spec(task_spec)
|
||||
concurrency = task_ex.runtime_context['concurrency']
|
||||
def get_final_state(task_ex):
|
||||
find_error = lambda x: x.accepted and x.state == states.ERROR
|
||||
|
||||
if filter(find_error, task_ex.executions):
|
||||
return states.ERROR
|
||||
else:
|
||||
return states.SUCCESS
|
||||
|
||||
|
||||
def _get_indices_if_rerun(unaccepted_executions):
|
||||
"""Returns a list of indices in case of re-running with-items.
|
||||
|
||||
:param unaccepted_executions: List of executions.
|
||||
:return: a list of numbers.
|
||||
"""
|
||||
|
||||
return [
|
||||
ex.runtime_context['with_items_index']
|
||||
for ex in unaccepted_executions
|
||||
]
|
||||
|
||||
|
||||
def _get_unaccepted_act_exs(task_ex):
|
||||
# Choose only if not accepted but completed.
|
||||
return filter(
|
||||
lambda x: not x.accepted and states.is_completed(x.state),
|
||||
task_ex.executions
|
||||
)
|
||||
|
||||
|
||||
def get_indices_for_loop(task_ex):
|
||||
capacity = _get_context(task_ex)[_CAPACITY]
|
||||
count = get_count(task_ex)
|
||||
|
||||
unaccepted = _get_unaccepted_act_exs(task_ex)
|
||||
|
||||
if unaccepted:
|
||||
indices = _get_indices_if_rerun(unaccepted)
|
||||
|
||||
if max(indices) < count - 1:
|
||||
indices += list(six.moves.range(max(indices) + 1, count))
|
||||
|
||||
return indices[:capacity] if capacity else indices
|
||||
|
||||
index = get_index(task_ex)
|
||||
|
||||
number_to_execute = (get_count(task_ex) - index
|
||||
if not concurrency_spec else concurrency)
|
||||
number_to_execute = capacity if capacity else count - index
|
||||
|
||||
return index, index + number_to_execute
|
||||
return list(six.moves.range(index, index + number_to_execute))
|
||||
|
||||
|
||||
def do_step(task_ex):
|
||||
def decrease_capacity(task_ex, count):
|
||||
with_items_context = _get_context(task_ex)
|
||||
|
||||
if with_items_context['capacity'] > 0:
|
||||
with_items_context['capacity'] -= 1
|
||||
if with_items_context[_CAPACITY] >= count:
|
||||
with_items_context[_CAPACITY] -= count
|
||||
elif with_items_context[_CAPACITY]:
|
||||
raise exc.WorkflowException(
|
||||
"Impossible to apply current with-items concurrency."
|
||||
)
|
||||
|
||||
with_items_context['index'] += 1
|
||||
task_ex.runtime_context.update({_WITH_ITEMS: with_items_context})
|
||||
|
||||
task_ex.runtime_context.update({'with_items': with_items_context})
|
||||
|
||||
def increase_capacity(task_ex):
|
||||
with_items_context = _get_context(task_ex)
|
||||
max_concurrency = task_ex.runtime_context.get('concurrency')
|
||||
|
||||
if max_concurrency and with_items_context[_CAPACITY] < max_concurrency:
|
||||
with_items_context[_CAPACITY] += 1
|
||||
task_ex.runtime_context.update({_WITH_ITEMS: with_items_context})
|
||||
|
||||
|
||||
def prepare_runtime_context(task_ex, task_spec, input_dicts):
|
||||
runtime_context = task_ex.runtime_context
|
||||
with_items_spec = task_spec.get_with_items()
|
||||
|
||||
if with_items_spec:
|
||||
if with_items_spec and not runtime_context.get(_WITH_ITEMS):
|
||||
# Prepare current indexes and parallel limitation.
|
||||
runtime_context['with_items'] = {
|
||||
'capacity': get_concurrency_spec(task_spec),
|
||||
'index': 0,
|
||||
'count': len(input_dicts)
|
||||
runtime_context[_WITH_ITEMS] = {
|
||||
_CAPACITY: get_concurrency_spec(task_spec) or None,
|
||||
_COUNT: len(input_dicts)
|
||||
}
|
||||
|
||||
|
||||
@ -88,7 +160,12 @@ def validate_input(with_items_input):
|
||||
)
|
||||
|
||||
|
||||
def iterations_completed(task_ex):
|
||||
completed = all([states.is_completed(ex.state)
|
||||
for ex in task_ex.executions])
|
||||
return completed
|
||||
def has_more_iterations(task_ex):
|
||||
# See action executions which have been already
|
||||
# accepted or are still running.
|
||||
action_exs = filter(
|
||||
lambda x: x.accepted or x.state == states.RUNNING,
|
||||
task_ex.executions
|
||||
)
|
||||
|
||||
return get_count(task_ex) > len(action_exs)
|
||||
|
Loading…
Reference in New Issue
Block a user