Fix for WBE sporadic timeout of tasks

This fixes the sporadic of tasks that would happen
under certain circumstances. What happened was that
a new worker notification would be sent to a callback
while at the same time a task submission would come in
and there would be a small race period where the task
would insert itself into the requests cache while the
callback was processing.

So to work around this the whole concept of a requests
cache was revamped and now the WBE executor just maintains
its own local dictionary of ongoing requests and accesses
it safely.

During the on_wait function that is periodically called
by kombu the previous expiry of work happens but now any
requests that are pending are matched to any new workers
that may have appeared.

This avoids the race (and ensures that even if a new
worker is found but a submission is in progress that the
duration until that submission happens will only be until
the next on_wait call happens).

Related-Bug: #1431097

Change-Id: I98b0caeedc77ab2f7214847763ae1eb0433d4a78
This commit is contained in:
Joshua Harlow 2016-02-04 18:09:24 -08:00
parent 61efc31e96
commit cea71f2799
7 changed files with 83 additions and 247 deletions

View File

@ -12,11 +12,6 @@ Types
into *isolated* libraries (as using these types in this manner is not into *isolated* libraries (as using these types in this manner is not
the expected and/or desired usage). the expected and/or desired usage).
Cache
=====
.. automodule:: taskflow.types.cache
Entity Entity
====== ======

View File

@ -15,9 +15,11 @@
# under the License. # under the License.
import functools import functools
import threading
from futurist import periodics from futurist import periodics
from oslo_utils import timeutils from oslo_utils import timeutils
import six
from taskflow.engines.action_engine import executor from taskflow.engines.action_engine import executor
from taskflow.engines.worker_based import dispatcher from taskflow.engines.worker_based import dispatcher
@ -26,7 +28,7 @@ from taskflow.engines.worker_based import proxy
from taskflow.engines.worker_based import types as wt from taskflow.engines.worker_based import types as wt
from taskflow import exceptions as exc from taskflow import exceptions as exc
from taskflow import logging from taskflow import logging
from taskflow import task as task_atom from taskflow.task import EVENT_UPDATE_PROGRESS # noqa
from taskflow.utils import kombu_utils as ku from taskflow.utils import kombu_utils as ku
from taskflow.utils import misc from taskflow.utils import misc
from taskflow.utils import threading_utils as tu from taskflow.utils import threading_utils as tu
@ -42,7 +44,8 @@ class WorkerTaskExecutor(executor.TaskExecutor):
url=None, transport=None, transport_options=None, url=None, transport=None, transport_options=None,
retry_options=None): retry_options=None):
self._uuid = uuid self._uuid = uuid
self._requests_cache = wt.RequestsCache() self._ongoing_requests = {}
self._ongoing_requests_lock = threading.RLock()
self._transition_timeout = transition_timeout self._transition_timeout = transition_timeout
type_handlers = { type_handlers = {
pr.RESPONSE: dispatcher.Handler(self._process_response, pr.RESPONSE: dispatcher.Handler(self._process_response,
@ -61,8 +64,6 @@ class WorkerTaskExecutor(executor.TaskExecutor):
# pre-existing knowledge of the topics those workers are on to gather # pre-existing knowledge of the topics those workers are on to gather
# and update this information). # and update this information).
self._finder = wt.ProxyWorkerFinder(uuid, self._proxy, topics) self._finder = wt.ProxyWorkerFinder(uuid, self._proxy, topics)
self._finder.notifier.register(wt.WorkerFinder.WORKER_ARRIVED,
self._on_worker)
self._helpers = tu.ThreadBundle() self._helpers = tu.ThreadBundle()
self._helpers.bind(lambda: tu.daemon_thread(self._proxy.start), self._helpers.bind(lambda: tu.daemon_thread(self._proxy.start),
after_start=lambda t: self._proxy.wait(), after_start=lambda t: self._proxy.wait(),
@ -74,25 +75,18 @@ class WorkerTaskExecutor(executor.TaskExecutor):
after_join=lambda t: p_worker.reset(), after_join=lambda t: p_worker.reset(),
before_start=lambda t: p_worker.reset()) before_start=lambda t: p_worker.reset())
def _on_worker(self, event_type, details):
"""Process new worker that has arrived (and fire off any work)."""
worker = details['worker']
for request in self._requests_cache.get_waiting_requests(worker):
if request.transition_and_log_error(pr.PENDING, logger=LOG):
self._publish_request(request, worker)
def _process_response(self, response, message): def _process_response(self, response, message):
"""Process response from remote side.""" """Process response from remote side."""
LOG.debug("Started processing response message '%s'", LOG.debug("Started processing response message '%s'",
ku.DelayedPretty(message)) ku.DelayedPretty(message))
try: try:
task_uuid = message.properties['correlation_id'] request_uuid = message.properties['correlation_id']
except KeyError: except KeyError:
LOG.warning("The 'correlation_id' message property is" LOG.warning("The 'correlation_id' message property is"
" missing in message '%s'", " missing in message '%s'",
ku.DelayedPretty(message)) ku.DelayedPretty(message))
else: else:
request = self._requests_cache.get(task_uuid) request = self._ongoing_requests.get(request_uuid)
if request is not None: if request is not None:
response = pr.Response.from_dict(response) response = pr.Response.from_dict(response)
LOG.debug("Extracted response '%s' and matched it to" LOG.debug("Extracted response '%s' and matched it to"
@ -105,35 +99,31 @@ class WorkerTaskExecutor(executor.TaskExecutor):
details = response.data['details'] details = response.data['details']
request.notifier.notify(event_type, details) request.notifier.notify(event_type, details)
elif response.state in (pr.FAILURE, pr.SUCCESS): elif response.state in (pr.FAILURE, pr.SUCCESS):
moved = request.transition_and_log_error(response.state, if request.transition_and_log_error(response.state,
logger=LOG) logger=LOG):
if moved: with self._ongoing_requests_lock:
# NOTE(imelnikov): request should not be in the del self._ongoing_requests[request.uuid]
# cache when another thread can see its result and
# schedule another request with the same uuid; so
# we remove it, then set the result...
del self._requests_cache[request.uuid]
request.set_result(**response.data) request.set_result(**response.data)
else: else:
LOG.warning("Unexpected response status '%s'", LOG.warning("Unexpected response status '%s'",
response.state) response.state)
else: else:
LOG.debug("Request with id='%s' not found", task_uuid) LOG.debug("Request with id='%s' not found", request_uuid)
@staticmethod @staticmethod
def _handle_expired_request(request): def _handle_expired_request(request):
"""Handle expired request. """Handle a expired request.
When request has expired it is removed from the requests cache and When a request has expired it is removed from the ongoing requests
the `RequestTimeout` exception is set as a request result. dictionary and a ``RequestTimeout`` exception is set as a
request result.
""" """
if request.transition_and_log_error(pr.FAILURE, logger=LOG): if request.transition_and_log_error(pr.FAILURE, logger=LOG):
# Raise an exception (and then catch it) so we get a nice # Raise an exception (and then catch it) so we get a nice
# traceback that the request will get instead of it getting # traceback that the request will get instead of it getting
# just an exception with no traceback... # just an exception with no traceback...
try: try:
request_age = timeutils.delta_seconds(request.created_on, request_age = timeutils.now() - request.created_on
timeutils.utcnow())
raise exc.RequestTimeout( raise exc.RequestTimeout(
"Request '%s' has expired after waiting for %0.2f" "Request '%s' has expired after waiting for %0.2f"
" seconds for it to transition out of (%s) states" " seconds for it to transition out of (%s) states"
@ -142,51 +132,74 @@ class WorkerTaskExecutor(executor.TaskExecutor):
with misc.capture_failure() as failure: with misc.capture_failure() as failure:
LOG.debug(failure.exception_str) LOG.debug(failure.exception_str)
request.set_result(failure) request.set_result(failure)
return True
return False
def _on_wait(self): def _on_wait(self):
"""This function is called cyclically between draining events.""" """This function is called cyclically between draining events."""
self._requests_cache.cleanup(self._handle_expired_request) with self._ongoing_requests_lock:
ongoing_requests_uuids = set(six.iterkeys(self._ongoing_requests))
waiting_requests = {}
expired_requests = {}
for request_uuid in ongoing_requests_uuids:
try:
request = self._ongoing_requests[request_uuid]
except KeyError:
# Guess it got removed before we got to it...
pass
else:
if request.expired:
expired_requests[request_uuid] = request
elif request.state == pr.WAITING:
worker = self._finder.get_worker_for_task(request.task)
if worker is not None:
waiting_requests[request_uuid] = (request, worker)
if expired_requests:
with self._ongoing_requests_lock:
while expired_requests:
request_uuid, request = expired_requests.popitem()
if self._handle_expired_request(request):
del self._ongoing_requests[request_uuid]
if waiting_requests:
while waiting_requests:
request_uuid, (request, worker) = waiting_requests.popitem()
if request.transition_and_log_error(pr.PENDING, logger=LOG):
self._publish_request(request, worker)
def _submit_task(self, task, task_uuid, action, arguments, def _submit_task(self, task, task_uuid, action, arguments,
progress_callback=None, **kwargs): progress_callback=None, **kwargs):
"""Submit task request to a worker.""" """Submit task request to a worker."""
request = pr.Request(task, task_uuid, action, arguments, request = pr.Request(task, task_uuid, action, arguments,
self._transition_timeout, **kwargs) self._transition_timeout, **kwargs)
# Register the callback, so that we can proxy the progress correctly. # Register the callback, so that we can proxy the progress correctly.
if (progress_callback is not None and if (progress_callback is not None and
request.notifier.can_be_registered( request.notifier.can_be_registered(EVENT_UPDATE_PROGRESS)):
task_atom.EVENT_UPDATE_PROGRESS)): request.notifier.register(EVENT_UPDATE_PROGRESS, progress_callback)
request.notifier.register(task_atom.EVENT_UPDATE_PROGRESS,
progress_callback)
cleaner = functools.partial(request.notifier.deregister, cleaner = functools.partial(request.notifier.deregister,
task_atom.EVENT_UPDATE_PROGRESS, EVENT_UPDATE_PROGRESS,
progress_callback) progress_callback)
request.result.add_done_callback(lambda fut: cleaner()) request.result.add_done_callback(lambda fut: cleaner())
# Get task's worker and publish request if worker was found. # Get task's worker and publish request if worker was found.
worker = self._finder.get_worker_for_task(task) worker = self._finder.get_worker_for_task(task)
if worker is not None: if worker is not None:
# NOTE(skudriashev): Make sure request is set to the PENDING state
# before putting it into the requests cache to prevent the notify
# processing thread get list of waiting requests and publish it
# before it is published here, so it wouldn't be published twice.
if request.transition_and_log_error(pr.PENDING, logger=LOG): if request.transition_and_log_error(pr.PENDING, logger=LOG):
self._requests_cache[request.uuid] = request with self._ongoing_requests_lock:
self._ongoing_requests[request.uuid] = request
self._publish_request(request, worker) self._publish_request(request, worker)
else: else:
LOG.debug("Delaying submission of '%s', no currently known" LOG.debug("Delaying submission of '%s', no currently known"
" worker/s available to process it", request) " worker/s available to process it", request)
self._requests_cache[request.uuid] = request with self._ongoing_requests_lock:
self._ongoing_requests[request.uuid] = request
return request.result return request.result
def _publish_request(self, request, worker): def _publish_request(self, request, worker):
"""Publish request to a given topic.""" """Publish request to a given topic."""
LOG.debug("Submitting execution of '%s' to worker '%s' (expecting" LOG.debug("Submitting execution of '%s' to worker '%s' (expecting"
" response identified by reply_to=%s and" " response identified by reply_to=%s and"
" correlation_id=%s)", request, worker, self._uuid, " correlation_id=%s) - waited %0.3f seconds to"
request.uuid) " get published", request, worker, self._uuid,
request.uuid, timeutils.now() - request.created_on)
try: try:
self._proxy.publish(request, worker.topic, self._proxy.publish(request, worker.topic,
reply_to=self._uuid, reply_to=self._uuid,
@ -196,7 +209,8 @@ class WorkerTaskExecutor(executor.TaskExecutor):
LOG.critical("Failed to submit '%s' (transitioning it to" LOG.critical("Failed to submit '%s' (transitioning it to"
" %s)", request, pr.FAILURE, exc_info=True) " %s)", request, pr.FAILURE, exc_info=True)
if request.transition_and_log_error(pr.FAILURE, logger=LOG): if request.transition_and_log_error(pr.FAILURE, logger=LOG):
del self._requests_cache[request.uuid] with self._ongoing_requests_lock:
del self._ongoing_requests[request.uuid]
request.set_result(failure) request.set_result(failure)
def execute_task(self, task, task_uuid, arguments, def execute_task(self, task, task_uuid, arguments,
@ -229,5 +243,8 @@ class WorkerTaskExecutor(executor.TaskExecutor):
def stop(self): def stop(self):
"""Stops proxy thread and associated topic notification thread.""" """Stops proxy thread and associated topic notification thread."""
self._helpers.stop() self._helpers.stop()
self._requests_cache.clear(self._handle_expired_request) with self._ongoing_requests_lock:
while self._ongoing_requests:
_request_uuid, request = self._ongoing_requests.popitem()
self._handle_expired_request(request)
self._finder.clear() self._finder.clear()

View File

@ -262,7 +262,7 @@ class Request(Message):
self._watch = timeutils.StopWatch(duration=timeout).start() self._watch = timeutils.StopWatch(duration=timeout).start()
self._state = WAITING self._state = WAITING
self._lock = threading.Lock() self._lock = threading.Lock()
self._created_on = timeutils.utcnow() self._created_on = timeutils.now()
self._result = futurist.Future() self._result = futurist.Future()
self._result.atom = task self._result.atom = task
self._notifier = task.notifier self._notifier = task.notifier

View File

@ -28,27 +28,11 @@ import six
from taskflow.engines.worker_based import dispatcher from taskflow.engines.worker_based import dispatcher
from taskflow.engines.worker_based import protocol as pr from taskflow.engines.worker_based import protocol as pr
from taskflow import logging from taskflow import logging
from taskflow.types import cache as base
from taskflow.types import notifier
from taskflow.utils import kombu_utils as ku from taskflow.utils import kombu_utils as ku
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
class RequestsCache(base.ExpiringCache):
"""Represents a thread-safe requests cache."""
def get_waiting_requests(self, worker):
"""Get list of waiting requests that the given worker can satisfy."""
waiting_requests = []
with self._lock:
for request in six.itervalues(self._data):
if request.state == pr.WAITING \
and worker.performs(request.task):
waiting_requests.append(request)
return waiting_requests
# TODO(harlowja): this needs to be made better, once # TODO(harlowja): this needs to be made better, once
# https://blueprints.launchpad.net/taskflow/+spec/wbe-worker-info is finally # https://blueprints.launchpad.net/taskflow/+spec/wbe-worker-info is finally
# implemented we can go about using that instead. # implemented we can go about using that instead.
@ -101,12 +85,8 @@ class TopicWorker(object):
class WorkerFinder(object): class WorkerFinder(object):
"""Base class for worker finders...""" """Base class for worker finders..."""
#: Event type emitted when a new worker arrives.
WORKER_ARRIVED = 'worker_arrived'
def __init__(self): def __init__(self):
self._cond = threading.Condition() self._cond = threading.Condition()
self.notifier = notifier.RestrictedNotifier([self.WORKER_ARRIVED])
@abc.abstractmethod @abc.abstractmethod
def _total_workers(self): def _total_workers(self):
@ -219,8 +199,6 @@ class ProxyWorkerFinder(WorkerFinder):
LOG.debug("Updated worker '%s' (%s total workers are" LOG.debug("Updated worker '%s' (%s total workers are"
" currently known)", worker, self._total_workers()) " currently known)", worker, self._total_workers())
self._cond.notify_all() self._cond.notify_all()
if new_or_updated:
self.notifier.notify(self.WORKER_ARRIVED, {'worker': worker})
def clear(self): def clear(self):
with self._cond: with self._cond:

View File

@ -17,9 +17,6 @@
import threading import threading
import time import time
from oslo_utils import fixture
from oslo_utils import timeutils
from taskflow.engines.worker_based import executor from taskflow.engines.worker_based import executor
from taskflow.engines.worker_based import protocol as pr from taskflow.engines.worker_based import protocol as pr
from taskflow import task as task_atom from taskflow import task as task_atom
@ -56,6 +53,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
self.proxy_inst_mock.stop.side_effect = self._fake_proxy_stop self.proxy_inst_mock.stop.side_effect = self._fake_proxy_stop
self.request_inst_mock.uuid = self.task_uuid self.request_inst_mock.uuid = self.task_uuid
self.request_inst_mock.expired = False self.request_inst_mock.expired = False
self.request_inst_mock.created_on = 0
self.request_inst_mock.task_cls = self.task.name self.request_inst_mock.task_cls = self.task.name
self.message_mock = mock.MagicMock(name='message') self.message_mock = mock.MagicMock(name='message')
self.message_mock.properties = {'correlation_id': self.task_uuid, self.message_mock.properties = {'correlation_id': self.task_uuid,
@ -96,7 +94,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
def test_on_message_response_state_running(self): def test_on_message_response_state_running(self):
response = pr.Response(pr.RUNNING) response = pr.Response(pr.RUNNING)
ex = self.executor() ex = self.executor()
ex._requests_cache[self.task_uuid] = self.request_inst_mock ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
ex._process_response(response.to_dict(), self.message_mock) ex._process_response(response.to_dict(), self.message_mock)
expected_calls = [ expected_calls = [
@ -109,7 +107,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
event_type=task_atom.EVENT_UPDATE_PROGRESS, event_type=task_atom.EVENT_UPDATE_PROGRESS,
details={'progress': 1.0}) details={'progress': 1.0})
ex = self.executor() ex = self.executor()
ex._requests_cache[self.task_uuid] = self.request_inst_mock ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
ex._process_response(response.to_dict(), self.message_mock) ex._process_response(response.to_dict(), self.message_mock)
expected_calls = [ expected_calls = [
@ -123,10 +121,10 @@ class TestWorkerTaskExecutor(test.MockTestCase):
failure_dict = a_failure.to_dict() failure_dict = a_failure.to_dict()
response = pr.Response(pr.FAILURE, result=failure_dict) response = pr.Response(pr.FAILURE, result=failure_dict)
ex = self.executor() ex = self.executor()
ex._requests_cache[self.task_uuid] = self.request_inst_mock ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
ex._process_response(response.to_dict(), self.message_mock) ex._process_response(response.to_dict(), self.message_mock)
self.assertEqual(0, len(ex._requests_cache)) self.assertEqual(0, len(ex._ongoing_requests))
expected_calls = [ expected_calls = [
mock.call.transition_and_log_error(pr.FAILURE, logger=mock.ANY), mock.call.transition_and_log_error(pr.FAILURE, logger=mock.ANY),
mock.call.set_result(result=test_utils.FailureMatcher(a_failure)) mock.call.set_result(result=test_utils.FailureMatcher(a_failure))
@ -137,7 +135,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
response = pr.Response(pr.SUCCESS, result=self.task_result, response = pr.Response(pr.SUCCESS, result=self.task_result,
event='executed') event='executed')
ex = self.executor() ex = self.executor()
ex._requests_cache[self.task_uuid] = self.request_inst_mock ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
ex._process_response(response.to_dict(), self.message_mock) ex._process_response(response.to_dict(), self.message_mock)
expected_calls = [ expected_calls = [
@ -149,7 +147,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
def test_on_message_response_unknown_state(self): def test_on_message_response_unknown_state(self):
response = pr.Response(state='<unknown>') response = pr.Response(state='<unknown>')
ex = self.executor() ex = self.executor()
ex._requests_cache[self.task_uuid] = self.request_inst_mock ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
ex._process_response(response.to_dict(), self.message_mock) ex._process_response(response.to_dict(), self.message_mock)
self.assertEqual([], self.request_inst_mock.mock_calls) self.assertEqual([], self.request_inst_mock.mock_calls)
@ -158,7 +156,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
self.message_mock.properties['correlation_id'] = '<unknown>' self.message_mock.properties['correlation_id'] = '<unknown>'
response = pr.Response(pr.RUNNING) response = pr.Response(pr.RUNNING)
ex = self.executor() ex = self.executor()
ex._requests_cache[self.task_uuid] = self.request_inst_mock ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
ex._process_response(response.to_dict(), self.message_mock) ex._process_response(response.to_dict(), self.message_mock)
self.assertEqual([], self.request_inst_mock.mock_calls) self.assertEqual([], self.request_inst_mock.mock_calls)
@ -167,48 +165,32 @@ class TestWorkerTaskExecutor(test.MockTestCase):
self.message_mock.properties = {'type': pr.RESPONSE} self.message_mock.properties = {'type': pr.RESPONSE}
response = pr.Response(pr.RUNNING) response = pr.Response(pr.RUNNING)
ex = self.executor() ex = self.executor()
ex._requests_cache[self.task_uuid] = self.request_inst_mock ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
ex._process_response(response.to_dict(), self.message_mock) ex._process_response(response.to_dict(), self.message_mock)
self.assertEqual([], self.request_inst_mock.mock_calls) self.assertEqual([], self.request_inst_mock.mock_calls)
def test_on_wait_task_not_expired(self): def test_on_wait_task_not_expired(self):
ex = self.executor() ex = self.executor()
ex._requests_cache[self.task_uuid] = self.request_inst_mock ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
self.assertEqual(1, len(ex._requests_cache)) self.assertEqual(1, len(ex._ongoing_requests))
ex._on_wait() ex._on_wait()
self.assertEqual(1, len(ex._requests_cache)) self.assertEqual(1, len(ex._ongoing_requests))
def test_on_wait_task_expired(self): @mock.patch('oslo_utils.timeutils.now')
now = timeutils.utcnow() def test_on_wait_task_expired(self, mock_now):
f = self.useFixture(fixture.TimeFixture(override_time=now)) mock_now.side_effect = [0, 120]
self.request_inst_mock.expired = True self.request_inst_mock.expired = True
self.request_inst_mock.created_on = now self.request_inst_mock.created_on = 0
f.advance_time_seconds(120)
ex = self.executor() ex = self.executor()
ex._requests_cache[self.task_uuid] = self.request_inst_mock ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
self.assertEqual(1, len(ex._ongoing_requests))
self.assertEqual(1, len(ex._requests_cache))
ex._on_wait() ex._on_wait()
self.assertEqual(0, len(ex._requests_cache)) self.assertEqual(0, len(ex._ongoing_requests))
def test_remove_task_non_existent(self):
ex = self.executor()
ex._requests_cache[self.task_uuid] = self.request_inst_mock
self.assertEqual(1, len(ex._requests_cache))
del ex._requests_cache[self.task_uuid]
self.assertEqual(0, len(ex._requests_cache))
# delete non-existent
try:
del ex._requests_cache[self.task_uuid]
except KeyError:
pass
self.assertEqual(0, len(ex._requests_cache))
def test_execute_task(self): def test_execute_task(self):
ex = self.executor() ex = self.executor()

View File

@ -16,63 +16,12 @@
from oslo_utils import reflection from oslo_utils import reflection
from taskflow.engines.worker_based import protocol as pr
from taskflow.engines.worker_based import types as worker_types from taskflow.engines.worker_based import types as worker_types
from taskflow import test from taskflow import test
from taskflow.test import mock from taskflow.test import mock
from taskflow.tests import utils from taskflow.tests import utils
class TestRequestCache(test.TestCase):
def setUp(self):
super(TestRequestCache, self).setUp()
self.task = utils.DummyTask()
self.task_uuid = 'task-uuid'
self.task_action = 'execute'
self.task_args = {'a': 'a'}
self.timeout = 60
def request(self, **kwargs):
request_kwargs = dict(task=self.task,
uuid=self.task_uuid,
action=self.task_action,
arguments=self.task_args,
progress_callback=None,
timeout=self.timeout)
request_kwargs.update(kwargs)
return pr.Request(**request_kwargs)
@mock.patch('oslo_utils.timeutils.now')
def test_requests_cache_expiry(self, now):
# Mock out the calls the underlying objects will soon use to return
# times that we can control more easily...
overrides = [
0,
1,
self.timeout + 1,
]
now.side_effect = overrides
cache = worker_types.RequestsCache()
cache[self.task_uuid] = self.request()
cache.cleanup()
self.assertEqual(1, len(cache))
cache.cleanup()
self.assertEqual(0, len(cache))
def test_requests_cache_match(self):
cache = worker_types.RequestsCache()
cache[self.task_uuid] = self.request()
cache['task-uuid-2'] = self.request(task=utils.NastyTask(),
uuid='task-uuid-2')
worker = worker_types.TopicWorker("dummy-topic", [utils.DummyTask],
identity="dummy")
matches = cache.get_waiting_requests(worker)
self.assertEqual(1, len(matches))
self.assertEqual(2, len(cache))
class TestTopicWorker(test.TestCase): class TestTopicWorker(test.TestCase):
def test_topic_worker(self): def test_topic_worker(self):
worker = worker_types.TopicWorker("dummy-topic", worker = worker_types.TopicWorker("dummy-topic",

View File

@ -1,85 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import threading
from oslo_utils import reflection
import six
class ExpiringCache(object):
"""Represents a thread-safe time-based expiring cache.
NOTE(harlowja): the values in this cache must have a expired attribute that
can be used to determine if the key and associated value has expired or if
it has not.
"""
def __init__(self):
self._data = {}
self._lock = threading.Lock()
def __setitem__(self, key, value):
"""Set a value in the cache."""
with self._lock:
self._data[key] = value
def __len__(self):
"""Returns how many items are in this cache."""
return len(self._data)
def get(self, key, default=None):
"""Retrieve a value from the cache (returns default if not found)."""
return self._data.get(key, default)
def __getitem__(self, key):
"""Retrieve a value from the cache."""
return self._data[key]
def __delitem__(self, key):
"""Delete a key & value from the cache."""
with self._lock:
del self._data[key]
def clear(self, on_cleared_callback=None):
"""Removes all keys & values from the cache."""
cleared_items = []
with self._lock:
if on_cleared_callback is not None:
cleared_items.extend(six.iteritems(self._data))
self._data.clear()
if on_cleared_callback is not None:
arg_c = len(reflection.get_callable_args(on_cleared_callback))
for (k, v) in cleared_items:
if arg_c == 2:
on_cleared_callback(k, v)
else:
on_cleared_callback(v)
def cleanup(self, on_expired_callback=None):
"""Delete out-dated keys & values from the cache."""
with self._lock:
expired_values = [(k, v) for k, v in six.iteritems(self._data)
if v.expired]
for (k, _v) in expired_values:
del self._data[k]
if on_expired_callback is not None:
arg_c = len(reflection.get_callable_args(on_expired_callback))
for (k, v) in expired_values:
if arg_c == 2:
on_expired_callback(k, v)
else:
on_expired_callback(v)