From f368c5822ee876c8359cbe1cfbfaa6ddf1c6283c Mon Sep 17 00:00:00 2001 From: Stanislav Kudriashev Date: Tue, 11 Mar 2014 18:04:33 +0200 Subject: [PATCH] [WBE] Collect information from workers * Added the `Notify` message type. Used to notify workers about executor has been started and to receive workers information back (topic and tasks list); * Implemented `WorkersCache` to interact with workers information easly; * Got rid of the engine `workers_info` parameter. Change-Id: I4a810b1ddb0b04c11d12b47afc5f9cdf77d070be --- taskflow/engines/worker_based/cache.py | 28 ++++++ taskflow/engines/worker_based/engine.py | 2 +- taskflow/engines/worker_based/executor.py | 98 ++++++++++++------- taskflow/engines/worker_based/protocol.py | 37 +++++-- taskflow/engines/worker_based/server.py | 30 ++++-- taskflow/examples/worker_based/flow.py | 7 +- taskflow/examples/worker_based/worker.py | 2 +- taskflow/tests/unit/test_action_engine.py | 20 +--- .../tests/unit/worker_based/test_engine.py | 8 +- .../tests/unit/worker_based/test_executor.py | 29 ++++-- 10 files changed, 175 insertions(+), 86 deletions(-) diff --git a/taskflow/engines/worker_based/cache.py b/taskflow/engines/worker_based/cache.py index de19095e..f92bf23e 100644 --- a/taskflow/engines/worker_based/cache.py +++ b/taskflow/engines/worker_based/cache.py @@ -15,9 +15,11 @@ # under the License. import logging +import random import six +from taskflow.engines.worker_based import protocol as pr from taskflow.utils import lock_utils as lu LOG = logging.getLogger(__name__) @@ -56,3 +58,29 @@ class Cache(object): if on_expired_callback: for (_k, v) in expired_values: on_expired_callback(v) + + +class RequestsCache(Cache): + """Represents thread-safe requests cache.""" + + def get_waiting_requests(self, tasks): + """Get list of waiting requests by tasks.""" + waiting_requests = [] + with self._lock.read_lock(): + for request in six.itervalues(self._data): + if request.state == pr.WAITING and request.task_cls in tasks: + waiting_requests.append(request) + return waiting_requests + + +class WorkersCache(Cache): + """Represents thread-safe workers cache.""" + + def get_topic_by_task(self, task): + """Get topic for a given task.""" + available_topics = [] + with self._lock.read_lock(): + for topic, tasks in six.iteritems(self._data): + if task in tasks: + available_topics.append(topic) + return random.choice(available_topics) if available_topics else None diff --git a/taskflow/engines/worker_based/engine.py b/taskflow/engines/worker_based/engine.py index 44c705c6..af178cb6 100644 --- a/taskflow/engines/worker_based/engine.py +++ b/taskflow/engines/worker_based/engine.py @@ -30,7 +30,7 @@ class WorkerBasedActionEngine(engine.ActionEngine): 'uuid': flow_detail.uuid, 'url': conf.get('url'), 'exchange': conf.get('exchange', 'default'), - 'workers_info': conf.get('workers_info', {}), + 'topics': conf.get('topics', []), 'transport': conf.get('transport'), 'transport_options': conf.get('transport_options') } diff --git a/taskflow/engines/worker_based/executor.py b/taskflow/engines/worker_based/executor.py index e1a68665..69f659a6 100644 --- a/taskflow/engines/worker_based/executor.py +++ b/taskflow/engines/worker_based/executor.py @@ -17,8 +17,6 @@ import logging import threading -import six - from kombu import exceptions as kombu_exc from taskflow.engines.action_engine import executor @@ -35,27 +33,24 @@ LOG = logging.getLogger(__name__) class WorkerTaskExecutor(executor.TaskExecutorBase): """Executes tasks on remote workers.""" - def __init__(self, uuid, exchange, workers_info, **kwargs): + def __init__(self, uuid, exchange, topics, **kwargs): self._uuid = uuid + self._topics = topics + self._requests_cache = cache.RequestsCache() + self._workers_cache = cache.WorkersCache() self._proxy = proxy.Proxy(uuid, exchange, self._on_message, self._on_wait, **kwargs) self._proxy_thread = None - self._requests_cache = cache.Cache() + self._notify_thread = None + self._notify_event = threading.Event() - # TODO(skudriashev): This data should be collected from workers - # using broadcast messages directly. - self._workers_info = {} - for topic, tasks in six.iteritems(workers_info): - for task in tasks: - self._workers_info[task] = topic - - def _get_proxy_thread(self): - proxy_thread = threading.Thread(target=self._proxy.start) + def _make_thread(self, target): + thread = threading.Thread(target=target) # NOTE(skudriashev): When the main thread is terminated unexpectedly - # and proxy thread is still alive - it will prevent main thread from - # exiting unless the daemon property is set to True. - proxy_thread.daemon = True - return proxy_thread + # and thread is still alive - it will prevent main thread from exiting + # unless the daemon property is set to True. + thread.daemon = True + return thread def _on_message(self, data, message): """This method is called on incoming message.""" @@ -72,11 +67,26 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): except KeyError: LOG.warning("The 'type' message property is missing.") else: - if msg_type == pr.RESPONSE: + if msg_type == pr.NOTIFY: + self._process_notify(data) + elif msg_type == pr.RESPONSE: self._process_response(data, message) else: LOG.warning("Unexpected message type: %s", msg_type) + def _process_notify(self, notify): + """Process notify message from remote side.""" + LOG.debug("Start processing notify message.") + topic = notify['topic'] + tasks = notify['tasks'] + + # add worker info to the cache + self._workers_cache.set(topic, tasks) + + # publish waiting requests + for request in self._requests_cache.get_waiting_requests(tasks): + self._publish_request(request, topic) + def _process_response(self, response, message): """Process response from remote side.""" LOG.debug("Start processing response message.") @@ -110,8 +120,9 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): the `Timeout` exception is set as a request result. """ LOG.debug("Request '%r' has expired.", request) + LOG.debug("The '%r' request has expired.", request) request.set_result(misc.Failure.from_exception( - exc.Timeout("Request '%r' has expired" % request))) + exc.Timeout("The '%r' request has expired" % request))) def _on_wait(self): """This function is called cyclically between draining events.""" @@ -123,26 +134,38 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): request = pr.Request(task, task_uuid, action, arguments, progress_callback, timeout, **kwargs) self._requests_cache.set(request.uuid, request) + + # Get task's topic and publish request if topic was found. + topic = self._workers_cache.get_topic_by_task(request.task_cls) + if topic is not None: + self._publish_request(request, topic) + + return request.result + + def _publish_request(self, request, topic): + """Publish request to a given topic.""" + LOG.debug("Sending request: %s" % request) try: - # get task's workers topic to send request to - try: - topic = self._workers_info[request.task_cls] - except KeyError: - raise exc.NotFound("Workers topic not found for the '%s'" - " task" % request.task_cls) - else: - # publish request - LOG.debug("Sending request: %s", request) - self._proxy.publish(request, - routing_key=topic, - reply_to=self._uuid, - correlation_id=request.uuid) + self._proxy.publish(msg=request, + routing_key=topic, + reply_to=self._uuid, + correlation_id=request.uuid) except Exception: with misc.capture_failure() as failure: - LOG.exception("Failed to submit the '%s' task", request) + LOG.exception("Failed to submit the '%s' request." % + request) self._requests_cache.delete(request.uuid) request.set_result(failure) - return request.result + else: + request.set_pending() + + def _notify_topics(self): + """Cyclically publish notify message to each topic.""" + LOG.debug("Notify thread started.") + while not self._notify_event.is_set(): + for topic in self._topics: + self._proxy.publish(pr.Notify(), topic, reply_to=self._uuid) + self._notify_event.wait(pr.NOTIFY_PERIOD) def execute_task(self, task, task_uuid, arguments, progress_callback=None): @@ -162,14 +185,19 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): def start(self): """Start proxy thread.""" if self._proxy_thread is None: - self._proxy_thread = self._get_proxy_thread() + self._proxy_thread = self._make_thread(self._proxy.start) self._proxy_thread.start() self._proxy.wait() + self._notify_thread = self._make_thread(self._notify_topics) + self._notify_thread.start() def stop(self): """Stop proxy, so its thread would be gracefully terminated.""" if self._proxy_thread is not None: if self._proxy_thread.is_alive(): + self._notify_event.set() + self._notify_thread.join() self._proxy.stop() self._proxy_thread.join() + self._notify_thread = None self._proxy_thread = None diff --git a/taskflow/engines/worker_based/protocol.py b/taskflow/engines/worker_based/protocol.py index 1e0868f1..3c353cd3 100644 --- a/taskflow/engines/worker_based/protocol.py +++ b/taskflow/engines/worker_based/protocol.py @@ -26,6 +26,7 @@ from taskflow.utils import persistence_utils as pu from taskflow.utils import reflection # NOTE(skudriashev): This is protocol events, not related to the task states. +WAITING = 'WAITING' PENDING = 'PENDING' RUNNING = 'RUNNING' SUCCESS = 'SUCCESS' @@ -53,7 +54,11 @@ REQUEST_TIMEOUT = 60 # no longer needed. QUEUE_EXPIRE_TIMEOUT = REQUEST_TIMEOUT +# Workers notify period. +NOTIFY_PERIOD = 5 + # Message types. +NOTIFY = 'NOTIFY' REQUEST = 'REQUEST' RESPONSE = 'RESPONSE' @@ -70,9 +75,20 @@ class Message(object): """Return json-serializable message representation.""" +class Notify(Message): + """Represents notify message type.""" + TYPE = NOTIFY + + def __init__(self, **data): + self._data = data + + def to_dict(self): + return self._data + + class Request(Message): """Represents request with execution results. Every request is created in - the PENDING state and is expired within the given timeout. + the WAITING state and is expired within the given timeout. """ TYPE = REQUEST @@ -87,7 +103,7 @@ class Request(Message): self._progress_callback = progress_callback self._kwargs = kwargs self._watch = misc.StopWatch(duration=timeout).start() - self._state = PENDING + self._state = WAITING self.result = futures.Future() def __repr__(self): @@ -101,18 +117,22 @@ class Request(Message): def task_cls(self): return self._task_cls + @property + def state(self): + return self._state + @property def expired(self): """Check if request has expired. - When new request is created its state is set to the PENDING, creation + When new request is created its state is set to the WAITING, creation time is stored and timeout is given via constructor arguments. - Request is considered to be expired when it is in the PENDING state - for more then the given timeout (it is not considered to be expired - in any other state). + Request is considered to be expired when it is in the WAITING/PENDING + state for more then the given timeout (it is not considered to be + expired in any other state). """ - if self._state == PENDING: + if self._state in (WAITING, PENDING): return self._watch.expired() return False @@ -139,6 +159,9 @@ class Request(Message): def set_result(self, result): self.result.set_result((self._task, self._event, result)) + def set_pending(self): + self._state = PENDING + def set_running(self): self._state = RUNNING self._watch.stop() diff --git a/taskflow/engines/worker_based/server.py b/taskflow/engines/worker_based/server.py index a018f71c..ca858a49 100644 --- a/taskflow/engines/worker_based/server.py +++ b/taskflow/engines/worker_based/server.py @@ -32,6 +32,7 @@ class Server(object): def __init__(self, topic, exchange, executor, endpoints, **kwargs): self._proxy = proxy.Proxy(topic, exchange, self._on_message, **kwargs) + self._topic = topic self._executor = executor self._endpoints = dict([(endpoint.name, endpoint) for endpoint in endpoints]) @@ -56,12 +57,15 @@ class Server(object): except KeyError: LOG.warning("The 'type' message property is missing.") else: - if msg_type == pr.REQUEST: - # spawn new thread to process request - self._executor.submit(self._process_request, data, - message) + if msg_type == pr.NOTIFY: + handler = self._process_notify + elif msg_type == pr.REQUEST: + handler = self._process_request else: LOG.warning("Unexpected message type: %s", msg_type) + return + # spawn new thread to process request + self._executor.submit(handler, data, message) else: try: # requeue message @@ -122,10 +126,24 @@ class Server(object): self._reply(reply_to, task_uuid, pr.PROGRESS, event_data=event_data, progress=progress) + def _process_notify(self, notify, message): + """Process notify message and reply back.""" + LOG.debug("Start processing notify message.") + try: + reply_to = message.properties['reply_to'] + except Exception: + LOG.exception("The 'reply_to' message property is missing.") + else: + self._proxy.publish( + msg=pr.Notify(topic=self._topic, tasks=self._endpoints.keys()), + routing_key=reply_to + ) + def _process_request(self, request, message): - """Process request in separate thread and reply back.""" - # NOTE(skudriashev): Parse broker message first to get the `reply_to` + """Process request message and reply back.""" + # NOTE(skudriashev): parse broker message first to get the `reply_to` # and the `task_uuid` parameters to have possibility to reply back. + LOG.debug("Start processing request message.") try: reply_to, task_uuid = self._parse_message(message) except ValueError: diff --git a/taskflow/examples/worker_based/flow.py b/taskflow/examples/worker_based/flow.py index 0c6778d3..50529a81 100644 --- a/taskflow/examples/worker_based/flow.py +++ b/taskflow/examples/worker_based/flow.py @@ -30,12 +30,7 @@ if __name__ == "__main__": engine_conf = { 'engine': 'worker-based', 'exchange': 'taskflow', - 'workers_info': { - 'topic': [ - 'taskflow.tests.utils.TaskOneArgOneReturn', - 'taskflow.tests.utils.TaskMultiArgOneReturn' - ] - } + 'topics': ['test-topic'], } # parse command line diff --git a/taskflow/examples/worker_based/worker.py b/taskflow/examples/worker_based/worker.py index aee3246c..405813c7 100644 --- a/taskflow/examples/worker_based/worker.py +++ b/taskflow/examples/worker_based/worker.py @@ -27,7 +27,7 @@ if __name__ == "__main__": logging.basicConfig(level=logging.ERROR) worker_conf = { 'exchange': 'taskflow', - 'topic': 'topic', + 'topic': 'test-topic', 'tasks': [ 'taskflow.tests.utils:TaskOneArgOneReturn', 'taskflow.tests.utils:TaskMultiArgOneReturn' diff --git a/taskflow/tests/unit/test_action_engine.py b/taskflow/tests/unit/test_action_engine.py index fc937098..e86ae62e 100644 --- a/taskflow/tests/unit/test_action_engine.py +++ b/taskflow/tests/unit/test_action_engine.py @@ -591,9 +591,9 @@ class WorkerBasedEngineTest(EngineTaskTest, 'exchange': self.exchange, 'topic': self.topic, 'tasks': [ - 'taskflow.tests.utils' + 'taskflow.tests.utils', ], - 'transport': self.transport + 'transport': self.transport, } self.worker = wkr.Worker(**worker_conf) self.worker_thread = threading.Thread(target=self.worker.run) @@ -611,20 +611,8 @@ class WorkerBasedEngineTest(EngineTaskTest, engine_conf = { 'engine': 'worker-based', 'exchange': self.exchange, - 'workers_info': { - self.topic: [ - 'taskflow.tests.utils.SaveOrderTask', - 'taskflow.tests.utils.FailingTask', - 'taskflow.tests.utils.TaskOneReturn', - 'taskflow.tests.utils.TaskMultiReturn', - 'taskflow.tests.utils.TaskMultiArgOneReturn', - 'taskflow.tests.utils.NastyTask', - 'taskflow.tests.utils.NastyFailingTask', - 'taskflow.tests.utils.NeverRunningTask', - 'taskflow.tests.utils.TaskNoRequiresNoReturns' - ] - }, - 'transport': self.transport + 'topics': [self.topic], + 'transport': self.transport, } return taskflow.engines.load(flow, flow_detail=flow_detail, engine_conf=engine_conf, diff --git a/taskflow/tests/unit/worker_based/test_engine.py b/taskflow/tests/unit/worker_based/test_engine.py index f8acb035..c966be5a 100644 --- a/taskflow/tests/unit/worker_based/test_engine.py +++ b/taskflow/tests/unit/worker_based/test_engine.py @@ -29,7 +29,7 @@ class TestWorkerBasedActionEngine(test.MockTestCase): super(TestWorkerBasedActionEngine, self).setUp() self.broker_url = 'test-url' self.exchange = 'test-exchange' - self.workers_info = {'test-topic': ['task1', 'task2']} + self.topics = ['test-topic1', 'test-topic2'] # patch classes self.executor_mock, self.executor_inst_mock = self._patch_class( @@ -44,7 +44,7 @@ class TestWorkerBasedActionEngine(test.MockTestCase): mock.call.executor_class(uuid=flow_detail.uuid, url=None, exchange='default', - workers_info={}, + topics=[], transport=None, transport_options=None) ] @@ -54,7 +54,7 @@ class TestWorkerBasedActionEngine(test.MockTestCase): flow = lf.Flow('test-flow').add(utils.DummyTask()) _, flow_detail = pu.temporary_flow_detail() config = {'url': self.broker_url, 'exchange': self.exchange, - 'workers_info': self.workers_info, 'transport': 'memory', + 'topics': self.topics, 'transport': 'memory', 'transport_options': {}} engine.WorkerBasedActionEngine( flow, flow_detail, None, config).compile() @@ -63,7 +63,7 @@ class TestWorkerBasedActionEngine(test.MockTestCase): mock.call.executor_class(uuid=flow_detail.uuid, url=self.broker_url, exchange=self.exchange, - workers_info=self.workers_info, + topics=self.topics, transport='memory', transport_options={}) ] diff --git a/taskflow/tests/unit/worker_based/test_executor.py b/taskflow/tests/unit/worker_based/test_executor.py index 63f987dd..b763f673 100644 --- a/taskflow/tests/unit/worker_based/test_executor.py +++ b/taskflow/tests/unit/worker_based/test_executor.py @@ -42,8 +42,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): self.broker_url = 'broker-url' self.executor_uuid = 'executor-uuid' self.executor_exchange = 'executor-exchange' - self.executor_topic = 'executor-topic' - self.executor_workers_info = {self.executor_topic: [self.task.name]} + self.executor_topic = 'test-topic1' self.proxy_started_event = threading.Event() # patch classes @@ -75,7 +74,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): def executor(self, reset_master_mock=True, **kwargs): executor_kwargs = dict(uuid=self.executor_uuid, exchange=self.executor_exchange, - workers_info=self.executor_workers_info, + topics=[self.executor_topic], url=self.broker_url) executor_kwargs.update(kwargs) ex = executor.WorkerTaskExecutor(**executor_kwargs) @@ -218,21 +217,28 @@ class TestWorkerTaskExecutor(test.MockTestCase): self.assertEqual(len(ex._requests_cache._data), 0) def test_execute_task(self): + self.message_mock.properties['type'] = pr.NOTIFY + notify = pr.Notify(topic=self.executor_topic, tasks=[self.task.name]) ex = self.executor() + ex._on_message(notify.to_dict(), self.message_mock) ex.execute_task(self.task, self.task_uuid, self.task_args) expected_calls = [ mock.call.Request(self.task, self.task_uuid, 'execute', self.task_args, None, self.timeout), - mock.call.proxy.publish(self.request_inst_mock, + mock.call.proxy.publish(msg=self.request_inst_mock, routing_key=self.executor_topic, reply_to=self.executor_uuid, - correlation_id=self.task_uuid) + correlation_id=self.task_uuid), + mock.call.request.set_pending() ] self.assertEqual(self.master_mock.mock_calls, expected_calls) def test_revert_task(self): + self.message_mock.properties['type'] = pr.NOTIFY + notify = pr.Notify(topic=self.executor_topic, tasks=[self.task.name]) ex = self.executor() + ex._on_message(notify.to_dict(), self.message_mock) ex.revert_task(self.task, self.task_uuid, self.task_args, self.task_result, self.task_failures) @@ -241,10 +247,11 @@ class TestWorkerTaskExecutor(test.MockTestCase): self.task_args, None, self.timeout, failures=self.task_failures, result=self.task_result), - mock.call.proxy.publish(self.request_inst_mock, + mock.call.proxy.publish(msg=self.request_inst_mock, routing_key=self.executor_topic, reply_to=self.executor_uuid, - correlation_id=self.task_uuid) + correlation_id=self.task_uuid), + mock.call.request.set_pending() ] self.assertEqual(self.master_mock.mock_calls, expected_calls) @@ -255,20 +262,22 @@ class TestWorkerTaskExecutor(test.MockTestCase): expected_calls = [ mock.call.Request(self.task, self.task_uuid, 'execute', - self.task_args, None, self.timeout), - mock.call.request.set_result(mock.ANY) + self.task_args, None, self.timeout) ] self.assertEqual(self.master_mock.mock_calls, expected_calls) def test_execute_task_publish_error(self): + self.message_mock.properties['type'] = pr.NOTIFY self.proxy_inst_mock.publish.side_effect = Exception('Woot!') + notify = pr.Notify(topic=self.executor_topic, tasks=[self.task.name]) ex = self.executor() + ex._on_message(notify.to_dict(), self.message_mock) ex.execute_task(self.task, self.task_uuid, self.task_args) expected_calls = [ mock.call.Request(self.task, self.task_uuid, 'execute', self.task_args, None, self.timeout), - mock.call.proxy.publish(self.request_inst_mock, + mock.call.proxy.publish(msg=self.request_inst_mock, routing_key=self.executor_topic, reply_to=self.executor_uuid, correlation_id=self.task_uuid),