diff --git a/taskflow/engines/worker_based/worker.py b/taskflow/engines/worker_based/worker.py index 5ac0cf4f2..4273af605 100644 --- a/taskflow/engines/worker_based/worker.py +++ b/taskflow/engines/worker_based/worker.py @@ -20,13 +20,13 @@ import socket import string import sys -from concurrent import futures from oslo_utils import reflection from taskflow.engines.worker_based import endpoint from taskflow.engines.worker_based import server from taskflow import logging from taskflow import task as t_task +from taskflow.types import futures from taskflow.utils import misc from taskflow.utils import threading_utils as tu from taskflow import version @@ -77,23 +77,26 @@ class Worker(object): will be used to create tasks from. :param executor: custom executor object that can used for processing requests in separate threads (if not provided one will be created) - :param threads_count: threads count to be passed to the default executor + :param threads_count: threads count to be passed to the + default executor (used only if an executor is not + passed in) :param transport: transport to be used (e.g. amqp, memory, etc.) :param transport_options: transport specific options :param retry_options: retry specific options (used to configure how kombu handles retrying under tolerable/transient failures). """ - def __init__(self, exchange, topic, tasks, executor=None, **kwargs): + def __init__(self, exchange, topic, tasks, + executor=None, threads_count=None, url=None, + transport=None, transport_options=None, + retry_options=None): self._topic = topic self._executor = executor self._owns_executor = False self._threads_count = -1 if self._executor is None: - if 'threads_count' in kwargs: - self._threads_count = int(kwargs.pop('threads_count')) - if self._threads_count <= 0: - raise ValueError("threads_count provided must be > 0") + if threads_count is not None: + self._threads_count = int(threads_count) else: self._threads_count = tu.get_optimal_thread_count() self._executor = futures.ThreadPoolExecutor(self._threads_count) @@ -101,7 +104,10 @@ class Worker(object): self._endpoints = self._derive_endpoints(tasks) self._exchange = exchange self._server = server.Server(topic, exchange, self._executor, - self._endpoints, **kwargs) + self._endpoints, url=url, + transport=transport, + transport_options=transport_options, + retry_options=retry_options) @staticmethod def _derive_endpoints(tasks): diff --git a/taskflow/tests/unit/worker_based/test_worker.py b/taskflow/tests/unit/worker_based/test_worker.py index a572beb47..cc4578c05 100644 --- a/taskflow/tests/unit/worker_based/test_worker.py +++ b/taskflow/tests/unit/worker_based/test_worker.py @@ -64,7 +64,11 @@ class TestWorker(test.MockTestCase): master_mock_calls = [ mock.call.executor_class(self.threads_count), mock.call.Server(self.topic, self.exchange, - self.executor_inst_mock, [], url=self.broker_url) + self.executor_inst_mock, [], + url=self.broker_url, + transport_options=mock.ANY, + transport=mock.ANY, + retry_options=mock.ANY) ] self.assertEqual(self.master_mock.mock_calls, master_mock_calls) @@ -83,20 +87,23 @@ class TestWorker(test.MockTestCase): mock.call.executor_class(10), mock.call.Server(self.topic, self.exchange, self.executor_inst_mock, [], - url=self.broker_url) + url=self.broker_url, + transport_options=mock.ANY, + transport=mock.ANY, + retry_options=mock.ANY) ] self.assertEqual(self.master_mock.mock_calls, master_mock_calls) - def test_creation_with_negative_threads_count(self): - self.assertRaises(ValueError, self.worker, threads_count=-10) - def test_creation_with_custom_executor(self): executor_mock = mock.MagicMock(name='executor') self.worker(executor=executor_mock) master_mock_calls = [ mock.call.Server(self.topic, self.exchange, executor_mock, [], - url=self.broker_url) + url=self.broker_url, + transport_options=mock.ANY, + transport=mock.ANY, + retry_options=mock.ANY) ] self.assertEqual(self.master_mock.mock_calls, master_mock_calls)