diff --git a/taskflow/conductors/backends/impl_executor.py b/taskflow/conductors/backends/impl_executor.py index f8d624e9..d5b7b204 100644 --- a/taskflow/conductors/backends/impl_executor.py +++ b/taskflow/conductors/backends/impl_executor.py @@ -31,6 +31,7 @@ from taskflow.conductors import base from taskflow import exceptions as excp from taskflow.listeners import logging as logging_listener from taskflow import logging +from taskflow import states from taskflow.types import timing as tt from taskflow.utils import iter_utils from taskflow.utils import misc @@ -185,11 +186,22 @@ class ExecutorConductor(base.Conductor): 'engine': engine, 'conductor': self, } + + def _run_engine(): + has_suspended = False + for _state in engine.run_iter(): + if not has_suspended and self._wait_timeout.is_stopped(): + self._log.info("Conductor stopped, requesting " + "suspension of engine running " + "job %s", job) + engine.suspend() + has_suspended = True + try: for stage_func, event_name in [(engine.compile, 'compilation'), (engine.prepare, 'preparation'), (engine.validate, 'validation'), - (engine.run, 'running')]: + (_run_engine, 'running')]: self._notifier.notify("%s_start" % event_name, details) stage_func() self._notifier.notify("%s_end" % event_name, details) @@ -218,7 +230,11 @@ class ExecutorConductor(base.Conductor): "Job execution failed (consumption proceeding): %s", job, exc_info=True) else: - self._log.info("Job completed successfully: %s", job) + if engine.storage.get_flow_state() == states.SUSPENDED: + self._log.info("Job execution was suspended: %s", job) + consume = False + else: + self._log.info("Job completed successfully: %s", job) return consume def _try_finish_job(self, job, consume): diff --git a/taskflow/tests/unit/test_conductors.py b/taskflow/tests/unit/test_conductors.py index 6177f262..569e5933 100644 --- a/taskflow/tests/unit/test_conductors.py +++ b/taskflow/tests/unit/test_conductors.py @@ -53,6 +53,13 @@ def test_factory(blowup): return f +def sleep_factory(): + f = lf.Flow("test") + f.add(test_utils.SleepTask('test1')) + f.add(test_utils.ProgressingTask('test2')) + return f + + def test_store_factory(): f = lf.Flow("test") f.add(test_utils.TaskMultiArg('task1')) @@ -366,6 +373,52 @@ class ManyConductorTest(testscenarios.TestWithScenarios, self.assertIsNotNone(fd) self.assertEqual(st.SUCCESS, fd.state) + def test_stop_aborts_engine(self): + components = self.make_components() + components.conductor.connect() + consumed_event = threading.Event() + job_consumed_event = threading.Event() + job_abandoned_event = threading.Event() + running_start_event = threading.Event() + + def on_running_start(event, details): + running_start_event.set() + + def on_consume(state, details): + consumed_event.set() + + def on_job_consumed(event, details): + if event == 'job_consumed': + job_consumed_event.set() + + def on_job_abandoned(event, details): + if event == 'job_abandoned': + job_abandoned_event.set() + + components.board.notifier.register(base.REMOVAL, on_consume) + components.conductor.notifier.register("job_consumed", + on_job_consumed) + components.conductor.notifier.register("job_abandoned", + on_job_abandoned) + components.conductor.notifier.register("running_start", + on_running_start) + with close_many(components.conductor, components.client): + t = threading_utils.daemon_thread(components.conductor.run) + t.start() + lb, fd = pu.temporary_flow_detail(components.persistence) + engines.save_factory_details(fd, sleep_factory, + [], {}, + backend=components.persistence) + components.board.post('poke', lb, + details={'flow_uuid': fd.uuid, + 'store': {'duration': 2}}) + running_start_event.wait(test_utils.WAIT_TIMEOUT) + components.conductor.stop() + job_abandoned_event.wait(test_utils.WAIT_TIMEOUT) + self.assertTrue(job_abandoned_event.is_set()) + self.assertFalse(job_consumed_event.is_set()) + self.assertFalse(consumed_event.is_set()) + class NonBlockingExecutorTest(test.TestCase): def test_bad_wait_timeout(self): diff --git a/taskflow/tests/unit/worker_based/test_worker.py b/taskflow/tests/unit/worker_based/test_worker.py index da26fa3b..8716fa20 100644 --- a/taskflow/tests/unit/worker_based/test_worker.py +++ b/taskflow/tests/unit/worker_based/test_worker.py @@ -33,7 +33,7 @@ class TestWorker(test.MockTestCase): self.broker_url = 'test-url' self.exchange = 'test-exchange' self.topic = 'test-topic' - self.endpoint_count = 27 + self.endpoint_count = 28 # patch classes self.executor_mock, self.executor_inst_mock = self.patchClass( diff --git a/taskflow/tests/utils.py b/taskflow/tests/utils.py index 23eeeb6e..0a71c143 100644 --- a/taskflow/tests/utils.py +++ b/taskflow/tests/utils.py @@ -17,6 +17,7 @@ import contextlib import string import threading +import time import redis import six @@ -340,6 +341,11 @@ class TaskRevertExtraArgs(task.Task): pass +class SleepTask(task.Task): + def execute(self, duration, **kwargs): + time.sleep(duration) + + class EngineTestBase(object): def setUp(self): super(EngineTestBase, self).setUp()