diff --git a/taskflow/engines/worker_based/worker.py b/taskflow/engines/worker_based/worker.py index ee3ea159..de55a8e2 100644 --- a/taskflow/engines/worker_based/worker.py +++ b/taskflow/engines/worker_based/worker.py @@ -143,11 +143,15 @@ class Worker(object): return BANNER_TEMPLATE.substitute(BANNER_TEMPLATE.defaults, **tpl_params) - def run(self, display_banner=True): + def run(self, display_banner=True, banner_writer=None): """Runs the worker.""" if display_banner: - for line in self._generate_banner().splitlines(): - LOG.info(line) + banner = self._generate_banner() + if banner_writer is None: + for line in banner.splitlines(): + LOG.info(line) + else: + banner_writer(banner) self._server.start() def wait(self): diff --git a/taskflow/tests/unit/worker_based/test_worker.py b/taskflow/tests/unit/worker_based/test_worker.py index d37e817f..ff049a64 100644 --- a/taskflow/tests/unit/worker_based/test_worker.py +++ b/taskflow/tests/unit/worker_based/test_worker.py @@ -14,6 +14,8 @@ # License for the specific language governing permissions and limitations # under the License. +import six + from taskflow.engines.worker_based import endpoint from taskflow.engines.worker_based import worker from taskflow import test @@ -66,6 +68,14 @@ class TestWorker(test.MockTestCase): ] self.assertEqual(self.master_mock.mock_calls, master_mock_calls) + def test_banner_writing(self): + buf = six.StringIO() + w = self.worker() + w.run(banner_writer=buf.write) + w.wait() + w.stop() + self.assertGreater(0, len(buf.getvalue())) + def test_creation_with_custom_threads_count(self): self.worker(threads_count=10)