Make the expiring cache a top level cache type
Create a cache module type and adjust a few of its methods to be more pythonic and then switch out the work_based engines usage of it and adjust its tests methods with adjusted methods using the new cache types functionality. Part of blueprint top-level-types Change-Id: I75c4b7db6dd989ef328e9e14d4b00266b1c97a9f
This commit is contained in:
committed by
Joshua Harlow
parent
95cb0625f4
commit
5d74d72257
@@ -14,54 +14,16 @@
|
|||||||
# License for the specific language governing permissions and limitations
|
# License for the specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
import logging
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from taskflow.engines.worker_based import protocol as pr
|
from taskflow.engines.worker_based import protocol as pr
|
||||||
from taskflow.utils import lock_utils as lu
|
from taskflow.types import cache as base
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Cache(object):
|
class RequestsCache(base.ExpiringCache):
|
||||||
"""Represents thread-safe cache."""
|
"""Represents a thread-safe requests cache."""
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._data = {}
|
|
||||||
self._lock = lu.ReaderWriterLock()
|
|
||||||
|
|
||||||
def get(self, key):
|
|
||||||
"""Retrieve a value from the cache."""
|
|
||||||
with self._lock.read_lock():
|
|
||||||
return self._data.get(key)
|
|
||||||
|
|
||||||
def set(self, key, value):
|
|
||||||
"""Set a value in the cache."""
|
|
||||||
with self._lock.write_lock():
|
|
||||||
self._data[key] = value
|
|
||||||
LOG.debug("Cache updated. Capacity: %s", len(self._data))
|
|
||||||
|
|
||||||
def delete(self, key):
|
|
||||||
"""Delete a value from the cache."""
|
|
||||||
with self._lock.write_lock():
|
|
||||||
self._data.pop(key, None)
|
|
||||||
|
|
||||||
def cleanup(self, on_expired_callback=None):
|
|
||||||
"""Delete out-dated values from the cache."""
|
|
||||||
with self._lock.write_lock():
|
|
||||||
expired_values = [(k, v) for k, v in six.iteritems(self._data)
|
|
||||||
if v.expired]
|
|
||||||
for (k, _v) in expired_values:
|
|
||||||
self._data.pop(k, None)
|
|
||||||
if on_expired_callback:
|
|
||||||
for (_k, v) in expired_values:
|
|
||||||
on_expired_callback(v)
|
|
||||||
|
|
||||||
|
|
||||||
class RequestsCache(Cache):
|
|
||||||
"""Represents thread-safe requests cache."""
|
|
||||||
|
|
||||||
def get_waiting_requests(self, tasks):
|
def get_waiting_requests(self, tasks):
|
||||||
"""Get list of waiting requests by tasks."""
|
"""Get list of waiting requests by tasks."""
|
||||||
@@ -73,8 +35,8 @@ class RequestsCache(Cache):
|
|||||||
return waiting_requests
|
return waiting_requests
|
||||||
|
|
||||||
|
|
||||||
class WorkersCache(Cache):
|
class WorkersCache(base.ExpiringCache):
|
||||||
"""Represents thread-safe workers cache."""
|
"""Represents a thread-safe workers cache."""
|
||||||
|
|
||||||
def get_topic_by_task(self, task):
|
def get_topic_by_task(self, task):
|
||||||
"""Get topic for a given task."""
|
"""Get topic for a given task."""
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
|||||||
tasks = notify['tasks']
|
tasks = notify['tasks']
|
||||||
|
|
||||||
# add worker info to the cache
|
# add worker info to the cache
|
||||||
self._workers_cache.set(topic, tasks)
|
self._workers_cache[topic] = tasks
|
||||||
|
|
||||||
# publish waiting requests
|
# publish waiting requests
|
||||||
for request in self._requests_cache.get_waiting_requests(tasks):
|
for request in self._requests_cache.get_waiting_requests(tasks):
|
||||||
@@ -137,7 +137,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
|||||||
# NOTE(imelnikov): request should not be in cache when
|
# NOTE(imelnikov): request should not be in cache when
|
||||||
# another thread can see its result and schedule another
|
# another thread can see its result and schedule another
|
||||||
# request with same uuid; so we remove it, then set result
|
# request with same uuid; so we remove it, then set result
|
||||||
self._requests_cache.delete(request.uuid)
|
del self._requests_cache[request.uuid]
|
||||||
request.set_result(**response.data)
|
request.set_result(**response.data)
|
||||||
else:
|
else:
|
||||||
LOG.warning("Unexpected response status: '%s'",
|
LOG.warning("Unexpected response status: '%s'",
|
||||||
@@ -175,10 +175,10 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
|||||||
# processing thread get list of waiting requests and publish it
|
# processing thread get list of waiting requests and publish it
|
||||||
# before it is published here, so it wouldn't be published twice.
|
# before it is published here, so it wouldn't be published twice.
|
||||||
request.set_pending()
|
request.set_pending()
|
||||||
self._requests_cache.set(request.uuid, request)
|
self._requests_cache[request.uuid] = request
|
||||||
self._publish_request(request, topic)
|
self._publish_request(request, topic)
|
||||||
else:
|
else:
|
||||||
self._requests_cache.set(request.uuid, request)
|
self._requests_cache[request.uuid] = request
|
||||||
|
|
||||||
return request.result
|
return request.result
|
||||||
|
|
||||||
@@ -191,9 +191,8 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
|||||||
correlation_id=request.uuid)
|
correlation_id=request.uuid)
|
||||||
except Exception:
|
except Exception:
|
||||||
with misc.capture_failure() as failure:
|
with misc.capture_failure() as failure:
|
||||||
LOG.exception("Failed to submit the '%s' request." %
|
LOG.exception("Failed to submit the '%s' request.", request)
|
||||||
request)
|
del self._requests_cache[request.uuid]
|
||||||
self._requests_cache.delete(request.uuid)
|
|
||||||
request.set_result(failure)
|
request.set_result(failure)
|
||||||
|
|
||||||
def _notify_topics(self):
|
def _notify_topics(self):
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
|||||||
def test_on_message_response_state_running(self):
|
def test_on_message_response_state_running(self):
|
||||||
response = pr.Response(pr.RUNNING)
|
response = pr.Response(pr.RUNNING)
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||||
ex._on_message(response.to_dict(), self.message_mock)
|
ex._on_message(response.to_dict(), self.message_mock)
|
||||||
|
|
||||||
self.assertEqual(self.request_inst_mock.mock_calls,
|
self.assertEqual(self.request_inst_mock.mock_calls,
|
||||||
@@ -103,7 +103,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
|||||||
def test_on_message_response_state_progress(self):
|
def test_on_message_response_state_progress(self):
|
||||||
response = pr.Response(pr.PROGRESS, progress=1.0)
|
response = pr.Response(pr.PROGRESS, progress=1.0)
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||||
ex._on_message(response.to_dict(), self.message_mock)
|
ex._on_message(response.to_dict(), self.message_mock)
|
||||||
|
|
||||||
self.assertEqual(self.request_inst_mock.mock_calls,
|
self.assertEqual(self.request_inst_mock.mock_calls,
|
||||||
@@ -115,10 +115,10 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
|||||||
failure_dict = failure.to_dict()
|
failure_dict = failure.to_dict()
|
||||||
response = pr.Response(pr.FAILURE, result=failure_dict)
|
response = pr.Response(pr.FAILURE, result=failure_dict)
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||||
ex._on_message(response.to_dict(), self.message_mock)
|
ex._on_message(response.to_dict(), self.message_mock)
|
||||||
|
|
||||||
self.assertEqual(len(ex._requests_cache._data), 0)
|
self.assertEqual(len(ex._requests_cache), 0)
|
||||||
self.assertEqual(self.request_inst_mock.mock_calls, [
|
self.assertEqual(self.request_inst_mock.mock_calls, [
|
||||||
mock.call.set_result(result=utils.FailureMatcher(failure))
|
mock.call.set_result(result=utils.FailureMatcher(failure))
|
||||||
])
|
])
|
||||||
@@ -128,7 +128,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
|||||||
response = pr.Response(pr.SUCCESS, result=self.task_result,
|
response = pr.Response(pr.SUCCESS, result=self.task_result,
|
||||||
event='executed')
|
event='executed')
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||||
ex._on_message(response.to_dict(), self.message_mock)
|
ex._on_message(response.to_dict(), self.message_mock)
|
||||||
|
|
||||||
self.assertEqual(self.request_inst_mock.mock_calls,
|
self.assertEqual(self.request_inst_mock.mock_calls,
|
||||||
@@ -139,7 +139,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
|||||||
def test_on_message_response_unknown_state(self):
|
def test_on_message_response_unknown_state(self):
|
||||||
response = pr.Response(state='<unknown>')
|
response = pr.Response(state='<unknown>')
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||||
ex._on_message(response.to_dict(), self.message_mock)
|
ex._on_message(response.to_dict(), self.message_mock)
|
||||||
|
|
||||||
self.assertEqual(self.request_inst_mock.mock_calls, [])
|
self.assertEqual(self.request_inst_mock.mock_calls, [])
|
||||||
@@ -149,7 +149,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
|||||||
self.message_mock.properties['correlation_id'] = '<unknown>'
|
self.message_mock.properties['correlation_id'] = '<unknown>'
|
||||||
response = pr.Response(pr.RUNNING)
|
response = pr.Response(pr.RUNNING)
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||||
ex._on_message(response.to_dict(), self.message_mock)
|
ex._on_message(response.to_dict(), self.message_mock)
|
||||||
|
|
||||||
self.assertEqual(self.request_inst_mock.mock_calls, [])
|
self.assertEqual(self.request_inst_mock.mock_calls, [])
|
||||||
@@ -159,7 +159,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
|||||||
self.message_mock.properties = {'type': pr.RESPONSE}
|
self.message_mock.properties = {'type': pr.RESPONSE}
|
||||||
response = pr.Response(pr.RUNNING)
|
response = pr.Response(pr.RUNNING)
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||||
ex._on_message(response.to_dict(), self.message_mock)
|
ex._on_message(response.to_dict(), self.message_mock)
|
||||||
|
|
||||||
self.assertEqual(self.request_inst_mock.mock_calls, [])
|
self.assertEqual(self.request_inst_mock.mock_calls, [])
|
||||||
@@ -188,32 +188,35 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
|||||||
|
|
||||||
def test_on_wait_task_not_expired(self):
|
def test_on_wait_task_not_expired(self):
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||||
|
|
||||||
self.assertEqual(len(ex._requests_cache._data), 1)
|
self.assertEqual(len(ex._requests_cache), 1)
|
||||||
ex._on_wait()
|
ex._on_wait()
|
||||||
self.assertEqual(len(ex._requests_cache._data), 1)
|
self.assertEqual(len(ex._requests_cache), 1)
|
||||||
|
|
||||||
def test_on_wait_task_expired(self):
|
def test_on_wait_task_expired(self):
|
||||||
self.request_inst_mock.expired = True
|
self.request_inst_mock.expired = True
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||||
|
|
||||||
self.assertEqual(len(ex._requests_cache._data), 1)
|
self.assertEqual(len(ex._requests_cache), 1)
|
||||||
ex._on_wait()
|
ex._on_wait()
|
||||||
self.assertEqual(len(ex._requests_cache._data), 0)
|
self.assertEqual(len(ex._requests_cache), 0)
|
||||||
|
|
||||||
def test_remove_task_non_existent(self):
|
def test_remove_task_non_existent(self):
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||||
|
|
||||||
self.assertEqual(len(ex._requests_cache._data), 1)
|
self.assertEqual(len(ex._requests_cache), 1)
|
||||||
ex._requests_cache.delete(self.task_uuid)
|
del ex._requests_cache[self.task_uuid]
|
||||||
self.assertEqual(len(ex._requests_cache._data), 0)
|
self.assertEqual(len(ex._requests_cache), 0)
|
||||||
|
|
||||||
# delete non-existent
|
# delete non-existent
|
||||||
ex._requests_cache.delete(self.task_uuid)
|
try:
|
||||||
self.assertEqual(len(ex._requests_cache._data), 0)
|
del ex._requests_cache[self.task_uuid]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
self.assertEqual(len(ex._requests_cache), 0)
|
||||||
|
|
||||||
def test_execute_task(self):
|
def test_execute_task(self):
|
||||||
self.message_mock.properties['type'] = pr.NOTIFY
|
self.message_mock.properties['type'] = pr.NOTIFY
|
||||||
|
|||||||
73
taskflow/types/cache.py
Normal file
73
taskflow/types/cache.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
# not use this file except in compliance with the License. You may obtain
|
||||||
|
# a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
# License for the specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
|
from taskflow.utils import lock_utils as lu
|
||||||
|
from taskflow.utils import reflection
|
||||||
|
|
||||||
|
|
||||||
|
class ExpiringCache(object):
|
||||||
|
"""Represents a thread-safe time-based expiring cache.
|
||||||
|
|
||||||
|
NOTE(harlowja): the values in this cache must have a expired attribute that
|
||||||
|
can be used to determine if the key and associated value has expired or if
|
||||||
|
it has not.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._data = {}
|
||||||
|
self._lock = lu.ReaderWriterLock()
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
"""Set a value in the cache."""
|
||||||
|
with self._lock.write_lock():
|
||||||
|
self._data[key] = value
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""Returns how many items are in this cache."""
|
||||||
|
with self._lock.read_lock():
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
"""Retrieve a value from the cache (returns default if not found)."""
|
||||||
|
with self._lock.read_lock():
|
||||||
|
return self._data.get(key, default)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
"""Retrieve a value from the cache."""
|
||||||
|
with self._lock.read_lock():
|
||||||
|
return self._data[key]
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
"""Delete a key & value from the cache."""
|
||||||
|
with self._lock.write_lock():
|
||||||
|
del self._data[key]
|
||||||
|
|
||||||
|
def cleanup(self, on_expired_callback=None):
|
||||||
|
"""Delete out-dated keys & values from the cache."""
|
||||||
|
with self._lock.write_lock():
|
||||||
|
expired_values = [(k, v) for k, v in six.iteritems(self._data)
|
||||||
|
if v.expired]
|
||||||
|
for (k, _v) in expired_values:
|
||||||
|
del self._data[k]
|
||||||
|
if on_expired_callback:
|
||||||
|
arg_c = len(reflection.get_callable_args(on_expired_callback))
|
||||||
|
for (k, v) in expired_values:
|
||||||
|
if arg_c == 2:
|
||||||
|
on_expired_callback(k, v)
|
||||||
|
else:
|
||||||
|
on_expired_callback(v)
|
||||||
Reference in New Issue
Block a user