Implementing with-items concurrency

* It allows executor to run at most N actions
   related to current with-items task at a time.

Implements blueprint: with-items-concurrency
Related-Bug: #1476871
Related-Bug: #1484483

Change-Id: Ifb193f06ae8e9becb6c43f36c07f5ffddb04c71c
This commit is contained in:
Nikolay Mahotkin 2015-07-24 14:11:26 +03:00
parent 3326affb89
commit 723a182a89
6 changed files with 631 additions and 67 deletions

View File

@ -142,9 +142,7 @@ class DefaultEngine(base.Engine, coordination.Service):
task_ex = db_api.get_task_execution(task_ex_id) task_ex = db_api.get_task_execution(task_ex_id)
# TODO(rakhmerov): The method is mostly needed for policy and # TODO(rakhmerov): The method is mostly needed for policy and
# we are supposed to get the same action execution as when the # we are supposed to get the same action execution as when the
# policy worked. But by the moment this method is called the # policy worked.
# last execution object may have changed. It's a race condition.
execution = task_ex.executions[-1]
wf_ex_id = task_ex.workflow_execution_id wf_ex_id = task_ex.workflow_execution_id
@ -161,13 +159,13 @@ class DefaultEngine(base.Engine, coordination.Service):
task_ex.state = state task_ex.state = state
self._on_task_state_change(task_ex, wf_ex, action_ex=execution) self._on_task_state_change(task_ex, wf_ex)
def _on_task_state_change(self, task_ex, wf_ex, action_ex=None): def _on_task_state_change(self, task_ex, wf_ex):
task_spec = spec_parser.get_task_spec(task_ex.spec) task_spec = spec_parser.get_task_spec(task_ex.spec)
wf_spec = spec_parser.get_workflow_spec(wf_ex.spec) wf_spec = spec_parser.get_workflow_spec(wf_ex.spec)
if states.is_completed(task_ex.state): if task_handler.is_task_completed(task_ex, task_spec):
task_handler.after_task_complete(task_ex, task_spec, wf_spec) task_handler.after_task_complete(task_ex, task_spec, wf_spec)
# Ignore DELAYED state. # Ignore DELAYED state.
@ -184,6 +182,11 @@ class DefaultEngine(base.Engine, coordination.Service):
self._dispatch_workflow_commands(wf_ex, cmds) self._dispatch_workflow_commands(wf_ex, cmds)
self._check_workflow_completion(wf_ex, wf_ctrl) self._check_workflow_completion(wf_ex, wf_ctrl)
elif task_handler.need_to_continue(task_ex, task_spec):
# Re-run existing task.
cmds = [commands.RunExistingTask(task_ex, reset=False)]
self._dispatch_workflow_commands(wf_ex, cmds)
@staticmethod @staticmethod
def _check_workflow_completion(wf_ex, wf_ctrl): def _check_workflow_completion(wf_ex, wf_ctrl):
@ -233,7 +236,7 @@ class DefaultEngine(base.Engine, coordination.Service):
if states.is_paused_or_completed(wf_ex.state): if states.is_paused_or_completed(wf_ex.state):
return action_ex return action_ex
self._on_task_state_change(task_ex, wf_ex, action_ex) self._on_task_state_change(task_ex, wf_ex)
return action_ex.get_clone() return action_ex.get_clone()
except Exception as e: except Exception as e:

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import copy import copy
import operator
from oslo_log import log as logging from oslo_log import log as logging
@ -62,6 +63,8 @@ def run_existing_task(task_ex_id, reset=True):
not task_spec.get_with_items()): not task_spec.get_with_items()):
return return
# Reset nested executions only if task is not already RUNNING.
if task_ex.state != states.RUNNING:
# Reset state of processed task and related action executions. # Reset state of processed task and related action executions.
if reset: if reset:
action_exs = task_ex.executions action_exs = task_ex.executions
@ -90,26 +93,9 @@ def _run_existing_task(task_ex, task_spec, wf_spec):
task_ex.in_context task_ex.in_context
) )
# TODO(rakhmerov): May be it shouldn't be here. Need to think.
if task_spec.get_with_items():
with_items.prepare_runtime_context(task_ex, task_spec, input_dicts)
action_exs = db_api.get_action_executions(
task_execution_id=task_ex.id,
state=states.SUCCESS,
accepted=True
)
with_items_indices = [
action_ex.runtime_context['with_items_index']
for action_ex in action_exs
if 'with_items_index' in action_ex.runtime_context
]
# In some cases we can have no input, e.g. in case of 'with-items'. # In some cases we can have no input, e.g. in case of 'with-items'.
if input_dicts: if input_dicts:
for index, input_d in enumerate(input_dicts): for index, input_d in input_dicts:
if index not in with_items_indices:
_run_action_or_workflow(task_ex, task_spec, input_d, index) _run_action_or_workflow(task_ex, task_spec, input_d, index)
else: else:
_schedule_noop_action(task_ex, task_spec) _schedule_noop_action(task_ex, task_spec)
@ -200,9 +186,13 @@ def on_action_complete(action_ex, result):
if not task_spec.get_with_items(): if not task_spec.get_with_items():
_complete_task(task_ex, task_spec, task_state) _complete_task(task_ex, task_spec, task_state)
else: else:
if (task_state == states.ERROR or with_items.increase_capacity(task_ex)
with_items.iterations_completed(task_ex)): if with_items.is_completed(task_ex):
_complete_task(task_ex, task_spec, task_state) _complete_task(
task_ex,
task_spec,
with_items.get_final_state(task_ex)
)
return task_ex return task_ex
@ -245,6 +235,9 @@ def _get_input_dictionaries(wf_spec, task_ex, task_spec, ctx):
should run with. should run with.
In case of 'with-items' the result list will contain input dictionaries In case of 'with-items' the result list will contain input dictionaries
for all 'with-items' iterations correspondingly. for all 'with-items' iterations correspondingly.
:return the list of tuples containing indexes
and the corresponding input dict.
""" """
# TODO(rakhmerov): Think how to get rid of this. # TODO(rakhmerov): Think how to get rid of this.
ctx = data_flow.extract_task_result_proxies_to_context(ctx) ctx = data_flow.extract_task_result_proxies_to_context(ctx)
@ -257,7 +250,7 @@ def _get_input_dictionaries(wf_spec, task_ex, task_spec, ctx):
ctx ctx
) )
return [input_dict] return enumerate([input_dict])
else: else:
return _get_with_items_input(wf_spec, task_ex, task_spec, ctx) return _get_with_items_input(wf_spec, task_ex, task_spec, ctx)
@ -297,7 +290,8 @@ def _get_with_items_input(wf_spec, task_ex, task_spec, ctx):
{'itemX': 2, 'itemY': 'b'} {'itemX': 2, 'itemY': 'b'}
] ]
:return: list containing dicts of each action input. :return: the list of tuples containing indexes
and the corresponding input dict.
""" """
with_items_inputs = expr.evaluate_recursively( with_items_inputs = expr.evaluate_recursively(
task_spec.get_with_items(), ctx task_spec.get_with_items(), ctx
@ -325,7 +319,21 @@ def _get_with_items_input(wf_spec, task_ex, task_spec, ctx):
wf_spec, task_ex, task_spec, new_ctx wf_spec, task_ex, task_spec, new_ctx
)) ))
return action_inputs with_items.prepare_runtime_context(task_ex, task_spec, action_inputs)
indices = with_items.get_indices_for_loop(task_ex)
with_items.decrease_capacity(task_ex, len(indices))
if indices:
current_inputs = operator.itemgetter(*indices)(action_inputs)
return zip(
indices,
current_inputs if isinstance(current_inputs, tuple)
else [current_inputs]
)
return []
def _get_action_input(wf_spec, task_ex, task_spec, ctx): def _get_action_input(wf_spec, task_ex, task_spec, ctx):
@ -507,3 +515,19 @@ def _set_task_state(task_ex, state):
) )
task_ex.state = state task_ex.state = state
def is_task_completed(task_ex, task_spec):
if task_spec.get_with_items():
return with_items.is_completed(task_ex)
return states.is_completed(task_ex.state)
def need_to_continue(task_ex, task_spec):
# For now continue is available only for with-items.
if task_spec.get_with_items():
return (with_items.has_more_iterations(task_ex)
and with_items.get_concurrency_spec(task_spec))
return False

View File

@ -67,6 +67,28 @@ workflows:
action: std.echo output="Task 2" action: std.echo output="Task 2"
""" """
WITH_ITEMS_WORKBOOK_CONCURRENCY = """
---
version: '2.0'
name: wb3
workflows:
wf1:
type: direct
tasks:
t1:
with-items: i in <% list(range(0, 4)) %>
action: std.echo output="Task 1.<% $.i %>"
concurrency: 2
publish:
v1: <% $.t1 %>
on-success:
- t2
t2:
action: std.echo output="Task 2"
"""
JOIN_WORKBOOK = """ JOIN_WORKBOOK = """
--- ---
version: '2.0' version: '2.0'
@ -271,6 +293,77 @@ class DirectWorkflowRerunTest(base.EngineTestCase):
self.assertEqual(1, len(task_2_action_exs)) self.assertEqual(1, len(task_2_action_exs))
@mock.patch.object(
std_actions.EchoAction,
'run',
mock.MagicMock(
side_effect=[
exc.ActionException(), # Mock task1 exception for initial run.
'Task 1.1', # Mock task1 success for initial run.
exc.ActionException(), # Mock task1 exception for initial run.
'Task 1.3', # Mock task1 success for initial run.
'Task 1.0', # Mock task1 success for rerun.
'Task 1.2', # Mock task1 success for rerun.
'Task 2' # Mock task2 success.
]
)
)
def test_rerun_with_items_concurrency(self):
wb_service.create_workbook_v2(WITH_ITEMS_WORKBOOK_CONCURRENCY)
# Run workflow and fail task.
wf_ex = self.engine.start_workflow('wb3.wf1', {})
self._await(lambda: self.is_execution_error(wf_ex.id))
wf_ex = db_api.get_workflow_execution(wf_ex.id)
self.assertEqual(states.ERROR, wf_ex.state)
self.assertEqual(1, len(wf_ex.task_executions))
task_1_ex = self._assert_single_item(wf_ex.task_executions, name='t1')
self.assertEqual(states.ERROR, task_1_ex.state)
task_1_action_exs = db_api.get_action_executions(
task_execution_id=task_1_ex.id
)
self.assertEqual(4, len(task_1_action_exs))
# Resume workflow and re-run failed task.
self.engine.rerun_workflow(wf_ex.id, task_1_ex.id, reset=False)
wf_ex = db_api.get_workflow_execution(wf_ex.id)
self.assertEqual(states.RUNNING, wf_ex.state)
self._await(lambda: self.is_execution_success(wf_ex.id), delay=10)
wf_ex = db_api.get_workflow_execution(wf_ex.id)
self.assertEqual(states.SUCCESS, wf_ex.state)
self.assertEqual(2, len(wf_ex.task_executions))
task_1_ex = self._assert_single_item(wf_ex.task_executions, name='t1')
task_2_ex = self._assert_single_item(wf_ex.task_executions, name='t2')
# Check action executions of task 1.
self.assertEqual(states.SUCCESS, task_1_ex.state)
task_1_action_exs = db_api.get_action_executions(
task_execution_id=task_1_ex.id)
# The action executions that succeeded should not re-run.
self.assertEqual(6, len(task_1_action_exs))
self.assertListEqual(['Task 1.0', 'Task 1.1', 'Task 1.2', 'Task 1.3'],
task_1_ex.published.get('v1'))
# Check action executions of task 2.
self.assertEqual(states.SUCCESS, task_2_ex.state)
task_2_action_exs = db_api.get_action_executions(
task_execution_id=task_2_ex.id)
self.assertEqual(1, len(task_2_action_exs))
@mock.patch.object( @mock.patch.object(
std_actions.EchoAction, std_actions.EchoAction,
'run', 'run',

View File

@ -20,6 +20,7 @@ from mistral.actions import base as action_base
from mistral.db.v2 import api as db_api from mistral.db.v2 import api as db_api
from mistral import exceptions as exc from mistral import exceptions as exc
from mistral.services import workbooks as wb_service from mistral.services import workbooks as wb_service
from mistral.services import workflows as wf_service
from mistral.tests import base as test_base from mistral.tests import base as test_base
from mistral.tests.unit.engine import base from mistral.tests.unit.engine import base
from mistral import utils from mistral import utils
@ -161,6 +162,21 @@ class RandomSleepEchoAction(action_base.Action):
class WithItemsEngineTest(base.EngineTestCase): class WithItemsEngineTest(base.EngineTestCase):
def assert_capacity(self, capacity, task_ex):
self.assertEqual(
capacity,
task_ex.runtime_context['with_items_context']['capacity']
)
@staticmethod
def get_incomplete_action_ex(task_ex):
return [ex for ex in task_ex.executions if not ex.accepted][0]
@staticmethod
def get_running_action_exs_number(task_ex):
return len([ex for ex in task_ex.executions
if ex.state == states.RUNNING])
def test_with_items_simple(self): def test_with_items_simple(self):
wb_service.create_workbook_v2(WORKBOOK) wb_service.create_workbook_v2(WORKBOOK)
@ -176,7 +192,7 @@ class WithItemsEngineTest(base.EngineTestCase):
tasks = wf_ex.task_executions tasks = wf_ex.task_executions
task1 = self._assert_single_item(tasks, name='task1') task1 = self._assert_single_item(tasks, name='task1')
with_items_context = task1.runtime_context['with_items'] with_items_context = task1.runtime_context['with_items_context']
self.assertEqual(3, with_items_context['count']) self.assertEqual(3, with_items_context['count'])
@ -197,6 +213,37 @@ class WithItemsEngineTest(base.EngineTestCase):
self.assertEqual(1, len(tasks)) self.assertEqual(1, len(tasks))
self.assertEqual(states.SUCCESS, task1.state) self.assertEqual(states.SUCCESS, task1.state)
def test_with_items_fail(self):
workflow = """---
version: "2.0"
with_items:
type: direct
tasks:
task1:
with-items: i in [1, 2, 3]
action: std.fail
on-error: task2
task2:
action: std.echo output="With-items failed"
"""
wf_service.create_workflows(workflow)
# Start workflow.
wf_ex = self.engine.start_workflow('with_items', {})
self._await(
lambda: self.is_execution_success(wf_ex.id),
)
# Note: We need to reread execution to access related tasks.
wf_ex = db_api.get_workflow_execution(wf_ex.id)
tasks = wf_ex.task_executions
self.assertEqual(2, len(tasks))
def test_with_items_static_var(self): def test_with_items_static_var(self):
wb_service.create_workbook_v2(WORKBOOK_WITH_STATIC_VAR) wb_service.create_workbook_v2(WORKBOOK_WITH_STATIC_VAR)
@ -468,3 +515,268 @@ class WithItemsEngineTest(base.EngineTestCase):
self.assertEqual(1, len(tasks)) self.assertEqual(1, len(tasks))
self.assertEqual(states.SUCCESS, task1.state) self.assertEqual(states.SUCCESS, task1.state)
def test_with_items_concurrency_1(self):
workflow_with_concurrency_1 = """---
version: "2.0"
concurrency_test:
type: direct
input:
- names: ["John", "Ivan", "Mistral"]
tasks:
task1:
action: std.async_noop
with-items: name in <% $.names %>
concurrency: 1
"""
wf_service.create_workflows(workflow_with_concurrency_1)
# Start workflow.
wf_ex = self.engine.start_workflow('concurrency_test', {})
wf_ex = db_api.get_execution(wf_ex.id)
task_ex = wf_ex.task_executions[0]
task_ex = db_api.get_task_execution(task_ex.id)
self.assert_capacity(0, task_ex)
self.assertEqual(1, self.get_running_action_exs_number(task_ex))
# 1st iteration complete.
self.engine.on_action_complete(
self.get_incomplete_action_ex(task_ex).id,
wf_utils.Result("John")
)
task_ex = db_api.get_task_execution(task_ex.id)
self.assert_capacity(0, task_ex)
self.assertEqual(1, self.get_running_action_exs_number(task_ex))
# 2nd iteration complete.
self.engine.on_action_complete(
self.get_incomplete_action_ex(task_ex).id,
wf_utils.Result("Ivan")
)
task_ex = db_api.get_task_execution(task_ex.id)
self.assert_capacity(0, task_ex)
self.assertEqual(1, self.get_running_action_exs_number(task_ex))
# 3rd iteration complete.
self.engine.on_action_complete(
self.get_incomplete_action_ex(task_ex).id,
wf_utils.Result("Mistral")
)
task_ex = db_api.get_task_execution(task_ex.id)
self.assert_capacity(1, task_ex)
self._await(
lambda: self.is_execution_success(wf_ex.id),
)
task_ex = db_api.get_task_execution(task_ex.id)
# Since we know that we can receive results in random order,
# check is not depend on order of items.
result = data_flow.get_task_execution_result(task_ex)
self.assertTrue(isinstance(result, list))
self.assertIn('John', result)
self.assertIn('Ivan', result)
self.assertIn('Mistral', result)
self.assertEqual(states.SUCCESS, task_ex.state)
def test_with_items_concurrency_2(self):
workflow_with_concurrency_2 = """---
version: "2.0"
concurrency_test:
type: direct
input:
- names: ["John", "Ivan", "Mistral", "Hello"]
tasks:
task1:
action: std.async_noop
with-items: name in <% $.names %>
concurrency: 2
"""
wf_service.create_workflows(workflow_with_concurrency_2)
# Start workflow.
wf_ex = self.engine.start_workflow('concurrency_test', {})
wf_ex = db_api.get_execution(wf_ex.id)
task_ex = wf_ex.task_executions[0]
self.assert_capacity(0, task_ex)
self.assertEqual(2, self.get_running_action_exs_number(task_ex))
# 1st iteration complete.
self.engine.on_action_complete(
self.get_incomplete_action_ex(task_ex).id,
wf_utils.Result("John")
)
task_ex = db_api.get_task_execution(task_ex.id)
self.assert_capacity(0, task_ex)
self.assertEqual(2, self.get_running_action_exs_number(task_ex))
# 2nd iteration complete.
self.engine.on_action_complete(
self.get_incomplete_action_ex(task_ex).id,
wf_utils.Result("Ivan")
)
task_ex = db_api.get_task_execution(task_ex.id)
self.assert_capacity(0, task_ex)
self.assertEqual(2, self.get_running_action_exs_number(task_ex))
# 3rd iteration complete.
self.engine.on_action_complete(
self.get_incomplete_action_ex(task_ex).id,
wf_utils.Result("Mistral")
)
task_ex = db_api.get_task_execution(task_ex.id)
self.assert_capacity(1, task_ex)
# 4th iteration complete.
self.engine.on_action_complete(
self.get_incomplete_action_ex(task_ex).id,
wf_utils.Result("Hello")
)
task_ex = db_api.get_task_execution(task_ex.id)
self.assert_capacity(2, task_ex)
self._await(
lambda: self.is_execution_success(wf_ex.id),
)
task_ex = db_api.get_task_execution(task_ex.id)
# Since we know that we can receive results in random order,
# check is not depend on order of items.
result = data_flow.get_task_execution_result(task_ex)
self.assertTrue(isinstance(result, list))
self.assertIn('John', result)
self.assertIn('Ivan', result)
self.assertIn('Mistral', result)
self.assertIn('Hello', result)
self.assertEqual(states.SUCCESS, task_ex.state)
def test_with_items_concurrency_2_fail(self):
workflow_with_concurrency_2_fail = """---
version: "2.0"
concurrency_test_fail:
type: direct
tasks:
task1:
with-items: i in [1, 2, 3, 4]
action: std.fail
concurrency: 2
on-error: task2
task2:
action: std.echo output="With-items failed"
"""
wf_service.create_workflows(workflow_with_concurrency_2_fail)
# Start workflow.
wf_ex = self.engine.start_workflow('concurrency_test_fail', {})
self._await(
lambda: self.is_execution_success(wf_ex.id),
)
wf_ex = db_api.get_execution(wf_ex.id)
task_exs = wf_ex.task_executions
self.assertEqual(2, len(task_exs))
task_2 = self._assert_single_item(task_exs, name='task2')
self.assertEqual(
"With-items failed",
data_flow.get_task_execution_result(task_2)
)
def test_with_items_concurrency_3(self):
workflow_with_concurrency_3 = """---
version: "2.0"
concurrency_test:
type: direct
input:
- names: ["John", "Ivan", "Mistral"]
tasks:
task1:
action: std.async_noop
with-items: name in <% $.names %>
concurrency: 3
"""
wf_service.create_workflows(workflow_with_concurrency_3)
# Start workflow.
wf_ex = self.engine.start_workflow('concurrency_test', {})
wf_ex = db_api.get_execution(wf_ex.id)
task_ex = wf_ex.task_executions[0]
self.assert_capacity(0, task_ex)
self.assertEqual(3, self.get_running_action_exs_number(task_ex))
# 1st iteration complete.
self.engine.on_action_complete(
self.get_incomplete_action_ex(task_ex).id,
wf_utils.Result("John")
)
task_ex = db_api.get_task_execution(task_ex.id)
self.assert_capacity(1, task_ex)
# 2nd iteration complete.
self.engine.on_action_complete(
self.get_incomplete_action_ex(task_ex).id,
wf_utils.Result("Ivan")
)
task_ex = db_api.get_task_execution(task_ex.id)
self.assert_capacity(2, task_ex)
# 3rd iteration complete.
self.engine.on_action_complete(
self.get_incomplete_action_ex(task_ex).id,
wf_utils.Result("Mistral")
)
task_ex = db_api.get_task_execution(task_ex.id)
self.assert_capacity(3, task_ex)
self._await(
lambda: self.is_execution_success(wf_ex.id),
)
task_ex = db_api.get_task_execution(task_ex.id)
# Since we know that we can receive results in random order,
# check is not depend on order of items.
result = data_flow.get_task_execution_result(task_ex)
self.assertTrue(isinstance(result, list))
self.assertIn('John', result)
self.assertIn('Ivan', result)
self.assertIn('Mistral', result)
self.assertEqual(states.SUCCESS, task_ex.state)

View File

@ -0,0 +1,55 @@
# Copyright 2015 - Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mistral.db.v2.sqlalchemy import models
from mistral.tests import base
from mistral.workflow import states
from mistral.workflow import with_items
class WithItemsTest(base.BaseTest):
@staticmethod
def get_action_ex(accepted, state, index):
return models.ActionExecution(
accepted=accepted,
state=state,
runtime_context={
'with_items_index': index
}
)
def test_get_indices(self):
# Task execution for running 6 items with concurrency=3.
task_ex = models.TaskExecution(
runtime_context={
'with_items_context': {
'capacity': 3,
'count': 6
}
},
executions=[]
)
# Set 3 items: 2 success and 1 error unaccepted.
task_ex.executions += [
self.get_action_ex(True, states.SUCCESS, 0),
self.get_action_ex(True, states.SUCCESS, 1),
self.get_action_ex(False, states.ERROR, 2)
]
# Then call get_indices and expect [2, 3, 4].
indices = with_items.get_indices_for_loop(task_ex)
self.assertListEqual([2, 3, 4], indices)

View File

@ -12,20 +12,42 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import six
from mistral.db.v2 import api as db_api
from mistral import exceptions as exc from mistral import exceptions as exc
from mistral.workflow import states from mistral.workflow import states
_CAPACITY = 'capacity'
_COUNT = 'count'
_WITH_ITEMS = 'with_items_context'
def _get_context(task_ex): def _get_context(task_ex):
return task_ex.runtime_context['with_items'] return task_ex.runtime_context[_WITH_ITEMS]
def get_count(task_ex): def get_count(task_ex):
return _get_context(task_ex)['count'] return _get_context(task_ex)[_COUNT]
def is_completed(task_ex):
action_exs = db_api.get_action_executions(
task_execution_id=task_ex.id,
accepted=True
)
count = get_count(task_ex) or 1
return count == len(action_exs)
def get_index(task_ex): def get_index(task_ex):
return _get_context(task_ex)['index'] return len(
filter(
lambda x: x.accepted or states.RUNNING, task_ex.executions
)
)
def get_concurrency_spec(task_spec): def get_concurrency_spec(task_spec):
@ -34,38 +56,88 @@ def get_concurrency_spec(task_spec):
return policies.get_concurrency() if policies else None return policies.get_concurrency() if policies else None
def get_indexes_for_loop(task_ex, task_spec): def get_final_state(task_ex):
concurrency_spec = get_concurrency_spec(task_spec) find_error = lambda x: x.accepted and x.state == states.ERROR
concurrency = task_ex.runtime_context['concurrency']
if filter(find_error, task_ex.executions):
return states.ERROR
else:
return states.SUCCESS
def _get_indices_if_rerun(unaccepted_executions):
"""Returns a list of indices in case of re-running with-items.
:param unaccepted_executions: List of executions.
:return: a list of numbers.
"""
return [
ex.runtime_context['with_items_index']
for ex in unaccepted_executions
]
def _get_unaccepted_act_exs(task_ex):
# Choose only if not accepted but completed.
return filter(
lambda x: not x.accepted and states.is_completed(x.state),
task_ex.executions
)
def get_indices_for_loop(task_ex):
capacity = _get_context(task_ex)[_CAPACITY]
count = get_count(task_ex)
unaccepted = _get_unaccepted_act_exs(task_ex)
if unaccepted:
indices = _get_indices_if_rerun(unaccepted)
if max(indices) < count - 1:
indices += list(six.moves.range(max(indices) + 1, count))
return indices[:capacity] if capacity else indices
index = get_index(task_ex) index = get_index(task_ex)
number_to_execute = (get_count(task_ex) - index number_to_execute = capacity if capacity else count - index
if not concurrency_spec else concurrency)
return index, index + number_to_execute return list(six.moves.range(index, index + number_to_execute))
def do_step(task_ex): def decrease_capacity(task_ex, count):
with_items_context = _get_context(task_ex) with_items_context = _get_context(task_ex)
if with_items_context['capacity'] > 0: if with_items_context[_CAPACITY] >= count:
with_items_context['capacity'] -= 1 with_items_context[_CAPACITY] -= count
elif with_items_context[_CAPACITY]:
raise exc.WorkflowException(
"Impossible to apply current with-items concurrency."
)
with_items_context['index'] += 1 task_ex.runtime_context.update({_WITH_ITEMS: with_items_context})
task_ex.runtime_context.update({'with_items': with_items_context})
def increase_capacity(task_ex):
with_items_context = _get_context(task_ex)
max_concurrency = task_ex.runtime_context.get('concurrency')
if max_concurrency and with_items_context[_CAPACITY] < max_concurrency:
with_items_context[_CAPACITY] += 1
task_ex.runtime_context.update({_WITH_ITEMS: with_items_context})
def prepare_runtime_context(task_ex, task_spec, input_dicts): def prepare_runtime_context(task_ex, task_spec, input_dicts):
runtime_context = task_ex.runtime_context runtime_context = task_ex.runtime_context
with_items_spec = task_spec.get_with_items() with_items_spec = task_spec.get_with_items()
if with_items_spec: if with_items_spec and not runtime_context.get(_WITH_ITEMS):
# Prepare current indexes and parallel limitation. # Prepare current indexes and parallel limitation.
runtime_context['with_items'] = { runtime_context[_WITH_ITEMS] = {
'capacity': get_concurrency_spec(task_spec), _CAPACITY: get_concurrency_spec(task_spec) or None,
'index': 0, _COUNT: len(input_dicts)
'count': len(input_dicts)
} }
@ -88,7 +160,12 @@ def validate_input(with_items_input):
) )
def iterations_completed(task_ex): def has_more_iterations(task_ex):
completed = all([states.is_completed(ex.state) # See action executions which have been already
for ex in task_ex.executions]) # accepted or are still running.
return completed action_exs = filter(
lambda x: x.accepted or x.state == states.RUNNING,
task_ex.executions
)
return get_count(task_ex) > len(action_exs)