From ae28f0ba4b267bc3addc8978a11f7b686d1fe573 Mon Sep 17 00:00:00 2001 From: Stanislav Kudriashev Date: Thu, 6 Mar 2014 18:47:35 +0200 Subject: [PATCH] Introduce remote tasks cache for worker-executor Updated WorkerTaskExecutor to use cache for remote tasks. Change-Id: I4572052f63647472367cb69fc02911bbec2bd4cc --- taskflow/engines/worker_based/cache.py | 57 +++++++++++++++++ taskflow/engines/worker_based/executor.py | 64 ++++++++----------- .../tests/unit/worker_based/test_executor.py | 48 +++++++------- 3 files changed, 107 insertions(+), 62 deletions(-) create mode 100644 taskflow/engines/worker_based/cache.py diff --git a/taskflow/engines/worker_based/cache.py b/taskflow/engines/worker_based/cache.py new file mode 100644 index 00000000..efacad09 --- /dev/null +++ b/taskflow/engines/worker_based/cache.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- + +# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import logging + +import six + +from taskflow.utils import lock_utils as lu + +LOG = logging.getLogger(__name__) + + +class Cache(object): + """Represents thread-safe cache.""" + + def __init__(self): + self._data = {} + self._lock = lu.ReaderWriterLock() + + def get(self, key): + """Retrieve a value from the cache.""" + with self._lock.read_lock(): + return self._data.get(key) + + def set(self, key, value): + """Set a value in the cache.""" + with self._lock.write_lock(): + self._data[key] = value + LOG.debug("Cache updated. Capacity: %s", len(self._data)) + + def delete(self, key): + """Delete a value from the cache.""" + with self._lock.write_lock(): + self._data.pop(key, None) + + def cleanup(self, on_expired_callback=None): + """Delete out-dated values from the cache.""" + with self._lock.write_lock(): + expired_values = [(k, v) for k, v in six.iteritems(self._data) + if v.expired] + for k, v in expired_values: + if on_expired_callback: + on_expired_callback(v) + self._data.pop(k, None) diff --git a/taskflow/engines/worker_based/executor.py b/taskflow/engines/worker_based/executor.py index 5c8c3a66..cf747f2d 100644 --- a/taskflow/engines/worker_based/executor.py +++ b/taskflow/engines/worker_based/executor.py @@ -15,12 +15,14 @@ # under the License. import logging -import six import threading +import six + from kombu import exceptions as kombu_exc from taskflow.engines.action_engine import executor +from taskflow.engines.worker_based import cache from taskflow.engines.worker_based import protocol as pr from taskflow.engines.worker_based import proxy from taskflow.engines.worker_based import remote_task as rt @@ -40,7 +42,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): self._proxy = proxy.Proxy(uuid, exchange, self._on_message, self._on_wait, **kwargs) self._proxy_thread = None - self._remote_tasks = {} + self._remote_tasks_cache = cache.Cache() # TODO(skudriashev): This data should be collected from workers # using broadcast messages directly. @@ -78,55 +80,41 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): def _process_response(self, task_uuid, response): """Process response from remote side.""" - try: - task = self._remote_tasks[task_uuid] - except KeyError: - LOG.debug("Task with id='%s' not found.", task_uuid) - else: + remote_task = self._remote_tasks_cache.get(task_uuid) + if remote_task is not None: state = response.pop('state') if state == pr.RUNNING: - task.set_running() + remote_task.set_running() elif state == pr.PROGRESS: - task.on_progress(**response) + remote_task.on_progress(**response) elif state == pr.FAILURE: response['result'] = pu.failure_from_dict(response['result']) - task.set_result(**response) - self._remove_remote_task(task) + remote_task.set_result(**response) + self._remote_tasks_cache.delete(remote_task.uuid) elif state == pr.SUCCESS: - task.set_result(**response) - self._remove_remote_task(task) + remote_task.set_result(**response) + self._remote_tasks_cache.delete(remote_task.uuid) else: LOG.warning("Unexpected response status: '%s'", state) + else: + LOG.debug("Remote task with id='%s' not found.", task_uuid) + + @staticmethod + def _handle_expired_remote_task(task): + LOG.debug("Remote task '%r' has expired.", task) + task.set_result(misc.Failure.from_exception( + exc.Timeout("Remote task '%r' has expired" % task))) def _on_wait(self): - """This function is called cyclically between draining events - iterations to clean-up expired task requests. - """ - expired_tasks = [task for task in six.itervalues(self._remote_tasks) - if task.expired] - for task in expired_tasks: - LOG.debug("Task request '%s' has expired.", task) - task.set_result(misc.Failure.from_exception( - exc.Timeout("Task request '%s' has expired" % task))) - del self._remote_tasks[task.uuid] - - def _store_remote_task(self, task): - """Store task in the remote tasks map.""" - self._remote_tasks[task.uuid] = task - return task - - def _remove_remote_task(self, task): - """Remove remote task from the tasks map.""" - if task.uuid in self._remote_tasks: - del self._remote_tasks[task.uuid] + """This function is called cyclically between draining events.""" + self._remote_tasks_cache.cleanup(self._handle_expired_remote_task) def _submit_task(self, task, task_uuid, action, arguments, progress_callback, timeout=pr.REQUEST_TIMEOUT, **kwargs): """Submit task request to workers.""" - remote_task = self._store_remote_task( - rt.RemoteTask(task, task_uuid, action, arguments, - progress_callback, timeout, **kwargs) - ) + remote_task = rt.RemoteTask(task, task_uuid, action, arguments, + progress_callback, timeout, **kwargs) + self._remote_tasks_cache.set(remote_task.uuid, remote_task) try: # get task's workers topic to send request to try: @@ -143,7 +131,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): except Exception: with misc.capture_failure() as failure: LOG.exception("Failed to submit the '%s' task", remote_task) - self._remove_remote_task(remote_task) + self._remote_tasks_cache.delete(remote_task.uuid) remote_task.set_result(failure) return remote_task.result diff --git a/taskflow/tests/unit/worker_based/test_executor.py b/taskflow/tests/unit/worker_based/test_executor.py index 25449846..11dd66e8 100644 --- a/taskflow/tests/unit/worker_based/test_executor.py +++ b/taskflow/tests/unit/worker_based/test_executor.py @@ -105,7 +105,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): def test_on_message_state_running(self): response = dict(state=pr.RUNNING) ex = self.executor() - ex._store_remote_task(self.remote_task_mock) + ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock) ex._on_message(response, self.message_mock) self.assertEqual(self.remote_task_mock.mock_calls, @@ -115,7 +115,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): def test_on_message_state_progress(self): response = dict(state=pr.PROGRESS, progress=1.0) ex = self.executor() - ex._store_remote_task(self.remote_task_mock) + ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock) ex._on_message(response, self.message_mock) self.assertEqual(self.remote_task_mock.mock_calls, @@ -127,10 +127,10 @@ class TestWorkerTaskExecutor(test.MockTestCase): failure_dict = pu.failure_to_dict(failure) response = dict(state=pr.FAILURE, result=failure_dict) ex = self.executor() - ex._store_remote_task(self.remote_task_mock) + ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock) ex._on_message(response, self.message_mock) - self.assertEqual(len(ex._remote_tasks), 0) + self.assertEqual(len(ex._remote_tasks_cache._data), 0) self.assertEqual(self.remote_task_mock.mock_calls, [ mock.call.set_result(result=utils.FailureMatcher(failure)) ]) @@ -140,7 +140,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): response = dict(state=pr.SUCCESS, result=self.task_result, event='executed') ex = self.executor() - ex._store_remote_task(self.remote_task_mock) + ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock) ex._on_message(response, self.message_mock) self.assertEqual(self.remote_task_mock.mock_calls, @@ -151,7 +151,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): def test_on_message_unknown_state(self): response = dict(state='unknown') ex = self.executor() - ex._store_remote_task(self.remote_task_mock) + ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock) ex._on_message(response, self.message_mock) self.assertEqual(self.remote_task_mock.mock_calls, []) @@ -161,7 +161,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): self.message_mock.properties = {'correlation_id': 'non-existent'} response = dict(state=pr.RUNNING) ex = self.executor() - ex._store_remote_task(self.remote_task_mock) + ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock) ex._on_message(response, self.message_mock) self.assertEqual(self.remote_task_mock.mock_calls, []) @@ -171,7 +171,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): self.message_mock.properties = {} response = dict(state=pr.RUNNING) ex = self.executor() - ex._store_remote_task(self.remote_task_mock) + ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock) ex._on_message(response, self.message_mock) self.assertEqual(self.remote_task_mock.mock_calls, []) @@ -184,37 +184,37 @@ class TestWorkerTaskExecutor(test.MockTestCase): self.assertTrue(mocked_exception.called) @mock.patch('taskflow.engines.worker_based.remote_task.misc.wallclock') - def test_on_wait_task_not_expired(self, mock_time): - mock_time.side_effect = [1, self.timeout] + def test_on_wait_task_not_expired(self, mocked_time): + mocked_time.side_effect = [1, self.timeout] ex = self.executor() - ex._store_remote_task(self.remote_task()) + ex._remote_tasks_cache.set(self.task_uuid, self.remote_task()) - self.assertEqual(len(ex._remote_tasks), 1) + self.assertEqual(len(ex._remote_tasks_cache._data), 1) ex._on_wait() - self.assertEqual(len(ex._remote_tasks), 1) + self.assertEqual(len(ex._remote_tasks_cache._data), 1) @mock.patch('taskflow.engines.worker_based.remote_task.misc.wallclock') - def test_on_wait_task_expired(self, mock_time): - mock_time.side_effect = [1, self.timeout + 2, self.timeout * 2] + def test_on_wait_task_expired(self, mocked_time): + mocked_time.side_effect = [1, self.timeout + 2, self.timeout * 2] ex = self.executor() - ex._store_remote_task(self.remote_task()) + ex._remote_tasks_cache.set(self.task_uuid, self.remote_task()) - self.assertEqual(len(ex._remote_tasks), 1) + self.assertEqual(len(ex._remote_tasks_cache._data), 1) ex._on_wait() - self.assertEqual(len(ex._remote_tasks), 0) + self.assertEqual(len(ex._remote_tasks_cache._data), 0) def test_remove_task_non_existent(self): task = self.remote_task() ex = self.executor() - ex._store_remote_task(task) + ex._remote_tasks_cache.set(self.task_uuid, task) - self.assertEqual(len(ex._remote_tasks), 1) - ex._remove_remote_task(task) - self.assertEqual(len(ex._remote_tasks), 0) + self.assertEqual(len(ex._remote_tasks_cache._data), 1) + ex._remote_tasks_cache.delete(self.task_uuid) + self.assertEqual(len(ex._remote_tasks_cache._data), 0) # remove non-existent - ex._remove_remote_task(task) - self.assertEqual(len(ex._remote_tasks), 0) + ex._remote_tasks_cache.delete(self.task_uuid) + self.assertEqual(len(ex._remote_tasks_cache._data), 0) def test_execute_task(self): request = self.request(action='execute')