Merge "Fix for WBE sporadic timeout of tasks"
This commit is contained in:
@@ -12,11 +12,6 @@ Types
|
||||
into *isolated* libraries (as using these types in this manner is not
|
||||
the expected and/or desired usage).
|
||||
|
||||
Cache
|
||||
=====
|
||||
|
||||
.. automodule:: taskflow.types.cache
|
||||
|
||||
Entity
|
||||
======
|
||||
|
||||
|
@@ -15,9 +15,11 @@
|
||||
# under the License.
|
||||
|
||||
import functools
|
||||
import threading
|
||||
|
||||
from futurist import periodics
|
||||
from oslo_utils import timeutils
|
||||
import six
|
||||
|
||||
from taskflow.engines.action_engine import executor
|
||||
from taskflow.engines.worker_based import dispatcher
|
||||
@@ -26,7 +28,7 @@ from taskflow.engines.worker_based import proxy
|
||||
from taskflow.engines.worker_based import types as wt
|
||||
from taskflow import exceptions as exc
|
||||
from taskflow import logging
|
||||
from taskflow import task as task_atom
|
||||
from taskflow.task import EVENT_UPDATE_PROGRESS # noqa
|
||||
from taskflow.utils import kombu_utils as ku
|
||||
from taskflow.utils import misc
|
||||
from taskflow.utils import threading_utils as tu
|
||||
@@ -42,7 +44,8 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
url=None, transport=None, transport_options=None,
|
||||
retry_options=None):
|
||||
self._uuid = uuid
|
||||
self._requests_cache = wt.RequestsCache()
|
||||
self._ongoing_requests = {}
|
||||
self._ongoing_requests_lock = threading.RLock()
|
||||
self._transition_timeout = transition_timeout
|
||||
type_handlers = {
|
||||
pr.RESPONSE: dispatcher.Handler(self._process_response,
|
||||
@@ -61,8 +64,6 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
# pre-existing knowledge of the topics those workers are on to gather
|
||||
# and update this information).
|
||||
self._finder = wt.ProxyWorkerFinder(uuid, self._proxy, topics)
|
||||
self._finder.notifier.register(wt.WorkerFinder.WORKER_ARRIVED,
|
||||
self._on_worker)
|
||||
self._helpers = tu.ThreadBundle()
|
||||
self._helpers.bind(lambda: tu.daemon_thread(self._proxy.start),
|
||||
after_start=lambda t: self._proxy.wait(),
|
||||
@@ -74,25 +75,18 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
after_join=lambda t: p_worker.reset(),
|
||||
before_start=lambda t: p_worker.reset())
|
||||
|
||||
def _on_worker(self, event_type, details):
|
||||
"""Process new worker that has arrived (and fire off any work)."""
|
||||
worker = details['worker']
|
||||
for request in self._requests_cache.get_waiting_requests(worker):
|
||||
if request.transition_and_log_error(pr.PENDING, logger=LOG):
|
||||
self._publish_request(request, worker)
|
||||
|
||||
def _process_response(self, response, message):
|
||||
"""Process response from remote side."""
|
||||
LOG.debug("Started processing response message '%s'",
|
||||
ku.DelayedPretty(message))
|
||||
try:
|
||||
task_uuid = message.properties['correlation_id']
|
||||
request_uuid = message.properties['correlation_id']
|
||||
except KeyError:
|
||||
LOG.warning("The 'correlation_id' message property is"
|
||||
" missing in message '%s'",
|
||||
ku.DelayedPretty(message))
|
||||
else:
|
||||
request = self._requests_cache.get(task_uuid)
|
||||
request = self._ongoing_requests.get(request_uuid)
|
||||
if request is not None:
|
||||
response = pr.Response.from_dict(response)
|
||||
LOG.debug("Extracted response '%s' and matched it to"
|
||||
@@ -105,35 +99,31 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
details = response.data['details']
|
||||
request.notifier.notify(event_type, details)
|
||||
elif response.state in (pr.FAILURE, pr.SUCCESS):
|
||||
moved = request.transition_and_log_error(response.state,
|
||||
logger=LOG)
|
||||
if moved:
|
||||
# NOTE(imelnikov): request should not be in the
|
||||
# cache when another thread can see its result and
|
||||
# schedule another request with the same uuid; so
|
||||
# we remove it, then set the result...
|
||||
del self._requests_cache[request.uuid]
|
||||
if request.transition_and_log_error(response.state,
|
||||
logger=LOG):
|
||||
with self._ongoing_requests_lock:
|
||||
del self._ongoing_requests[request.uuid]
|
||||
request.set_result(**response.data)
|
||||
else:
|
||||
LOG.warning("Unexpected response status '%s'",
|
||||
response.state)
|
||||
else:
|
||||
LOG.debug("Request with id='%s' not found", task_uuid)
|
||||
LOG.debug("Request with id='%s' not found", request_uuid)
|
||||
|
||||
@staticmethod
|
||||
def _handle_expired_request(request):
|
||||
"""Handle expired request.
|
||||
"""Handle a expired request.
|
||||
|
||||
When request has expired it is removed from the requests cache and
|
||||
the `RequestTimeout` exception is set as a request result.
|
||||
When a request has expired it is removed from the ongoing requests
|
||||
dictionary and a ``RequestTimeout`` exception is set as a
|
||||
request result.
|
||||
"""
|
||||
if request.transition_and_log_error(pr.FAILURE, logger=LOG):
|
||||
# Raise an exception (and then catch it) so we get a nice
|
||||
# traceback that the request will get instead of it getting
|
||||
# just an exception with no traceback...
|
||||
try:
|
||||
request_age = timeutils.delta_seconds(request.created_on,
|
||||
timeutils.utcnow())
|
||||
request_age = timeutils.now() - request.created_on
|
||||
raise exc.RequestTimeout(
|
||||
"Request '%s' has expired after waiting for %0.2f"
|
||||
" seconds for it to transition out of (%s) states"
|
||||
@@ -142,51 +132,74 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
with misc.capture_failure() as failure:
|
||||
LOG.debug(failure.exception_str)
|
||||
request.set_result(failure)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _on_wait(self):
|
||||
"""This function is called cyclically between draining events."""
|
||||
self._requests_cache.cleanup(self._handle_expired_request)
|
||||
with self._ongoing_requests_lock:
|
||||
ongoing_requests_uuids = set(six.iterkeys(self._ongoing_requests))
|
||||
waiting_requests = {}
|
||||
expired_requests = {}
|
||||
for request_uuid in ongoing_requests_uuids:
|
||||
try:
|
||||
request = self._ongoing_requests[request_uuid]
|
||||
except KeyError:
|
||||
# Guess it got removed before we got to it...
|
||||
pass
|
||||
else:
|
||||
if request.expired:
|
||||
expired_requests[request_uuid] = request
|
||||
elif request.state == pr.WAITING:
|
||||
worker = self._finder.get_worker_for_task(request.task)
|
||||
if worker is not None:
|
||||
waiting_requests[request_uuid] = (request, worker)
|
||||
if expired_requests:
|
||||
with self._ongoing_requests_lock:
|
||||
while expired_requests:
|
||||
request_uuid, request = expired_requests.popitem()
|
||||
if self._handle_expired_request(request):
|
||||
del self._ongoing_requests[request_uuid]
|
||||
if waiting_requests:
|
||||
while waiting_requests:
|
||||
request_uuid, (request, worker) = waiting_requests.popitem()
|
||||
if request.transition_and_log_error(pr.PENDING, logger=LOG):
|
||||
self._publish_request(request, worker)
|
||||
|
||||
def _submit_task(self, task, task_uuid, action, arguments,
|
||||
progress_callback=None, **kwargs):
|
||||
"""Submit task request to a worker."""
|
||||
request = pr.Request(task, task_uuid, action, arguments,
|
||||
self._transition_timeout, **kwargs)
|
||||
|
||||
# Register the callback, so that we can proxy the progress correctly.
|
||||
if (progress_callback is not None and
|
||||
request.notifier.can_be_registered(
|
||||
task_atom.EVENT_UPDATE_PROGRESS)):
|
||||
request.notifier.register(task_atom.EVENT_UPDATE_PROGRESS,
|
||||
progress_callback)
|
||||
request.notifier.can_be_registered(EVENT_UPDATE_PROGRESS)):
|
||||
request.notifier.register(EVENT_UPDATE_PROGRESS, progress_callback)
|
||||
cleaner = functools.partial(request.notifier.deregister,
|
||||
task_atom.EVENT_UPDATE_PROGRESS,
|
||||
EVENT_UPDATE_PROGRESS,
|
||||
progress_callback)
|
||||
request.result.add_done_callback(lambda fut: cleaner())
|
||||
|
||||
# Get task's worker and publish request if worker was found.
|
||||
worker = self._finder.get_worker_for_task(task)
|
||||
if worker is not None:
|
||||
# NOTE(skudriashev): Make sure request is set to the PENDING state
|
||||
# before putting it into the requests cache to prevent the notify
|
||||
# processing thread get list of waiting requests and publish it
|
||||
# before it is published here, so it wouldn't be published twice.
|
||||
if request.transition_and_log_error(pr.PENDING, logger=LOG):
|
||||
self._requests_cache[request.uuid] = request
|
||||
with self._ongoing_requests_lock:
|
||||
self._ongoing_requests[request.uuid] = request
|
||||
self._publish_request(request, worker)
|
||||
else:
|
||||
LOG.debug("Delaying submission of '%s', no currently known"
|
||||
" worker/s available to process it", request)
|
||||
self._requests_cache[request.uuid] = request
|
||||
|
||||
with self._ongoing_requests_lock:
|
||||
self._ongoing_requests[request.uuid] = request
|
||||
return request.result
|
||||
|
||||
def _publish_request(self, request, worker):
|
||||
"""Publish request to a given topic."""
|
||||
LOG.debug("Submitting execution of '%s' to worker '%s' (expecting"
|
||||
" response identified by reply_to=%s and"
|
||||
" correlation_id=%s)", request, worker, self._uuid,
|
||||
request.uuid)
|
||||
" correlation_id=%s) - waited %0.3f seconds to"
|
||||
" get published", request, worker, self._uuid,
|
||||
request.uuid, timeutils.now() - request.created_on)
|
||||
try:
|
||||
self._proxy.publish(request, worker.topic,
|
||||
reply_to=self._uuid,
|
||||
@@ -196,7 +209,8 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
LOG.critical("Failed to submit '%s' (transitioning it to"
|
||||
" %s)", request, pr.FAILURE, exc_info=True)
|
||||
if request.transition_and_log_error(pr.FAILURE, logger=LOG):
|
||||
del self._requests_cache[request.uuid]
|
||||
with self._ongoing_requests_lock:
|
||||
del self._ongoing_requests[request.uuid]
|
||||
request.set_result(failure)
|
||||
|
||||
def execute_task(self, task, task_uuid, arguments,
|
||||
@@ -229,5 +243,8 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
def stop(self):
|
||||
"""Stops proxy thread and associated topic notification thread."""
|
||||
self._helpers.stop()
|
||||
self._requests_cache.clear(self._handle_expired_request)
|
||||
with self._ongoing_requests_lock:
|
||||
while self._ongoing_requests:
|
||||
_request_uuid, request = self._ongoing_requests.popitem()
|
||||
self._handle_expired_request(request)
|
||||
self._finder.clear()
|
||||
|
@@ -262,7 +262,7 @@ class Request(Message):
|
||||
self._watch = timeutils.StopWatch(duration=timeout).start()
|
||||
self._state = WAITING
|
||||
self._lock = threading.Lock()
|
||||
self._created_on = timeutils.utcnow()
|
||||
self._created_on = timeutils.now()
|
||||
self._result = futurist.Future()
|
||||
self._result.atom = task
|
||||
self._notifier = task.notifier
|
||||
|
@@ -28,27 +28,11 @@ import six
|
||||
from taskflow.engines.worker_based import dispatcher
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow import logging
|
||||
from taskflow.types import cache as base
|
||||
from taskflow.types import notifier
|
||||
from taskflow.utils import kombu_utils as ku
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RequestsCache(base.ExpiringCache):
|
||||
"""Represents a thread-safe requests cache."""
|
||||
|
||||
def get_waiting_requests(self, worker):
|
||||
"""Get list of waiting requests that the given worker can satisfy."""
|
||||
waiting_requests = []
|
||||
with self._lock:
|
||||
for request in six.itervalues(self._data):
|
||||
if request.state == pr.WAITING \
|
||||
and worker.performs(request.task):
|
||||
waiting_requests.append(request)
|
||||
return waiting_requests
|
||||
|
||||
|
||||
# TODO(harlowja): this needs to be made better, once
|
||||
# https://blueprints.launchpad.net/taskflow/+spec/wbe-worker-info is finally
|
||||
# implemented we can go about using that instead.
|
||||
@@ -101,12 +85,8 @@ class TopicWorker(object):
|
||||
class WorkerFinder(object):
|
||||
"""Base class for worker finders..."""
|
||||
|
||||
#: Event type emitted when a new worker arrives.
|
||||
WORKER_ARRIVED = 'worker_arrived'
|
||||
|
||||
def __init__(self):
|
||||
self._cond = threading.Condition()
|
||||
self.notifier = notifier.RestrictedNotifier([self.WORKER_ARRIVED])
|
||||
|
||||
@abc.abstractmethod
|
||||
def _total_workers(self):
|
||||
@@ -219,8 +199,6 @@ class ProxyWorkerFinder(WorkerFinder):
|
||||
LOG.debug("Updated worker '%s' (%s total workers are"
|
||||
" currently known)", worker, self._total_workers())
|
||||
self._cond.notify_all()
|
||||
if new_or_updated:
|
||||
self.notifier.notify(self.WORKER_ARRIVED, {'worker': worker})
|
||||
|
||||
def clear(self):
|
||||
with self._cond:
|
||||
|
@@ -17,9 +17,6 @@
|
||||
import threading
|
||||
import time
|
||||
|
||||
from oslo_utils import fixture
|
||||
from oslo_utils import timeutils
|
||||
|
||||
from taskflow.engines.worker_based import executor
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow import task as task_atom
|
||||
@@ -56,6 +53,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
self.proxy_inst_mock.stop.side_effect = self._fake_proxy_stop
|
||||
self.request_inst_mock.uuid = self.task_uuid
|
||||
self.request_inst_mock.expired = False
|
||||
self.request_inst_mock.created_on = 0
|
||||
self.request_inst_mock.task_cls = self.task.name
|
||||
self.message_mock = mock.MagicMock(name='message')
|
||||
self.message_mock.properties = {'correlation_id': self.task_uuid,
|
||||
@@ -96,7 +94,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
def test_on_message_response_state_running(self):
|
||||
response = pr.Response(pr.RUNNING)
|
||||
ex = self.executor()
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
|
||||
ex._process_response(response.to_dict(), self.message_mock)
|
||||
|
||||
expected_calls = [
|
||||
@@ -109,7 +107,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
event_type=task_atom.EVENT_UPDATE_PROGRESS,
|
||||
details={'progress': 1.0})
|
||||
ex = self.executor()
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
|
||||
ex._process_response(response.to_dict(), self.message_mock)
|
||||
|
||||
expected_calls = [
|
||||
@@ -123,10 +121,10 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
failure_dict = a_failure.to_dict()
|
||||
response = pr.Response(pr.FAILURE, result=failure_dict)
|
||||
ex = self.executor()
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
|
||||
ex._process_response(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(0, len(ex._requests_cache))
|
||||
self.assertEqual(0, len(ex._ongoing_requests))
|
||||
expected_calls = [
|
||||
mock.call.transition_and_log_error(pr.FAILURE, logger=mock.ANY),
|
||||
mock.call.set_result(result=test_utils.FailureMatcher(a_failure))
|
||||
@@ -137,7 +135,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
response = pr.Response(pr.SUCCESS, result=self.task_result,
|
||||
event='executed')
|
||||
ex = self.executor()
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
|
||||
ex._process_response(response.to_dict(), self.message_mock)
|
||||
|
||||
expected_calls = [
|
||||
@@ -149,7 +147,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
def test_on_message_response_unknown_state(self):
|
||||
response = pr.Response(state='<unknown>')
|
||||
ex = self.executor()
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
|
||||
ex._process_response(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual([], self.request_inst_mock.mock_calls)
|
||||
@@ -158,7 +156,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
self.message_mock.properties['correlation_id'] = '<unknown>'
|
||||
response = pr.Response(pr.RUNNING)
|
||||
ex = self.executor()
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
|
||||
ex._process_response(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual([], self.request_inst_mock.mock_calls)
|
||||
@@ -167,48 +165,32 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
self.message_mock.properties = {'type': pr.RESPONSE}
|
||||
response = pr.Response(pr.RUNNING)
|
||||
ex = self.executor()
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
|
||||
ex._process_response(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual([], self.request_inst_mock.mock_calls)
|
||||
|
||||
def test_on_wait_task_not_expired(self):
|
||||
ex = self.executor()
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
|
||||
|
||||
self.assertEqual(1, len(ex._requests_cache))
|
||||
self.assertEqual(1, len(ex._ongoing_requests))
|
||||
ex._on_wait()
|
||||
self.assertEqual(1, len(ex._requests_cache))
|
||||
self.assertEqual(1, len(ex._ongoing_requests))
|
||||
|
||||
def test_on_wait_task_expired(self):
|
||||
now = timeutils.utcnow()
|
||||
f = self.useFixture(fixture.TimeFixture(override_time=now))
|
||||
@mock.patch('oslo_utils.timeutils.now')
|
||||
def test_on_wait_task_expired(self, mock_now):
|
||||
mock_now.side_effect = [0, 120]
|
||||
|
||||
self.request_inst_mock.expired = True
|
||||
self.request_inst_mock.created_on = now
|
||||
self.request_inst_mock.created_on = 0
|
||||
|
||||
f.advance_time_seconds(120)
|
||||
ex = self.executor()
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
|
||||
self.assertEqual(1, len(ex._ongoing_requests))
|
||||
|
||||
self.assertEqual(1, len(ex._requests_cache))
|
||||
ex._on_wait()
|
||||
self.assertEqual(0, len(ex._requests_cache))
|
||||
|
||||
def test_remove_task_non_existent(self):
|
||||
ex = self.executor()
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
|
||||
self.assertEqual(1, len(ex._requests_cache))
|
||||
del ex._requests_cache[self.task_uuid]
|
||||
self.assertEqual(0, len(ex._requests_cache))
|
||||
|
||||
# delete non-existent
|
||||
try:
|
||||
del ex._requests_cache[self.task_uuid]
|
||||
except KeyError:
|
||||
pass
|
||||
self.assertEqual(0, len(ex._requests_cache))
|
||||
self.assertEqual(0, len(ex._ongoing_requests))
|
||||
|
||||
def test_execute_task(self):
|
||||
ex = self.executor()
|
||||
|
@@ -16,63 +16,12 @@
|
||||
|
||||
from oslo_utils import reflection
|
||||
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.engines.worker_based import types as worker_types
|
||||
from taskflow import test
|
||||
from taskflow.test import mock
|
||||
from taskflow.tests import utils
|
||||
|
||||
|
||||
class TestRequestCache(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestRequestCache, self).setUp()
|
||||
self.task = utils.DummyTask()
|
||||
self.task_uuid = 'task-uuid'
|
||||
self.task_action = 'execute'
|
||||
self.task_args = {'a': 'a'}
|
||||
self.timeout = 60
|
||||
|
||||
def request(self, **kwargs):
|
||||
request_kwargs = dict(task=self.task,
|
||||
uuid=self.task_uuid,
|
||||
action=self.task_action,
|
||||
arguments=self.task_args,
|
||||
progress_callback=None,
|
||||
timeout=self.timeout)
|
||||
request_kwargs.update(kwargs)
|
||||
return pr.Request(**request_kwargs)
|
||||
|
||||
@mock.patch('oslo_utils.timeutils.now')
|
||||
def test_requests_cache_expiry(self, now):
|
||||
# Mock out the calls the underlying objects will soon use to return
|
||||
# times that we can control more easily...
|
||||
overrides = [
|
||||
0,
|
||||
1,
|
||||
self.timeout + 1,
|
||||
]
|
||||
now.side_effect = overrides
|
||||
|
||||
cache = worker_types.RequestsCache()
|
||||
cache[self.task_uuid] = self.request()
|
||||
cache.cleanup()
|
||||
self.assertEqual(1, len(cache))
|
||||
cache.cleanup()
|
||||
self.assertEqual(0, len(cache))
|
||||
|
||||
def test_requests_cache_match(self):
|
||||
cache = worker_types.RequestsCache()
|
||||
cache[self.task_uuid] = self.request()
|
||||
cache['task-uuid-2'] = self.request(task=utils.NastyTask(),
|
||||
uuid='task-uuid-2')
|
||||
worker = worker_types.TopicWorker("dummy-topic", [utils.DummyTask],
|
||||
identity="dummy")
|
||||
matches = cache.get_waiting_requests(worker)
|
||||
self.assertEqual(1, len(matches))
|
||||
self.assertEqual(2, len(cache))
|
||||
|
||||
|
||||
class TestTopicWorker(test.TestCase):
|
||||
def test_topic_worker(self):
|
||||
worker = worker_types.TopicWorker("dummy-topic",
|
||||
|
@@ -1,85 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import threading
|
||||
|
||||
from oslo_utils import reflection
|
||||
import six
|
||||
|
||||
|
||||
class ExpiringCache(object):
|
||||
"""Represents a thread-safe time-based expiring cache.
|
||||
|
||||
NOTE(harlowja): the values in this cache must have a expired attribute that
|
||||
can be used to determine if the key and associated value has expired or if
|
||||
it has not.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._data = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
"""Set a value in the cache."""
|
||||
with self._lock:
|
||||
self._data[key] = value
|
||||
|
||||
def __len__(self):
|
||||
"""Returns how many items are in this cache."""
|
||||
return len(self._data)
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""Retrieve a value from the cache (returns default if not found)."""
|
||||
return self._data.get(key, default)
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Retrieve a value from the cache."""
|
||||
return self._data[key]
|
||||
|
||||
def __delitem__(self, key):
|
||||
"""Delete a key & value from the cache."""
|
||||
with self._lock:
|
||||
del self._data[key]
|
||||
|
||||
def clear(self, on_cleared_callback=None):
|
||||
"""Removes all keys & values from the cache."""
|
||||
cleared_items = []
|
||||
with self._lock:
|
||||
if on_cleared_callback is not None:
|
||||
cleared_items.extend(six.iteritems(self._data))
|
||||
self._data.clear()
|
||||
if on_cleared_callback is not None:
|
||||
arg_c = len(reflection.get_callable_args(on_cleared_callback))
|
||||
for (k, v) in cleared_items:
|
||||
if arg_c == 2:
|
||||
on_cleared_callback(k, v)
|
||||
else:
|
||||
on_cleared_callback(v)
|
||||
|
||||
def cleanup(self, on_expired_callback=None):
|
||||
"""Delete out-dated keys & values from the cache."""
|
||||
with self._lock:
|
||||
expired_values = [(k, v) for k, v in six.iteritems(self._data)
|
||||
if v.expired]
|
||||
for (k, _v) in expired_values:
|
||||
del self._data[k]
|
||||
if on_expired_callback is not None:
|
||||
arg_c = len(reflection.get_callable_args(on_expired_callback))
|
||||
for (k, v) in expired_values:
|
||||
if arg_c == 2:
|
||||
on_expired_callback(k, v)
|
||||
else:
|
||||
on_expired_callback(v)
|
Reference in New Issue
Block a user