Make the expiring cache a top level cache type
Create a cache module type and adjust a few of its methods to be more pythonic and then switch out the work_based engines usage of it and adjust its tests methods with adjusted methods using the new cache types functionality. Part of blueprint top-level-types Change-Id: I75c4b7db6dd989ef328e9e14d4b00266b1c97a9f
This commit is contained in:

committed by
Joshua Harlow

parent
95cb0625f4
commit
5d74d72257
@@ -14,54 +14,16 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
import six
|
||||
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.utils import lock_utils as lu
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
from taskflow.types import cache as base
|
||||
|
||||
|
||||
class Cache(object):
|
||||
"""Represents thread-safe cache."""
|
||||
|
||||
def __init__(self):
|
||||
self._data = {}
|
||||
self._lock = lu.ReaderWriterLock()
|
||||
|
||||
def get(self, key):
|
||||
"""Retrieve a value from the cache."""
|
||||
with self._lock.read_lock():
|
||||
return self._data.get(key)
|
||||
|
||||
def set(self, key, value):
|
||||
"""Set a value in the cache."""
|
||||
with self._lock.write_lock():
|
||||
self._data[key] = value
|
||||
LOG.debug("Cache updated. Capacity: %s", len(self._data))
|
||||
|
||||
def delete(self, key):
|
||||
"""Delete a value from the cache."""
|
||||
with self._lock.write_lock():
|
||||
self._data.pop(key, None)
|
||||
|
||||
def cleanup(self, on_expired_callback=None):
|
||||
"""Delete out-dated values from the cache."""
|
||||
with self._lock.write_lock():
|
||||
expired_values = [(k, v) for k, v in six.iteritems(self._data)
|
||||
if v.expired]
|
||||
for (k, _v) in expired_values:
|
||||
self._data.pop(k, None)
|
||||
if on_expired_callback:
|
||||
for (_k, v) in expired_values:
|
||||
on_expired_callback(v)
|
||||
|
||||
|
||||
class RequestsCache(Cache):
|
||||
"""Represents thread-safe requests cache."""
|
||||
class RequestsCache(base.ExpiringCache):
|
||||
"""Represents a thread-safe requests cache."""
|
||||
|
||||
def get_waiting_requests(self, tasks):
|
||||
"""Get list of waiting requests by tasks."""
|
||||
@@ -73,8 +35,8 @@ class RequestsCache(Cache):
|
||||
return waiting_requests
|
||||
|
||||
|
||||
class WorkersCache(Cache):
|
||||
"""Represents thread-safe workers cache."""
|
||||
class WorkersCache(base.ExpiringCache):
|
||||
"""Represents a thread-safe workers cache."""
|
||||
|
||||
def get_topic_by_task(self, task):
|
||||
"""Get topic for a given task."""
|
||||
|
@@ -110,7 +110,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
tasks = notify['tasks']
|
||||
|
||||
# add worker info to the cache
|
||||
self._workers_cache.set(topic, tasks)
|
||||
self._workers_cache[topic] = tasks
|
||||
|
||||
# publish waiting requests
|
||||
for request in self._requests_cache.get_waiting_requests(tasks):
|
||||
@@ -137,7 +137,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
# NOTE(imelnikov): request should not be in cache when
|
||||
# another thread can see its result and schedule another
|
||||
# request with same uuid; so we remove it, then set result
|
||||
self._requests_cache.delete(request.uuid)
|
||||
del self._requests_cache[request.uuid]
|
||||
request.set_result(**response.data)
|
||||
else:
|
||||
LOG.warning("Unexpected response status: '%s'",
|
||||
@@ -175,10 +175,10 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
# processing thread get list of waiting requests and publish it
|
||||
# before it is published here, so it wouldn't be published twice.
|
||||
request.set_pending()
|
||||
self._requests_cache.set(request.uuid, request)
|
||||
self._requests_cache[request.uuid] = request
|
||||
self._publish_request(request, topic)
|
||||
else:
|
||||
self._requests_cache.set(request.uuid, request)
|
||||
self._requests_cache[request.uuid] = request
|
||||
|
||||
return request.result
|
||||
|
||||
@@ -191,9 +191,8 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
correlation_id=request.uuid)
|
||||
except Exception:
|
||||
with misc.capture_failure() as failure:
|
||||
LOG.exception("Failed to submit the '%s' request." %
|
||||
request)
|
||||
self._requests_cache.delete(request.uuid)
|
||||
LOG.exception("Failed to submit the '%s' request.", request)
|
||||
del self._requests_cache[request.uuid]
|
||||
request.set_result(failure)
|
||||
|
||||
def _notify_topics(self):
|
||||
|
@@ -93,7 +93,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
def test_on_message_response_state_running(self):
|
||||
response = pr.Response(pr.RUNNING)
|
||||
ex = self.executor()
|
||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._on_message(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_inst_mock.mock_calls,
|
||||
@@ -103,7 +103,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
def test_on_message_response_state_progress(self):
|
||||
response = pr.Response(pr.PROGRESS, progress=1.0)
|
||||
ex = self.executor()
|
||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._on_message(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_inst_mock.mock_calls,
|
||||
@@ -115,10 +115,10 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
failure_dict = failure.to_dict()
|
||||
response = pr.Response(pr.FAILURE, result=failure_dict)
|
||||
ex = self.executor()
|
||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._on_message(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(len(ex._requests_cache._data), 0)
|
||||
self.assertEqual(len(ex._requests_cache), 0)
|
||||
self.assertEqual(self.request_inst_mock.mock_calls, [
|
||||
mock.call.set_result(result=utils.FailureMatcher(failure))
|
||||
])
|
||||
@@ -128,7 +128,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
response = pr.Response(pr.SUCCESS, result=self.task_result,
|
||||
event='executed')
|
||||
ex = self.executor()
|
||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._on_message(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_inst_mock.mock_calls,
|
||||
@@ -139,7 +139,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
def test_on_message_response_unknown_state(self):
|
||||
response = pr.Response(state='<unknown>')
|
||||
ex = self.executor()
|
||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._on_message(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_inst_mock.mock_calls, [])
|
||||
@@ -149,7 +149,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
self.message_mock.properties['correlation_id'] = '<unknown>'
|
||||
response = pr.Response(pr.RUNNING)
|
||||
ex = self.executor()
|
||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._on_message(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_inst_mock.mock_calls, [])
|
||||
@@ -159,7 +159,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
self.message_mock.properties = {'type': pr.RESPONSE}
|
||||
response = pr.Response(pr.RUNNING)
|
||||
ex = self.executor()
|
||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._on_message(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_inst_mock.mock_calls, [])
|
||||
@@ -188,32 +188,35 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
|
||||
def test_on_wait_task_not_expired(self):
|
||||
ex = self.executor()
|
||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
|
||||
self.assertEqual(len(ex._requests_cache._data), 1)
|
||||
self.assertEqual(len(ex._requests_cache), 1)
|
||||
ex._on_wait()
|
||||
self.assertEqual(len(ex._requests_cache._data), 1)
|
||||
self.assertEqual(len(ex._requests_cache), 1)
|
||||
|
||||
def test_on_wait_task_expired(self):
|
||||
self.request_inst_mock.expired = True
|
||||
ex = self.executor()
|
||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
|
||||
self.assertEqual(len(ex._requests_cache._data), 1)
|
||||
self.assertEqual(len(ex._requests_cache), 1)
|
||||
ex._on_wait()
|
||||
self.assertEqual(len(ex._requests_cache._data), 0)
|
||||
self.assertEqual(len(ex._requests_cache), 0)
|
||||
|
||||
def test_remove_task_non_existent(self):
|
||||
ex = self.executor()
|
||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
|
||||
self.assertEqual(len(ex._requests_cache._data), 1)
|
||||
ex._requests_cache.delete(self.task_uuid)
|
||||
self.assertEqual(len(ex._requests_cache._data), 0)
|
||||
self.assertEqual(len(ex._requests_cache), 1)
|
||||
del ex._requests_cache[self.task_uuid]
|
||||
self.assertEqual(len(ex._requests_cache), 0)
|
||||
|
||||
# delete non-existent
|
||||
ex._requests_cache.delete(self.task_uuid)
|
||||
self.assertEqual(len(ex._requests_cache._data), 0)
|
||||
try:
|
||||
del ex._requests_cache[self.task_uuid]
|
||||
except KeyError:
|
||||
pass
|
||||
self.assertEqual(len(ex._requests_cache), 0)
|
||||
|
||||
def test_execute_task(self):
|
||||
self.message_mock.properties['type'] = pr.NOTIFY
|
||||
|
73
taskflow/types/cache.py
Normal file
73
taskflow/types/cache.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import six
|
||||
|
||||
from taskflow.utils import lock_utils as lu
|
||||
from taskflow.utils import reflection
|
||||
|
||||
|
||||
class ExpiringCache(object):
|
||||
"""Represents a thread-safe time-based expiring cache.
|
||||
|
||||
NOTE(harlowja): the values in this cache must have a expired attribute that
|
||||
can be used to determine if the key and associated value has expired or if
|
||||
it has not.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._data = {}
|
||||
self._lock = lu.ReaderWriterLock()
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
"""Set a value in the cache."""
|
||||
with self._lock.write_lock():
|
||||
self._data[key] = value
|
||||
|
||||
def __len__(self):
|
||||
"""Returns how many items are in this cache."""
|
||||
with self._lock.read_lock():
|
||||
return len(self._data)
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""Retrieve a value from the cache (returns default if not found)."""
|
||||
with self._lock.read_lock():
|
||||
return self._data.get(key, default)
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Retrieve a value from the cache."""
|
||||
with self._lock.read_lock():
|
||||
return self._data[key]
|
||||
|
||||
def __delitem__(self, key):
|
||||
"""Delete a key & value from the cache."""
|
||||
with self._lock.write_lock():
|
||||
del self._data[key]
|
||||
|
||||
def cleanup(self, on_expired_callback=None):
|
||||
"""Delete out-dated keys & values from the cache."""
|
||||
with self._lock.write_lock():
|
||||
expired_values = [(k, v) for k, v in six.iteritems(self._data)
|
||||
if v.expired]
|
||||
for (k, _v) in expired_values:
|
||||
del self._data[k]
|
||||
if on_expired_callback:
|
||||
arg_c = len(reflection.get_callable_args(on_expired_callback))
|
||||
for (k, v) in expired_values:
|
||||
if arg_c == 2:
|
||||
on_expired_callback(k, v)
|
||||
else:
|
||||
on_expired_callback(v)
|
Reference in New Issue
Block a user