diff --git a/mistral/engine/default_engine.py b/mistral/engine/default_engine.py index 181a81c0..451ac1cb 100644 --- a/mistral/engine/default_engine.py +++ b/mistral/engine/default_engine.py @@ -142,9 +142,7 @@ class DefaultEngine(base.Engine, coordination.Service): task_ex = db_api.get_task_execution(task_ex_id) # TODO(rakhmerov): The method is mostly needed for policy and # we are supposed to get the same action execution as when the - # policy worked. But by the moment this method is called the - # last execution object may have changed. It's a race condition. - execution = task_ex.executions[-1] + # policy worked. wf_ex_id = task_ex.workflow_execution_id @@ -161,13 +159,13 @@ class DefaultEngine(base.Engine, coordination.Service): task_ex.state = state - self._on_task_state_change(task_ex, wf_ex, action_ex=execution) + self._on_task_state_change(task_ex, wf_ex) - def _on_task_state_change(self, task_ex, wf_ex, action_ex=None): + def _on_task_state_change(self, task_ex, wf_ex): task_spec = spec_parser.get_task_spec(task_ex.spec) wf_spec = spec_parser.get_workflow_spec(wf_ex.spec) - if states.is_completed(task_ex.state): + if task_handler.is_task_completed(task_ex, task_spec): task_handler.after_task_complete(task_ex, task_spec, wf_spec) # Ignore DELAYED state. @@ -184,6 +182,11 @@ class DefaultEngine(base.Engine, coordination.Service): self._dispatch_workflow_commands(wf_ex, cmds) self._check_workflow_completion(wf_ex, wf_ctrl) + elif task_handler.need_to_continue(task_ex, task_spec): + # Re-run existing task. + cmds = [commands.RunExistingTask(task_ex, reset=False)] + + self._dispatch_workflow_commands(wf_ex, cmds) @staticmethod def _check_workflow_completion(wf_ex, wf_ctrl): @@ -233,7 +236,7 @@ class DefaultEngine(base.Engine, coordination.Service): if states.is_paused_or_completed(wf_ex.state): return action_ex - self._on_task_state_change(task_ex, wf_ex, action_ex) + self._on_task_state_change(task_ex, wf_ex) return action_ex.get_clone() except Exception as e: diff --git a/mistral/engine/task_handler.py b/mistral/engine/task_handler.py index 26a750f4..33bd4a84 100644 --- a/mistral/engine/task_handler.py +++ b/mistral/engine/task_handler.py @@ -14,6 +14,7 @@ # limitations under the License. import copy +import operator from oslo_log import log as logging @@ -62,18 +63,20 @@ def run_existing_task(task_ex_id, reset=True): not task_spec.get_with_items()): return - # Reset state of processed task and related action executions. - if reset: - action_exs = task_ex.executions - else: - action_exs = db_api.get_action_executions( - task_execution_id=task_ex.id, - state=states.ERROR, - accepted=True - ) + # Reset nested executions only if task is not already RUNNING. + if task_ex.state != states.RUNNING: + # Reset state of processed task and related action executions. + if reset: + action_exs = task_ex.executions + else: + action_exs = db_api.get_action_executions( + task_execution_id=task_ex.id, + state=states.ERROR, + accepted=True + ) - for action_ex in action_exs: - action_ex.accepted = False + for action_ex in action_exs: + action_ex.accepted = False # Explicitly change task state to RUNNING. task_ex.state = states.RUNNING @@ -90,27 +93,10 @@ def _run_existing_task(task_ex, task_spec, wf_spec): task_ex.in_context ) - # TODO(rakhmerov): May be it shouldn't be here. Need to think. - if task_spec.get_with_items(): - with_items.prepare_runtime_context(task_ex, task_spec, input_dicts) - - action_exs = db_api.get_action_executions( - task_execution_id=task_ex.id, - state=states.SUCCESS, - accepted=True - ) - - with_items_indices = [ - action_ex.runtime_context['with_items_index'] - for action_ex in action_exs - if 'with_items_index' in action_ex.runtime_context - ] - # In some cases we can have no input, e.g. in case of 'with-items'. if input_dicts: - for index, input_d in enumerate(input_dicts): - if index not in with_items_indices: - _run_action_or_workflow(task_ex, task_spec, input_d, index) + for index, input_d in input_dicts: + _run_action_or_workflow(task_ex, task_spec, input_d, index) else: _schedule_noop_action(task_ex, task_spec) @@ -200,9 +186,13 @@ def on_action_complete(action_ex, result): if not task_spec.get_with_items(): _complete_task(task_ex, task_spec, task_state) else: - if (task_state == states.ERROR or - with_items.iterations_completed(task_ex)): - _complete_task(task_ex, task_spec, task_state) + with_items.increase_capacity(task_ex) + if with_items.is_completed(task_ex): + _complete_task( + task_ex, + task_spec, + with_items.get_final_state(task_ex) + ) return task_ex @@ -245,6 +235,9 @@ def _get_input_dictionaries(wf_spec, task_ex, task_spec, ctx): should run with. In case of 'with-items' the result list will contain input dictionaries for all 'with-items' iterations correspondingly. + + :return the list of tuples containing indexes + and the corresponding input dict. """ # TODO(rakhmerov): Think how to get rid of this. ctx = data_flow.extract_task_result_proxies_to_context(ctx) @@ -257,7 +250,7 @@ def _get_input_dictionaries(wf_spec, task_ex, task_spec, ctx): ctx ) - return [input_dict] + return enumerate([input_dict]) else: return _get_with_items_input(wf_spec, task_ex, task_spec, ctx) @@ -297,7 +290,8 @@ def _get_with_items_input(wf_spec, task_ex, task_spec, ctx): {'itemX': 2, 'itemY': 'b'} ] - :return: list containing dicts of each action input. + :return: the list of tuples containing indexes + and the corresponding input dict. """ with_items_inputs = expr.evaluate_recursively( task_spec.get_with_items(), ctx @@ -325,7 +319,21 @@ def _get_with_items_input(wf_spec, task_ex, task_spec, ctx): wf_spec, task_ex, task_spec, new_ctx )) - return action_inputs + with_items.prepare_runtime_context(task_ex, task_spec, action_inputs) + + indices = with_items.get_indices_for_loop(task_ex) + with_items.decrease_capacity(task_ex, len(indices)) + + if indices: + current_inputs = operator.itemgetter(*indices)(action_inputs) + + return zip( + indices, + current_inputs if isinstance(current_inputs, tuple) + else [current_inputs] + ) + + return [] def _get_action_input(wf_spec, task_ex, task_spec, ctx): @@ -507,3 +515,19 @@ def _set_task_state(task_ex, state): ) task_ex.state = state + + +def is_task_completed(task_ex, task_spec): + if task_spec.get_with_items(): + return with_items.is_completed(task_ex) + + return states.is_completed(task_ex.state) + + +def need_to_continue(task_ex, task_spec): + # For now continue is available only for with-items. + if task_spec.get_with_items(): + return (with_items.has_more_iterations(task_ex) + and with_items.get_concurrency_spec(task_spec)) + + return False diff --git a/mistral/tests/unit/engine/test_direct_workflow_rerun.py b/mistral/tests/unit/engine/test_direct_workflow_rerun.py index 3c297d7d..62e5d6c4 100644 --- a/mistral/tests/unit/engine/test_direct_workflow_rerun.py +++ b/mistral/tests/unit/engine/test_direct_workflow_rerun.py @@ -67,6 +67,28 @@ workflows: action: std.echo output="Task 2" """ + +WITH_ITEMS_WORKBOOK_CONCURRENCY = """ +--- +version: '2.0' +name: wb3 +workflows: + wf1: + type: direct + tasks: + t1: + with-items: i in <% list(range(0, 4)) %> + action: std.echo output="Task 1.<% $.i %>" + concurrency: 2 + publish: + v1: <% $.t1 %> + on-success: + - t2 + t2: + action: std.echo output="Task 2" +""" + + JOIN_WORKBOOK = """ --- version: '2.0' @@ -271,6 +293,77 @@ class DirectWorkflowRerunTest(base.EngineTestCase): self.assertEqual(1, len(task_2_action_exs)) + @mock.patch.object( + std_actions.EchoAction, + 'run', + mock.MagicMock( + side_effect=[ + exc.ActionException(), # Mock task1 exception for initial run. + 'Task 1.1', # Mock task1 success for initial run. + exc.ActionException(), # Mock task1 exception for initial run. + 'Task 1.3', # Mock task1 success for initial run. + 'Task 1.0', # Mock task1 success for rerun. + 'Task 1.2', # Mock task1 success for rerun. + 'Task 2' # Mock task2 success. + ] + ) + ) + def test_rerun_with_items_concurrency(self): + wb_service.create_workbook_v2(WITH_ITEMS_WORKBOOK_CONCURRENCY) + + # Run workflow and fail task. + wf_ex = self.engine.start_workflow('wb3.wf1', {}) + self._await(lambda: self.is_execution_error(wf_ex.id)) + + wf_ex = db_api.get_workflow_execution(wf_ex.id) + + self.assertEqual(states.ERROR, wf_ex.state) + self.assertEqual(1, len(wf_ex.task_executions)) + + task_1_ex = self._assert_single_item(wf_ex.task_executions, name='t1') + + self.assertEqual(states.ERROR, task_1_ex.state) + + task_1_action_exs = db_api.get_action_executions( + task_execution_id=task_1_ex.id + ) + + self.assertEqual(4, len(task_1_action_exs)) + + # Resume workflow and re-run failed task. + self.engine.rerun_workflow(wf_ex.id, task_1_ex.id, reset=False) + wf_ex = db_api.get_workflow_execution(wf_ex.id) + + self.assertEqual(states.RUNNING, wf_ex.state) + + self._await(lambda: self.is_execution_success(wf_ex.id), delay=10) + wf_ex = db_api.get_workflow_execution(wf_ex.id) + + self.assertEqual(states.SUCCESS, wf_ex.state) + self.assertEqual(2, len(wf_ex.task_executions)) + + task_1_ex = self._assert_single_item(wf_ex.task_executions, name='t1') + task_2_ex = self._assert_single_item(wf_ex.task_executions, name='t2') + + # Check action executions of task 1. + self.assertEqual(states.SUCCESS, task_1_ex.state) + + task_1_action_exs = db_api.get_action_executions( + task_execution_id=task_1_ex.id) + + # The action executions that succeeded should not re-run. + self.assertEqual(6, len(task_1_action_exs)) + self.assertListEqual(['Task 1.0', 'Task 1.1', 'Task 1.2', 'Task 1.3'], + task_1_ex.published.get('v1')) + + # Check action executions of task 2. + self.assertEqual(states.SUCCESS, task_2_ex.state) + + task_2_action_exs = db_api.get_action_executions( + task_execution_id=task_2_ex.id) + + self.assertEqual(1, len(task_2_action_exs)) + @mock.patch.object( std_actions.EchoAction, 'run', diff --git a/mistral/tests/unit/engine/test_with_items.py b/mistral/tests/unit/engine/test_with_items.py index cb322d5c..0c7d79ca 100644 --- a/mistral/tests/unit/engine/test_with_items.py +++ b/mistral/tests/unit/engine/test_with_items.py @@ -20,6 +20,7 @@ from mistral.actions import base as action_base from mistral.db.v2 import api as db_api from mistral import exceptions as exc from mistral.services import workbooks as wb_service +from mistral.services import workflows as wf_service from mistral.tests import base as test_base from mistral.tests.unit.engine import base from mistral import utils @@ -161,6 +162,21 @@ class RandomSleepEchoAction(action_base.Action): class WithItemsEngineTest(base.EngineTestCase): + def assert_capacity(self, capacity, task_ex): + self.assertEqual( + capacity, + task_ex.runtime_context['with_items_context']['capacity'] + ) + + @staticmethod + def get_incomplete_action_ex(task_ex): + return [ex for ex in task_ex.executions if not ex.accepted][0] + + @staticmethod + def get_running_action_exs_number(task_ex): + return len([ex for ex in task_ex.executions + if ex.state == states.RUNNING]) + def test_with_items_simple(self): wb_service.create_workbook_v2(WORKBOOK) @@ -176,7 +192,7 @@ class WithItemsEngineTest(base.EngineTestCase): tasks = wf_ex.task_executions task1 = self._assert_single_item(tasks, name='task1') - with_items_context = task1.runtime_context['with_items'] + with_items_context = task1.runtime_context['with_items_context'] self.assertEqual(3, with_items_context['count']) @@ -197,6 +213,37 @@ class WithItemsEngineTest(base.EngineTestCase): self.assertEqual(1, len(tasks)) self.assertEqual(states.SUCCESS, task1.state) + def test_with_items_fail(self): + workflow = """--- + version: "2.0" + + with_items: + type: direct + + tasks: + task1: + with-items: i in [1, 2, 3] + action: std.fail + on-error: task2 + + task2: + action: std.echo output="With-items failed" + """ + wf_service.create_workflows(workflow) + + # Start workflow. + wf_ex = self.engine.start_workflow('with_items', {}) + + self._await( + lambda: self.is_execution_success(wf_ex.id), + ) + + # Note: We need to reread execution to access related tasks. + wf_ex = db_api.get_workflow_execution(wf_ex.id) + + tasks = wf_ex.task_executions + self.assertEqual(2, len(tasks)) + def test_with_items_static_var(self): wb_service.create_workbook_v2(WORKBOOK_WITH_STATIC_VAR) @@ -468,3 +515,268 @@ class WithItemsEngineTest(base.EngineTestCase): self.assertEqual(1, len(tasks)) self.assertEqual(states.SUCCESS, task1.state) + + def test_with_items_concurrency_1(self): + workflow_with_concurrency_1 = """--- + version: "2.0" + + concurrency_test: + type: direct + + input: + - names: ["John", "Ivan", "Mistral"] + + tasks: + task1: + action: std.async_noop + with-items: name in <% $.names %> + concurrency: 1 + + """ + wf_service.create_workflows(workflow_with_concurrency_1) + + # Start workflow. + wf_ex = self.engine.start_workflow('concurrency_test', {}) + wf_ex = db_api.get_execution(wf_ex.id) + task_ex = wf_ex.task_executions[0] + + task_ex = db_api.get_task_execution(task_ex.id) + self.assert_capacity(0, task_ex) + self.assertEqual(1, self.get_running_action_exs_number(task_ex)) + + # 1st iteration complete. + self.engine.on_action_complete( + self.get_incomplete_action_ex(task_ex).id, + wf_utils.Result("John") + ) + + task_ex = db_api.get_task_execution(task_ex.id) + self.assert_capacity(0, task_ex) + self.assertEqual(1, self.get_running_action_exs_number(task_ex)) + + # 2nd iteration complete. + self.engine.on_action_complete( + self.get_incomplete_action_ex(task_ex).id, + wf_utils.Result("Ivan") + ) + + task_ex = db_api.get_task_execution(task_ex.id) + self.assert_capacity(0, task_ex) + self.assertEqual(1, self.get_running_action_exs_number(task_ex)) + + # 3rd iteration complete. + self.engine.on_action_complete( + self.get_incomplete_action_ex(task_ex).id, + wf_utils.Result("Mistral") + ) + + task_ex = db_api.get_task_execution(task_ex.id) + self.assert_capacity(1, task_ex) + + self._await( + lambda: self.is_execution_success(wf_ex.id), + ) + + task_ex = db_api.get_task_execution(task_ex.id) + # Since we know that we can receive results in random order, + # check is not depend on order of items. + result = data_flow.get_task_execution_result(task_ex) + + self.assertTrue(isinstance(result, list)) + + self.assertIn('John', result) + self.assertIn('Ivan', result) + self.assertIn('Mistral', result) + + self.assertEqual(states.SUCCESS, task_ex.state) + + def test_with_items_concurrency_2(self): + workflow_with_concurrency_2 = """--- + version: "2.0" + + concurrency_test: + type: direct + + input: + - names: ["John", "Ivan", "Mistral", "Hello"] + + tasks: + task1: + action: std.async_noop + with-items: name in <% $.names %> + concurrency: 2 + + """ + wf_service.create_workflows(workflow_with_concurrency_2) + + # Start workflow. + wf_ex = self.engine.start_workflow('concurrency_test', {}) + wf_ex = db_api.get_execution(wf_ex.id) + task_ex = wf_ex.task_executions[0] + + self.assert_capacity(0, task_ex) + self.assertEqual(2, self.get_running_action_exs_number(task_ex)) + + # 1st iteration complete. + self.engine.on_action_complete( + self.get_incomplete_action_ex(task_ex).id, + wf_utils.Result("John") + ) + + task_ex = db_api.get_task_execution(task_ex.id) + self.assert_capacity(0, task_ex) + self.assertEqual(2, self.get_running_action_exs_number(task_ex)) + + # 2nd iteration complete. + self.engine.on_action_complete( + self.get_incomplete_action_ex(task_ex).id, + wf_utils.Result("Ivan") + ) + + task_ex = db_api.get_task_execution(task_ex.id) + self.assert_capacity(0, task_ex) + self.assertEqual(2, self.get_running_action_exs_number(task_ex)) + + # 3rd iteration complete. + self.engine.on_action_complete( + self.get_incomplete_action_ex(task_ex).id, + wf_utils.Result("Mistral") + ) + + task_ex = db_api.get_task_execution(task_ex.id) + self.assert_capacity(1, task_ex) + + # 4th iteration complete. + self.engine.on_action_complete( + self.get_incomplete_action_ex(task_ex).id, + wf_utils.Result("Hello") + ) + + task_ex = db_api.get_task_execution(task_ex.id) + self.assert_capacity(2, task_ex) + + self._await( + lambda: self.is_execution_success(wf_ex.id), + ) + + task_ex = db_api.get_task_execution(task_ex.id) + # Since we know that we can receive results in random order, + # check is not depend on order of items. + result = data_flow.get_task_execution_result(task_ex) + self.assertTrue(isinstance(result, list)) + + self.assertIn('John', result) + self.assertIn('Ivan', result) + self.assertIn('Mistral', result) + self.assertIn('Hello', result) + + self.assertEqual(states.SUCCESS, task_ex.state) + + def test_with_items_concurrency_2_fail(self): + workflow_with_concurrency_2_fail = """--- + version: "2.0" + + concurrency_test_fail: + type: direct + + tasks: + task1: + with-items: i in [1, 2, 3, 4] + action: std.fail + concurrency: 2 + on-error: task2 + + task2: + action: std.echo output="With-items failed" + + """ + wf_service.create_workflows(workflow_with_concurrency_2_fail) + + # Start workflow. + wf_ex = self.engine.start_workflow('concurrency_test_fail', {}) + + self._await( + lambda: self.is_execution_success(wf_ex.id), + ) + wf_ex = db_api.get_execution(wf_ex.id) + + task_exs = wf_ex.task_executions + + self.assertEqual(2, len(task_exs)) + + task_2 = self._assert_single_item(task_exs, name='task2') + + self.assertEqual( + "With-items failed", + data_flow.get_task_execution_result(task_2) + ) + + def test_with_items_concurrency_3(self): + workflow_with_concurrency_3 = """--- + version: "2.0" + + concurrency_test: + type: direct + + input: + - names: ["John", "Ivan", "Mistral"] + + tasks: + task1: + action: std.async_noop + with-items: name in <% $.names %> + concurrency: 3 + + """ + wf_service.create_workflows(workflow_with_concurrency_3) + + # Start workflow. + wf_ex = self.engine.start_workflow('concurrency_test', {}) + wf_ex = db_api.get_execution(wf_ex.id) + task_ex = wf_ex.task_executions[0] + + self.assert_capacity(0, task_ex) + self.assertEqual(3, self.get_running_action_exs_number(task_ex)) + + # 1st iteration complete. + self.engine.on_action_complete( + self.get_incomplete_action_ex(task_ex).id, + wf_utils.Result("John") + ) + + task_ex = db_api.get_task_execution(task_ex.id) + self.assert_capacity(1, task_ex) + + # 2nd iteration complete. + self.engine.on_action_complete( + self.get_incomplete_action_ex(task_ex).id, + wf_utils.Result("Ivan") + ) + + task_ex = db_api.get_task_execution(task_ex.id) + self.assert_capacity(2, task_ex) + + # 3rd iteration complete. + self.engine.on_action_complete( + self.get_incomplete_action_ex(task_ex).id, + wf_utils.Result("Mistral") + ) + + task_ex = db_api.get_task_execution(task_ex.id) + self.assert_capacity(3, task_ex) + + self._await( + lambda: self.is_execution_success(wf_ex.id), + ) + + task_ex = db_api.get_task_execution(task_ex.id) + # Since we know that we can receive results in random order, + # check is not depend on order of items. + result = data_flow.get_task_execution_result(task_ex) + self.assertTrue(isinstance(result, list)) + + self.assertIn('John', result) + self.assertIn('Ivan', result) + self.assertIn('Mistral', result) + + self.assertEqual(states.SUCCESS, task_ex.state) diff --git a/mistral/tests/unit/workflow/test_with_items.py b/mistral/tests/unit/workflow/test_with_items.py new file mode 100644 index 00000000..770322f2 --- /dev/null +++ b/mistral/tests/unit/workflow/test_with_items.py @@ -0,0 +1,55 @@ +# Copyright 2015 - Mirantis, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from mistral.db.v2.sqlalchemy import models +from mistral.tests import base +from mistral.workflow import states +from mistral.workflow import with_items + + +class WithItemsTest(base.BaseTest): + @staticmethod + def get_action_ex(accepted, state, index): + return models.ActionExecution( + accepted=accepted, + state=state, + runtime_context={ + 'with_items_index': index + } + ) + + def test_get_indices(self): + # Task execution for running 6 items with concurrency=3. + task_ex = models.TaskExecution( + runtime_context={ + 'with_items_context': { + 'capacity': 3, + 'count': 6 + } + }, + executions=[] + ) + + # Set 3 items: 2 success and 1 error unaccepted. + task_ex.executions += [ + self.get_action_ex(True, states.SUCCESS, 0), + self.get_action_ex(True, states.SUCCESS, 1), + self.get_action_ex(False, states.ERROR, 2) + ] + + # Then call get_indices and expect [2, 3, 4]. + indices = with_items.get_indices_for_loop(task_ex) + + self.assertListEqual([2, 3, 4], indices) diff --git a/mistral/workflow/with_items.py b/mistral/workflow/with_items.py index f3cbe5c2..7797719f 100644 --- a/mistral/workflow/with_items.py +++ b/mistral/workflow/with_items.py @@ -12,20 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. +import six + +from mistral.db.v2 import api as db_api from mistral import exceptions as exc from mistral.workflow import states +_CAPACITY = 'capacity' +_COUNT = 'count' +_WITH_ITEMS = 'with_items_context' + + def _get_context(task_ex): - return task_ex.runtime_context['with_items'] + return task_ex.runtime_context[_WITH_ITEMS] def get_count(task_ex): - return _get_context(task_ex)['count'] + return _get_context(task_ex)[_COUNT] + + +def is_completed(task_ex): + action_exs = db_api.get_action_executions( + task_execution_id=task_ex.id, + accepted=True + ) + count = get_count(task_ex) or 1 + + return count == len(action_exs) def get_index(task_ex): - return _get_context(task_ex)['index'] + return len( + filter( + lambda x: x.accepted or states.RUNNING, task_ex.executions + ) + ) def get_concurrency_spec(task_spec): @@ -34,38 +56,88 @@ def get_concurrency_spec(task_spec): return policies.get_concurrency() if policies else None -def get_indexes_for_loop(task_ex, task_spec): - concurrency_spec = get_concurrency_spec(task_spec) - concurrency = task_ex.runtime_context['concurrency'] +def get_final_state(task_ex): + find_error = lambda x: x.accepted and x.state == states.ERROR + + if filter(find_error, task_ex.executions): + return states.ERROR + else: + return states.SUCCESS + + +def _get_indices_if_rerun(unaccepted_executions): + """Returns a list of indices in case of re-running with-items. + + :param unaccepted_executions: List of executions. + :return: a list of numbers. + """ + + return [ + ex.runtime_context['with_items_index'] + for ex in unaccepted_executions + ] + + +def _get_unaccepted_act_exs(task_ex): + # Choose only if not accepted but completed. + return filter( + lambda x: not x.accepted and states.is_completed(x.state), + task_ex.executions + ) + + +def get_indices_for_loop(task_ex): + capacity = _get_context(task_ex)[_CAPACITY] + count = get_count(task_ex) + + unaccepted = _get_unaccepted_act_exs(task_ex) + + if unaccepted: + indices = _get_indices_if_rerun(unaccepted) + + if max(indices) < count - 1: + indices += list(six.moves.range(max(indices) + 1, count)) + + return indices[:capacity] if capacity else indices + index = get_index(task_ex) - number_to_execute = (get_count(task_ex) - index - if not concurrency_spec else concurrency) + number_to_execute = capacity if capacity else count - index - return index, index + number_to_execute + return list(six.moves.range(index, index + number_to_execute)) -def do_step(task_ex): +def decrease_capacity(task_ex, count): with_items_context = _get_context(task_ex) - if with_items_context['capacity'] > 0: - with_items_context['capacity'] -= 1 + if with_items_context[_CAPACITY] >= count: + with_items_context[_CAPACITY] -= count + elif with_items_context[_CAPACITY]: + raise exc.WorkflowException( + "Impossible to apply current with-items concurrency." + ) - with_items_context['index'] += 1 + task_ex.runtime_context.update({_WITH_ITEMS: with_items_context}) - task_ex.runtime_context.update({'with_items': with_items_context}) + +def increase_capacity(task_ex): + with_items_context = _get_context(task_ex) + max_concurrency = task_ex.runtime_context.get('concurrency') + + if max_concurrency and with_items_context[_CAPACITY] < max_concurrency: + with_items_context[_CAPACITY] += 1 + task_ex.runtime_context.update({_WITH_ITEMS: with_items_context}) def prepare_runtime_context(task_ex, task_spec, input_dicts): runtime_context = task_ex.runtime_context with_items_spec = task_spec.get_with_items() - if with_items_spec: + if with_items_spec and not runtime_context.get(_WITH_ITEMS): # Prepare current indexes and parallel limitation. - runtime_context['with_items'] = { - 'capacity': get_concurrency_spec(task_spec), - 'index': 0, - 'count': len(input_dicts) + runtime_context[_WITH_ITEMS] = { + _CAPACITY: get_concurrency_spec(task_spec) or None, + _COUNT: len(input_dicts) } @@ -88,7 +160,12 @@ def validate_input(with_items_input): ) -def iterations_completed(task_ex): - completed = all([states.is_completed(ex.state) - for ex in task_ex.executions]) - return completed +def has_more_iterations(task_ex): + # See action executions which have been already + # accepted or are still running. + action_exs = filter( + lambda x: x.accepted or x.state == states.RUNNING, + task_ex.executions + ) + + return get_count(task_ex) > len(action_exs)