diff --git a/mistral/db/v2/api.py b/mistral/db/v2/api.py index 58f7d3ed2..ce386e159 100644 --- a/mistral/db/v2/api.py +++ b/mistral/db/v2/api.py @@ -312,6 +312,10 @@ def get_completed_task_executions(**kwargs): return IMPL.get_completed_task_executions(**kwargs) +def get_completed_task_executions_as_batches(**kwargs): + return IMPL.get_completed_task_executions_as_batches(**kwargs) + + def get_incomplete_task_executions(**kwargs): return IMPL.get_incomplete_task_executions(**kwargs) diff --git a/mistral/db/v2/sqlalchemy/api.py b/mistral/db/v2/sqlalchemy/api.py index 5d39579d3..1f6e0f776 100644 --- a/mistral/db/v2/sqlalchemy/api.py +++ b/mistral/db/v2/sqlalchemy/api.py @@ -876,6 +876,30 @@ def get_completed_task_executions(session=None, **kwargs): return query.all() +@b.session_aware() +def get_completed_task_executions_as_batches(session=None, **kwargs): + # NOTE: Using batch querying seriously allows to optimize memory + # consumption on operations when we need to iterate through + # a list of task executions and do some processing like merging + # their inbound contexts. If we don't use batches Mistral has to + # hold all the collection (that can be large) in memory. + # Using a generator that returns batches lets GC to collect a + # batch of task executions that has already been processed. + query = _get_completed_task_executions_query(kwargs) + + # Batch size 20 may be arguable but still seems reasonable: it's big + # enough to keep the total number of DB hops small (say for 100 tasks + # we'll need only 5) and small enough not to drastically increase + # memory footprint if the number of tasks is big like several hundreds. + batch_size = 20 + idx = 0 + + while idx < query.count(): + yield query.slice(idx, idx + batch_size).all() + + idx += batch_size + + def _get_incomplete_task_executions_query(kwargs): query = b.model_query(models.TaskExecution) diff --git a/mistral/tests/unit/engine/test_direct_workflow.py b/mistral/tests/unit/engine/test_direct_workflow.py index 0836fc1ef..2aed90fb7 100644 --- a/mistral/tests/unit/engine/test_direct_workflow.py +++ b/mistral/tests/unit/engine/test_direct_workflow.py @@ -945,9 +945,9 @@ class DirectWorkflowEngineTest(base.EngineTestCase): with db_api.transaction(): wf_ex = db_api.get_workflow_execution(wf_ex.id) - tasks = wf_ex.task_executions + task_execs = wf_ex.task_executions - self.assertEqual(task_cnt + 3, len(tasks)) + self.assertEqual(task_cnt + 3, len(task_execs)) - self._assert_single_item(tasks, name='task0') - self._assert_single_item(tasks, name='task{}'.format(task_cnt)) + self._assert_single_item(task_execs, name='task0') + self._assert_single_item(task_execs, name='task{}'.format(task_cnt)) diff --git a/mistral/workflow/direct_workflow.py b/mistral/workflow/direct_workflow.py index 98de02d24..7d9a16724 100644 --- a/mistral/workflow/direct_workflow.py +++ b/mistral/workflow/direct_workflow.py @@ -168,11 +168,15 @@ class DirectWorkflowController(base.WorkflowController): def evaluate_workflow_final_context(self): ctx = {} - for t_ex in self._find_end_task_executions(): - ctx = utils.merge_dicts( - ctx, - data_flow.evaluate_task_outbound_context(t_ex) - ) + for batch in self._find_end_task_executions_as_batches(): + if not batch: + break + + for t_ex in batch: + ctx = utils.merge_dicts( + ctx, + data_flow.evaluate_task_outbound_context(t_ex) + ) return ctx @@ -208,7 +212,7 @@ class DirectWorkflowController(base.WorkflowController): return True - def _find_end_task_executions(self): + def _find_end_task_executions_as_batches(self): def is_end_task(t_ex): try: return not self._has_outbound_tasks(t_ex) @@ -219,13 +223,15 @@ class DirectWorkflowController(base.WorkflowController): # of given task also. return True - return list( - filter( - is_end_task, - lookup_utils.find_completed_task_executions(self.wf_ex.id) - ) + batches = lookup_utils.find_completed_task_executions_as_batches( + self.wf_ex.id ) + for batch in batches: + yield list( + filter(is_end_task, batch) + ) + def _has_outbound_tasks(self, task_ex): # In order to determine if there are outbound tasks we just need # to calculate next task names (based on task outbound context) diff --git a/mistral/workflow/lookup_utils.py b/mistral/workflow/lookup_utils.py index 68de2c10b..ea145ee5e 100644 --- a/mistral/workflow/lookup_utils.py +++ b/mistral/workflow/lookup_utils.py @@ -151,6 +151,12 @@ def find_completed_task_executions(wf_ex_id): return db_api.get_completed_task_executions(workflow_execution_id=wf_ex_id) +def find_completed_task_executions_as_batches(wf_ex_id): + return db_api.get_completed_task_executions_as_batches( + workflow_execution_id=wf_ex_id + ) + + def get_task_execution_cache_size(): return len(_TASK_EX_CACHE)