Fix for WBE sporadic timeout of tasks

This fixes the sporadic of tasks that would happen
under certain circumstances. What happened was that
a new worker notification would be sent to a callback
while at the same time a task submission would come in
and there would be a small race period where the task
would insert itself into the requests cache while the
callback was processing.

So to work around this the whole concept of a requests
cache was revamped and now the WBE executor just maintains
its own local dictionary of ongoing requests and accesses
it safely.

During the on_wait function that is periodically called
by kombu the previous expiry of work happens but now any
requests that are pending are matched to any new workers
that may have appeared.

This avoids the race (and ensures that even if a new
worker is found but a submission is in progress that the
duration until that submission happens will only be until
the next on_wait call happens).

Related-Bug: #1431097

Change-Id: I98b0caeedc77ab2f7214847763ae1eb0433d4a78
This commit is contained in:
Joshua Harlow 2016-02-04 18:09:24 -08:00
parent 61efc31e96
commit cea71f2799
7 changed files with 83 additions and 247 deletions

View File

@ -12,11 +12,6 @@ Types
into *isolated* libraries (as using these types in this manner is not
the expected and/or desired usage).
Cache
=====
.. automodule:: taskflow.types.cache
Entity
======

View File

@ -15,9 +15,11 @@
# under the License.
import functools
import threading
from futurist import periodics
from oslo_utils import timeutils
import six
from taskflow.engines.action_engine import executor
from taskflow.engines.worker_based import dispatcher
@ -26,7 +28,7 @@ from taskflow.engines.worker_based import proxy
from taskflow.engines.worker_based import types as wt
from taskflow import exceptions as exc
from taskflow import logging
from taskflow import task as task_atom
from taskflow.task import EVENT_UPDATE_PROGRESS # noqa
from taskflow.utils import kombu_utils as ku
from taskflow.utils import misc
from taskflow.utils import threading_utils as tu
@ -42,7 +44,8 @@ class WorkerTaskExecutor(executor.TaskExecutor):
url=None, transport=None, transport_options=None,
retry_options=None):
self._uuid = uuid
self._requests_cache = wt.RequestsCache()
self._ongoing_requests = {}
self._ongoing_requests_lock = threading.RLock()
self._transition_timeout = transition_timeout
type_handlers = {
pr.RESPONSE: dispatcher.Handler(self._process_response,
@ -61,8 +64,6 @@ class WorkerTaskExecutor(executor.TaskExecutor):
# pre-existing knowledge of the topics those workers are on to gather
# and update this information).
self._finder = wt.ProxyWorkerFinder(uuid, self._proxy, topics)
self._finder.notifier.register(wt.WorkerFinder.WORKER_ARRIVED,
self._on_worker)
self._helpers = tu.ThreadBundle()
self._helpers.bind(lambda: tu.daemon_thread(self._proxy.start),
after_start=lambda t: self._proxy.wait(),
@ -74,25 +75,18 @@ class WorkerTaskExecutor(executor.TaskExecutor):
after_join=lambda t: p_worker.reset(),
before_start=lambda t: p_worker.reset())
def _on_worker(self, event_type, details):
"""Process new worker that has arrived (and fire off any work)."""
worker = details['worker']
for request in self._requests_cache.get_waiting_requests(worker):
if request.transition_and_log_error(pr.PENDING, logger=LOG):
self._publish_request(request, worker)
def _process_response(self, response, message):
"""Process response from remote side."""
LOG.debug("Started processing response message '%s'",
ku.DelayedPretty(message))
try:
task_uuid = message.properties['correlation_id']
request_uuid = message.properties['correlation_id']
except KeyError:
LOG.warning("The 'correlation_id' message property is"
" missing in message '%s'",
ku.DelayedPretty(message))
else:
request = self._requests_cache.get(task_uuid)
request = self._ongoing_requests.get(request_uuid)
if request is not None:
response = pr.Response.from_dict(response)
LOG.debug("Extracted response '%s' and matched it to"
@ -105,35 +99,31 @@ class WorkerTaskExecutor(executor.TaskExecutor):
details = response.data['details']
request.notifier.notify(event_type, details)
elif response.state in (pr.FAILURE, pr.SUCCESS):
moved = request.transition_and_log_error(response.state,
logger=LOG)
if moved:
# NOTE(imelnikov): request should not be in the
# cache when another thread can see its result and
# schedule another request with the same uuid; so
# we remove it, then set the result...
del self._requests_cache[request.uuid]
if request.transition_and_log_error(response.state,
logger=LOG):
with self._ongoing_requests_lock:
del self._ongoing_requests[request.uuid]
request.set_result(**response.data)
else:
LOG.warning("Unexpected response status '%s'",
response.state)
else:
LOG.debug("Request with id='%s' not found", task_uuid)
LOG.debug("Request with id='%s' not found", request_uuid)
@staticmethod
def _handle_expired_request(request):
"""Handle expired request.
"""Handle a expired request.
When request has expired it is removed from the requests cache and
the `RequestTimeout` exception is set as a request result.
When a request has expired it is removed from the ongoing requests
dictionary and a ``RequestTimeout`` exception is set as a
request result.
"""
if request.transition_and_log_error(pr.FAILURE, logger=LOG):
# Raise an exception (and then catch it) so we get a nice
# traceback that the request will get instead of it getting
# just an exception with no traceback...
try:
request_age = timeutils.delta_seconds(request.created_on,
timeutils.utcnow())
request_age = timeutils.now() - request.created_on
raise exc.RequestTimeout(
"Request '%s' has expired after waiting for %0.2f"
" seconds for it to transition out of (%s) states"
@ -142,51 +132,74 @@ class WorkerTaskExecutor(executor.TaskExecutor):
with misc.capture_failure() as failure:
LOG.debug(failure.exception_str)
request.set_result(failure)
return True
return False
def _on_wait(self):
"""This function is called cyclically between draining events."""
self._requests_cache.cleanup(self._handle_expired_request)
with self._ongoing_requests_lock:
ongoing_requests_uuids = set(six.iterkeys(self._ongoing_requests))
waiting_requests = {}
expired_requests = {}
for request_uuid in ongoing_requests_uuids:
try:
request = self._ongoing_requests[request_uuid]
except KeyError:
# Guess it got removed before we got to it...
pass
else:
if request.expired:
expired_requests[request_uuid] = request
elif request.state == pr.WAITING:
worker = self._finder.get_worker_for_task(request.task)
if worker is not None:
waiting_requests[request_uuid] = (request, worker)
if expired_requests:
with self._ongoing_requests_lock:
while expired_requests:
request_uuid, request = expired_requests.popitem()
if self._handle_expired_request(request):
del self._ongoing_requests[request_uuid]
if waiting_requests:
while waiting_requests:
request_uuid, (request, worker) = waiting_requests.popitem()
if request.transition_and_log_error(pr.PENDING, logger=LOG):
self._publish_request(request, worker)
def _submit_task(self, task, task_uuid, action, arguments,
progress_callback=None, **kwargs):
"""Submit task request to a worker."""
request = pr.Request(task, task_uuid, action, arguments,
self._transition_timeout, **kwargs)
# Register the callback, so that we can proxy the progress correctly.
if (progress_callback is not None and
request.notifier.can_be_registered(
task_atom.EVENT_UPDATE_PROGRESS)):
request.notifier.register(task_atom.EVENT_UPDATE_PROGRESS,
progress_callback)
request.notifier.can_be_registered(EVENT_UPDATE_PROGRESS)):
request.notifier.register(EVENT_UPDATE_PROGRESS, progress_callback)
cleaner = functools.partial(request.notifier.deregister,
task_atom.EVENT_UPDATE_PROGRESS,
EVENT_UPDATE_PROGRESS,
progress_callback)
request.result.add_done_callback(lambda fut: cleaner())
# Get task's worker and publish request if worker was found.
worker = self._finder.get_worker_for_task(task)
if worker is not None:
# NOTE(skudriashev): Make sure request is set to the PENDING state
# before putting it into the requests cache to prevent the notify
# processing thread get list of waiting requests and publish it
# before it is published here, so it wouldn't be published twice.
if request.transition_and_log_error(pr.PENDING, logger=LOG):
self._requests_cache[request.uuid] = request
with self._ongoing_requests_lock:
self._ongoing_requests[request.uuid] = request
self._publish_request(request, worker)
else:
LOG.debug("Delaying submission of '%s', no currently known"
" worker/s available to process it", request)
self._requests_cache[request.uuid] = request
with self._ongoing_requests_lock:
self._ongoing_requests[request.uuid] = request
return request.result
def _publish_request(self, request, worker):
"""Publish request to a given topic."""
LOG.debug("Submitting execution of '%s' to worker '%s' (expecting"
" response identified by reply_to=%s and"
" correlation_id=%s)", request, worker, self._uuid,
request.uuid)
" correlation_id=%s) - waited %0.3f seconds to"
" get published", request, worker, self._uuid,
request.uuid, timeutils.now() - request.created_on)
try:
self._proxy.publish(request, worker.topic,
reply_to=self._uuid,
@ -196,7 +209,8 @@ class WorkerTaskExecutor(executor.TaskExecutor):
LOG.critical("Failed to submit '%s' (transitioning it to"
" %s)", request, pr.FAILURE, exc_info=True)
if request.transition_and_log_error(pr.FAILURE, logger=LOG):
del self._requests_cache[request.uuid]
with self._ongoing_requests_lock:
del self._ongoing_requests[request.uuid]
request.set_result(failure)
def execute_task(self, task, task_uuid, arguments,
@ -229,5 +243,8 @@ class WorkerTaskExecutor(executor.TaskExecutor):
def stop(self):
"""Stops proxy thread and associated topic notification thread."""
self._helpers.stop()
self._requests_cache.clear(self._handle_expired_request)
with self._ongoing_requests_lock:
while self._ongoing_requests:
_request_uuid, request = self._ongoing_requests.popitem()
self._handle_expired_request(request)
self._finder.clear()

View File

@ -262,7 +262,7 @@ class Request(Message):
self._watch = timeutils.StopWatch(duration=timeout).start()
self._state = WAITING
self._lock = threading.Lock()
self._created_on = timeutils.utcnow()
self._created_on = timeutils.now()
self._result = futurist.Future()
self._result.atom = task
self._notifier = task.notifier

View File

@ -28,27 +28,11 @@ import six
from taskflow.engines.worker_based import dispatcher
from taskflow.engines.worker_based import protocol as pr
from taskflow import logging
from taskflow.types import cache as base
from taskflow.types import notifier
from taskflow.utils import kombu_utils as ku
LOG = logging.getLogger(__name__)
class RequestsCache(base.ExpiringCache):
"""Represents a thread-safe requests cache."""
def get_waiting_requests(self, worker):
"""Get list of waiting requests that the given worker can satisfy."""
waiting_requests = []
with self._lock:
for request in six.itervalues(self._data):
if request.state == pr.WAITING \
and worker.performs(request.task):
waiting_requests.append(request)
return waiting_requests
# TODO(harlowja): this needs to be made better, once
# https://blueprints.launchpad.net/taskflow/+spec/wbe-worker-info is finally
# implemented we can go about using that instead.
@ -101,12 +85,8 @@ class TopicWorker(object):
class WorkerFinder(object):
"""Base class for worker finders..."""
#: Event type emitted when a new worker arrives.
WORKER_ARRIVED = 'worker_arrived'
def __init__(self):
self._cond = threading.Condition()
self.notifier = notifier.RestrictedNotifier([self.WORKER_ARRIVED])
@abc.abstractmethod
def _total_workers(self):
@ -219,8 +199,6 @@ class ProxyWorkerFinder(WorkerFinder):
LOG.debug("Updated worker '%s' (%s total workers are"
" currently known)", worker, self._total_workers())
self._cond.notify_all()
if new_or_updated:
self.notifier.notify(self.WORKER_ARRIVED, {'worker': worker})
def clear(self):
with self._cond:

View File

@ -17,9 +17,6 @@
import threading
import time
from oslo_utils import fixture
from oslo_utils import timeutils
from taskflow.engines.worker_based import executor
from taskflow.engines.worker_based import protocol as pr
from taskflow import task as task_atom
@ -56,6 +53,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
self.proxy_inst_mock.stop.side_effect = self._fake_proxy_stop
self.request_inst_mock.uuid = self.task_uuid
self.request_inst_mock.expired = False
self.request_inst_mock.created_on = 0
self.request_inst_mock.task_cls = self.task.name
self.message_mock = mock.MagicMock(name='message')
self.message_mock.properties = {'correlation_id': self.task_uuid,
@ -96,7 +94,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
def test_on_message_response_state_running(self):
response = pr.Response(pr.RUNNING)
ex = self.executor()
ex._requests_cache[self.task_uuid] = self.request_inst_mock
ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
ex._process_response(response.to_dict(), self.message_mock)
expected_calls = [
@ -109,7 +107,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
event_type=task_atom.EVENT_UPDATE_PROGRESS,
details={'progress': 1.0})
ex = self.executor()
ex._requests_cache[self.task_uuid] = self.request_inst_mock
ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
ex._process_response(response.to_dict(), self.message_mock)
expected_calls = [
@ -123,10 +121,10 @@ class TestWorkerTaskExecutor(test.MockTestCase):
failure_dict = a_failure.to_dict()
response = pr.Response(pr.FAILURE, result=failure_dict)
ex = self.executor()
ex._requests_cache[self.task_uuid] = self.request_inst_mock
ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
ex._process_response(response.to_dict(), self.message_mock)
self.assertEqual(0, len(ex._requests_cache))
self.assertEqual(0, len(ex._ongoing_requests))
expected_calls = [
mock.call.transition_and_log_error(pr.FAILURE, logger=mock.ANY),
mock.call.set_result(result=test_utils.FailureMatcher(a_failure))
@ -137,7 +135,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
response = pr.Response(pr.SUCCESS, result=self.task_result,
event='executed')
ex = self.executor()
ex._requests_cache[self.task_uuid] = self.request_inst_mock
ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
ex._process_response(response.to_dict(), self.message_mock)
expected_calls = [
@ -149,7 +147,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
def test_on_message_response_unknown_state(self):
response = pr.Response(state='<unknown>')
ex = self.executor()
ex._requests_cache[self.task_uuid] = self.request_inst_mock
ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
ex._process_response(response.to_dict(), self.message_mock)
self.assertEqual([], self.request_inst_mock.mock_calls)
@ -158,7 +156,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
self.message_mock.properties['correlation_id'] = '<unknown>'
response = pr.Response(pr.RUNNING)
ex = self.executor()
ex._requests_cache[self.task_uuid] = self.request_inst_mock
ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
ex._process_response(response.to_dict(), self.message_mock)
self.assertEqual([], self.request_inst_mock.mock_calls)
@ -167,48 +165,32 @@ class TestWorkerTaskExecutor(test.MockTestCase):
self.message_mock.properties = {'type': pr.RESPONSE}
response = pr.Response(pr.RUNNING)
ex = self.executor()
ex._requests_cache[self.task_uuid] = self.request_inst_mock
ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
ex._process_response(response.to_dict(), self.message_mock)
self.assertEqual([], self.request_inst_mock.mock_calls)
def test_on_wait_task_not_expired(self):
ex = self.executor()
ex._requests_cache[self.task_uuid] = self.request_inst_mock
ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
self.assertEqual(1, len(ex._requests_cache))
self.assertEqual(1, len(ex._ongoing_requests))
ex._on_wait()
self.assertEqual(1, len(ex._requests_cache))
self.assertEqual(1, len(ex._ongoing_requests))
def test_on_wait_task_expired(self):
now = timeutils.utcnow()
f = self.useFixture(fixture.TimeFixture(override_time=now))
@mock.patch('oslo_utils.timeutils.now')
def test_on_wait_task_expired(self, mock_now):
mock_now.side_effect = [0, 120]
self.request_inst_mock.expired = True
self.request_inst_mock.created_on = now
self.request_inst_mock.created_on = 0
f.advance_time_seconds(120)
ex = self.executor()
ex._requests_cache[self.task_uuid] = self.request_inst_mock
ex._ongoing_requests[self.task_uuid] = self.request_inst_mock
self.assertEqual(1, len(ex._ongoing_requests))
self.assertEqual(1, len(ex._requests_cache))
ex._on_wait()
self.assertEqual(0, len(ex._requests_cache))
def test_remove_task_non_existent(self):
ex = self.executor()
ex._requests_cache[self.task_uuid] = self.request_inst_mock
self.assertEqual(1, len(ex._requests_cache))
del ex._requests_cache[self.task_uuid]
self.assertEqual(0, len(ex._requests_cache))
# delete non-existent
try:
del ex._requests_cache[self.task_uuid]
except KeyError:
pass
self.assertEqual(0, len(ex._requests_cache))
self.assertEqual(0, len(ex._ongoing_requests))
def test_execute_task(self):
ex = self.executor()

View File

@ -16,63 +16,12 @@
from oslo_utils import reflection
from taskflow.engines.worker_based import protocol as pr
from taskflow.engines.worker_based import types as worker_types
from taskflow import test
from taskflow.test import mock
from taskflow.tests import utils
class TestRequestCache(test.TestCase):
def setUp(self):
super(TestRequestCache, self).setUp()
self.task = utils.DummyTask()
self.task_uuid = 'task-uuid'
self.task_action = 'execute'
self.task_args = {'a': 'a'}
self.timeout = 60
def request(self, **kwargs):
request_kwargs = dict(task=self.task,
uuid=self.task_uuid,
action=self.task_action,
arguments=self.task_args,
progress_callback=None,
timeout=self.timeout)
request_kwargs.update(kwargs)
return pr.Request(**request_kwargs)
@mock.patch('oslo_utils.timeutils.now')
def test_requests_cache_expiry(self, now):
# Mock out the calls the underlying objects will soon use to return
# times that we can control more easily...
overrides = [
0,
1,
self.timeout + 1,
]
now.side_effect = overrides
cache = worker_types.RequestsCache()
cache[self.task_uuid] = self.request()
cache.cleanup()
self.assertEqual(1, len(cache))
cache.cleanup()
self.assertEqual(0, len(cache))
def test_requests_cache_match(self):
cache = worker_types.RequestsCache()
cache[self.task_uuid] = self.request()
cache['task-uuid-2'] = self.request(task=utils.NastyTask(),
uuid='task-uuid-2')
worker = worker_types.TopicWorker("dummy-topic", [utils.DummyTask],
identity="dummy")
matches = cache.get_waiting_requests(worker)
self.assertEqual(1, len(matches))
self.assertEqual(2, len(cache))
class TestTopicWorker(test.TestCase):
def test_topic_worker(self):
worker = worker_types.TopicWorker("dummy-topic",

View File

@ -1,85 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import threading
from oslo_utils import reflection
import six
class ExpiringCache(object):
"""Represents a thread-safe time-based expiring cache.
NOTE(harlowja): the values in this cache must have a expired attribute that
can be used to determine if the key and associated value has expired or if
it has not.
"""
def __init__(self):
self._data = {}
self._lock = threading.Lock()
def __setitem__(self, key, value):
"""Set a value in the cache."""
with self._lock:
self._data[key] = value
def __len__(self):
"""Returns how many items are in this cache."""
return len(self._data)
def get(self, key, default=None):
"""Retrieve a value from the cache (returns default if not found)."""
return self._data.get(key, default)
def __getitem__(self, key):
"""Retrieve a value from the cache."""
return self._data[key]
def __delitem__(self, key):
"""Delete a key & value from the cache."""
with self._lock:
del self._data[key]
def clear(self, on_cleared_callback=None):
"""Removes all keys & values from the cache."""
cleared_items = []
with self._lock:
if on_cleared_callback is not None:
cleared_items.extend(six.iteritems(self._data))
self._data.clear()
if on_cleared_callback is not None:
arg_c = len(reflection.get_callable_args(on_cleared_callback))
for (k, v) in cleared_items:
if arg_c == 2:
on_cleared_callback(k, v)
else:
on_cleared_callback(v)
def cleanup(self, on_expired_callback=None):
"""Delete out-dated keys & values from the cache."""
with self._lock:
expired_values = [(k, v) for k, v in six.iteritems(self._data)
if v.expired]
for (k, _v) in expired_values:
del self._data[k]
if on_expired_callback is not None:
arg_c = len(reflection.get_callable_args(on_expired_callback))
for (k, v) in expired_values:
if arg_c == 2:
on_expired_callback(k, v)
else:
on_expired_callback(v)