From 5d74d72257f17fa8ed0f6b3202544dc35819d787 Mon Sep 17 00:00:00 2001 From: Joshua Harlow Date: Fri, 6 Jun 2014 20:55:32 -0700 Subject: [PATCH] Make the expiring cache a top level cache type Create a cache module type and adjust a few of its methods to be more pythonic and then switch out the work_based engines usage of it and adjust its tests methods with adjusted methods using the new cache types functionality. Part of blueprint top-level-types Change-Id: I75c4b7db6dd989ef328e9e14d4b00266b1c97a9f --- taskflow/engines/worker_based/cache.py | 48 ++---------- taskflow/engines/worker_based/executor.py | 13 ++-- .../tests/unit/worker_based/test_executor.py | 43 ++++++----- taskflow/types/cache.py | 73 +++++++++++++++++++ 4 files changed, 107 insertions(+), 70 deletions(-) create mode 100644 taskflow/types/cache.py diff --git a/taskflow/engines/worker_based/cache.py b/taskflow/engines/worker_based/cache.py index f92bf23e..9da7f12c 100644 --- a/taskflow/engines/worker_based/cache.py +++ b/taskflow/engines/worker_based/cache.py @@ -14,54 +14,16 @@ # License for the specific language governing permissions and limitations # under the License. -import logging import random import six from taskflow.engines.worker_based import protocol as pr -from taskflow.utils import lock_utils as lu - -LOG = logging.getLogger(__name__) +from taskflow.types import cache as base -class Cache(object): - """Represents thread-safe cache.""" - - def __init__(self): - self._data = {} - self._lock = lu.ReaderWriterLock() - - def get(self, key): - """Retrieve a value from the cache.""" - with self._lock.read_lock(): - return self._data.get(key) - - def set(self, key, value): - """Set a value in the cache.""" - with self._lock.write_lock(): - self._data[key] = value - LOG.debug("Cache updated. Capacity: %s", len(self._data)) - - def delete(self, key): - """Delete a value from the cache.""" - with self._lock.write_lock(): - self._data.pop(key, None) - - def cleanup(self, on_expired_callback=None): - """Delete out-dated values from the cache.""" - with self._lock.write_lock(): - expired_values = [(k, v) for k, v in six.iteritems(self._data) - if v.expired] - for (k, _v) in expired_values: - self._data.pop(k, None) - if on_expired_callback: - for (_k, v) in expired_values: - on_expired_callback(v) - - -class RequestsCache(Cache): - """Represents thread-safe requests cache.""" +class RequestsCache(base.ExpiringCache): + """Represents a thread-safe requests cache.""" def get_waiting_requests(self, tasks): """Get list of waiting requests by tasks.""" @@ -73,8 +35,8 @@ class RequestsCache(Cache): return waiting_requests -class WorkersCache(Cache): - """Represents thread-safe workers cache.""" +class WorkersCache(base.ExpiringCache): + """Represents a thread-safe workers cache.""" def get_topic_by_task(self, task): """Get topic for a given task.""" diff --git a/taskflow/engines/worker_based/executor.py b/taskflow/engines/worker_based/executor.py index 37ea8bd7..35febd40 100644 --- a/taskflow/engines/worker_based/executor.py +++ b/taskflow/engines/worker_based/executor.py @@ -110,7 +110,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): tasks = notify['tasks'] # add worker info to the cache - self._workers_cache.set(topic, tasks) + self._workers_cache[topic] = tasks # publish waiting requests for request in self._requests_cache.get_waiting_requests(tasks): @@ -137,7 +137,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): # NOTE(imelnikov): request should not be in cache when # another thread can see its result and schedule another # request with same uuid; so we remove it, then set result - self._requests_cache.delete(request.uuid) + del self._requests_cache[request.uuid] request.set_result(**response.data) else: LOG.warning("Unexpected response status: '%s'", @@ -175,10 +175,10 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): # processing thread get list of waiting requests and publish it # before it is published here, so it wouldn't be published twice. request.set_pending() - self._requests_cache.set(request.uuid, request) + self._requests_cache[request.uuid] = request self._publish_request(request, topic) else: - self._requests_cache.set(request.uuid, request) + self._requests_cache[request.uuid] = request return request.result @@ -191,9 +191,8 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): correlation_id=request.uuid) except Exception: with misc.capture_failure() as failure: - LOG.exception("Failed to submit the '%s' request." % - request) - self._requests_cache.delete(request.uuid) + LOG.exception("Failed to submit the '%s' request.", request) + del self._requests_cache[request.uuid] request.set_result(failure) def _notify_topics(self): diff --git a/taskflow/tests/unit/worker_based/test_executor.py b/taskflow/tests/unit/worker_based/test_executor.py index 75092003..7faee1b7 100644 --- a/taskflow/tests/unit/worker_based/test_executor.py +++ b/taskflow/tests/unit/worker_based/test_executor.py @@ -93,7 +93,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): def test_on_message_response_state_running(self): response = pr.Response(pr.RUNNING) ex = self.executor() - ex._requests_cache.set(self.task_uuid, self.request_inst_mock) + ex._requests_cache[self.task_uuid] = self.request_inst_mock ex._on_message(response.to_dict(), self.message_mock) self.assertEqual(self.request_inst_mock.mock_calls, @@ -103,7 +103,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): def test_on_message_response_state_progress(self): response = pr.Response(pr.PROGRESS, progress=1.0) ex = self.executor() - ex._requests_cache.set(self.task_uuid, self.request_inst_mock) + ex._requests_cache[self.task_uuid] = self.request_inst_mock ex._on_message(response.to_dict(), self.message_mock) self.assertEqual(self.request_inst_mock.mock_calls, @@ -115,10 +115,10 @@ class TestWorkerTaskExecutor(test.MockTestCase): failure_dict = failure.to_dict() response = pr.Response(pr.FAILURE, result=failure_dict) ex = self.executor() - ex._requests_cache.set(self.task_uuid, self.request_inst_mock) + ex._requests_cache[self.task_uuid] = self.request_inst_mock ex._on_message(response.to_dict(), self.message_mock) - self.assertEqual(len(ex._requests_cache._data), 0) + self.assertEqual(len(ex._requests_cache), 0) self.assertEqual(self.request_inst_mock.mock_calls, [ mock.call.set_result(result=utils.FailureMatcher(failure)) ]) @@ -128,7 +128,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): response = pr.Response(pr.SUCCESS, result=self.task_result, event='executed') ex = self.executor() - ex._requests_cache.set(self.task_uuid, self.request_inst_mock) + ex._requests_cache[self.task_uuid] = self.request_inst_mock ex._on_message(response.to_dict(), self.message_mock) self.assertEqual(self.request_inst_mock.mock_calls, @@ -139,7 +139,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): def test_on_message_response_unknown_state(self): response = pr.Response(state='') ex = self.executor() - ex._requests_cache.set(self.task_uuid, self.request_inst_mock) + ex._requests_cache[self.task_uuid] = self.request_inst_mock ex._on_message(response.to_dict(), self.message_mock) self.assertEqual(self.request_inst_mock.mock_calls, []) @@ -149,7 +149,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): self.message_mock.properties['correlation_id'] = '' response = pr.Response(pr.RUNNING) ex = self.executor() - ex._requests_cache.set(self.task_uuid, self.request_inst_mock) + ex._requests_cache[self.task_uuid] = self.request_inst_mock ex._on_message(response.to_dict(), self.message_mock) self.assertEqual(self.request_inst_mock.mock_calls, []) @@ -159,7 +159,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): self.message_mock.properties = {'type': pr.RESPONSE} response = pr.Response(pr.RUNNING) ex = self.executor() - ex._requests_cache.set(self.task_uuid, self.request_inst_mock) + ex._requests_cache[self.task_uuid] = self.request_inst_mock ex._on_message(response.to_dict(), self.message_mock) self.assertEqual(self.request_inst_mock.mock_calls, []) @@ -188,32 +188,35 @@ class TestWorkerTaskExecutor(test.MockTestCase): def test_on_wait_task_not_expired(self): ex = self.executor() - ex._requests_cache.set(self.task_uuid, self.request_inst_mock) + ex._requests_cache[self.task_uuid] = self.request_inst_mock - self.assertEqual(len(ex._requests_cache._data), 1) + self.assertEqual(len(ex._requests_cache), 1) ex._on_wait() - self.assertEqual(len(ex._requests_cache._data), 1) + self.assertEqual(len(ex._requests_cache), 1) def test_on_wait_task_expired(self): self.request_inst_mock.expired = True ex = self.executor() - ex._requests_cache.set(self.task_uuid, self.request_inst_mock) + ex._requests_cache[self.task_uuid] = self.request_inst_mock - self.assertEqual(len(ex._requests_cache._data), 1) + self.assertEqual(len(ex._requests_cache), 1) ex._on_wait() - self.assertEqual(len(ex._requests_cache._data), 0) + self.assertEqual(len(ex._requests_cache), 0) def test_remove_task_non_existent(self): ex = self.executor() - ex._requests_cache.set(self.task_uuid, self.request_inst_mock) + ex._requests_cache[self.task_uuid] = self.request_inst_mock - self.assertEqual(len(ex._requests_cache._data), 1) - ex._requests_cache.delete(self.task_uuid) - self.assertEqual(len(ex._requests_cache._data), 0) + self.assertEqual(len(ex._requests_cache), 1) + del ex._requests_cache[self.task_uuid] + self.assertEqual(len(ex._requests_cache), 0) # delete non-existent - ex._requests_cache.delete(self.task_uuid) - self.assertEqual(len(ex._requests_cache._data), 0) + try: + del ex._requests_cache[self.task_uuid] + except KeyError: + pass + self.assertEqual(len(ex._requests_cache), 0) def test_execute_task(self): self.message_mock.properties['type'] = pr.NOTIFY diff --git a/taskflow/types/cache.py b/taskflow/types/cache.py new file mode 100644 index 00000000..72214fed --- /dev/null +++ b/taskflow/types/cache.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- + +# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import six + +from taskflow.utils import lock_utils as lu +from taskflow.utils import reflection + + +class ExpiringCache(object): + """Represents a thread-safe time-based expiring cache. + + NOTE(harlowja): the values in this cache must have a expired attribute that + can be used to determine if the key and associated value has expired or if + it has not. + """ + + def __init__(self): + self._data = {} + self._lock = lu.ReaderWriterLock() + + def __setitem__(self, key, value): + """Set a value in the cache.""" + with self._lock.write_lock(): + self._data[key] = value + + def __len__(self): + """Returns how many items are in this cache.""" + with self._lock.read_lock(): + return len(self._data) + + def get(self, key, default=None): + """Retrieve a value from the cache (returns default if not found).""" + with self._lock.read_lock(): + return self._data.get(key, default) + + def __getitem__(self, key): + """Retrieve a value from the cache.""" + with self._lock.read_lock(): + return self._data[key] + + def __delitem__(self, key): + """Delete a key & value from the cache.""" + with self._lock.write_lock(): + del self._data[key] + + def cleanup(self, on_expired_callback=None): + """Delete out-dated keys & values from the cache.""" + with self._lock.write_lock(): + expired_values = [(k, v) for k, v in six.iteritems(self._data) + if v.expired] + for (k, _v) in expired_values: + del self._data[k] + if on_expired_callback: + arg_c = len(reflection.get_callable_args(on_expired_callback)) + for (k, v) in expired_values: + if arg_c == 2: + on_expired_callback(k, v) + else: + on_expired_callback(v)