Tidy up the WBE cache (now WBE types) module

Instead of using the expiring cache type as a means
to store worker information just avoid using that type
since we don't support expiry in the first place on
worker information and use a worker container and a
worker object that we can later extend as needed.

Also add on clear methods to the cache type that will be
used when the WBE executor stop occurs. This ensures we
clear out the worker information and any unfinished
requests.

Change-Id: I6a520376eff1e8a6edcef0a59f2d8b9c0eb15752
This commit is contained in:
Joshua Harlow
2014-06-26 16:15:05 -07:00
committed by Joshua Harlow
parent 292adc5a62
commit 45ef595fde
7 changed files with 404 additions and 122 deletions

View File

@@ -1,48 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import random
import six
from taskflow.engines.worker_based import protocol as pr
from taskflow.types import cache as base
class RequestsCache(base.ExpiringCache):
"""Represents a thread-safe requests cache."""
def get_waiting_requests(self, tasks):
"""Get list of waiting requests by tasks."""
waiting_requests = []
with self._lock:
for request in six.itervalues(self._data):
if request.state == pr.WAITING and request.task_cls in tasks:
waiting_requests.append(request)
return waiting_requests
class WorkersCache(base.ExpiringCache):
"""Represents a thread-safe workers cache."""
def get_topic_by_task(self, task):
"""Get topic for a given task."""
available_topics = []
with self._lock:
for topic, tasks in six.iteritems(self._data):
if task in tasks:
available_topics.append(topic)
return random.choice(available_topics) if available_topics else None

View File

@@ -15,15 +15,13 @@
# under the License.
import functools
import threading
from oslo_utils import reflection
from oslo_utils import timeutils
from taskflow.engines.action_engine import executor
from taskflow.engines.worker_based import cache
from taskflow.engines.worker_based import protocol as pr
from taskflow.engines.worker_based import proxy
from taskflow.engines.worker_based import types as wt
from taskflow import exceptions as exc
from taskflow import logging
from taskflow import task as task_atom
@@ -34,35 +32,6 @@ from taskflow.utils import threading_utils as tu
LOG = logging.getLogger(__name__)
class PeriodicWorker(object):
"""Calls a set of functions when activated periodically.
NOTE(harlowja): the provided timeout object determines the periodicity.
"""
def __init__(self, timeout, functors):
self._timeout = timeout
self._functors = []
for f in functors:
self._functors.append((f, reflection.get_callable_name(f)))
def start(self):
while not self._timeout.is_stopped():
for (f, f_name) in self._functors:
LOG.debug("Calling periodic function '%s'", f_name)
try:
f()
except Exception:
LOG.warn("Failed to call periodic function '%s'", f_name,
exc_info=True)
self._timeout.wait()
def stop(self):
self._timeout.interrupt()
def reset(self):
self._timeout.reset()
class WorkerTaskExecutor(executor.TaskExecutor):
"""Executes tasks on remote workers."""
@@ -72,10 +41,9 @@ class WorkerTaskExecutor(executor.TaskExecutor):
retry_options=None):
self._uuid = uuid
self._topics = topics
self._requests_cache = cache.RequestsCache()
self._requests_cache = wt.RequestsCache()
self._workers = wt.TopicWorkers()
self._transition_timeout = transition_timeout
self._workers_cache = cache.WorkersCache()
self._workers_arrival = threading.Condition()
type_handlers = {
pr.NOTIFY: [
self._process_notify,
@@ -92,7 +60,7 @@ class WorkerTaskExecutor(executor.TaskExecutor):
transport_options=transport_options,
retry_options=retry_options)
self._proxy_thread = None
self._periodic = PeriodicWorker(tt.Timeout(pr.NOTIFY_PERIOD),
self._periodic = wt.PeriodicWorker(tt.Timeout(pr.NOTIFY_PERIOD),
[self._notify_topics])
self._periodic_thread = None
@@ -104,16 +72,15 @@ class WorkerTaskExecutor(executor.TaskExecutor):
tasks = notify['tasks']
# Add worker info to the cache
LOG.debug("Received that tasks %s can be processed by topic '%s'",
tasks, topic)
with self._workers_arrival:
self._workers_cache[topic] = tasks
self._workers_arrival.notify_all()
worker = self._workers.add(topic, tasks)
LOG.debug("Received notification about worker '%s' (%s"
" total workers are currently known)", worker,
len(self._workers))
# Publish waiting requests
for request in self._requests_cache.get_waiting_requests(tasks):
for request in self._requests_cache.get_waiting_requests(worker):
if request.transition_and_log_error(pr.PENDING, logger=LOG):
self._publish_request(request, topic)
self._publish_request(request, worker)
def _process_response(self, response, message):
"""Process response from remote side."""
@@ -147,7 +114,7 @@ class WorkerTaskExecutor(executor.TaskExecutor):
del self._requests_cache[request.uuid]
request.set_result(**response.data)
else:
LOG.warning("Unexpected response status: '%s'",
LOG.warning("Unexpected response status '%s'",
response.state)
else:
LOG.debug("Request with id='%s' not found", task_uuid)
@@ -196,16 +163,16 @@ class WorkerTaskExecutor(executor.TaskExecutor):
progress_callback)
request.result.add_done_callback(lambda fut: cleaner())
# Get task's topic and publish request if topic was found.
topic = self._workers_cache.get_topic_by_task(request.task_cls)
if topic is not None:
# Get task's worker and publish request if worker was found.
worker = self._workers.get_worker_for_task(task)
if worker is not None:
# NOTE(skudriashev): Make sure request is set to the PENDING state
# before putting it into the requests cache to prevent the notify
# processing thread get list of waiting requests and publish it
# before it is published here, so it wouldn't be published twice.
if request.transition_and_log_error(pr.PENDING, logger=LOG):
self._requests_cache[request.uuid] = request
self._publish_request(request, topic)
self._publish_request(request, worker)
else:
LOG.debug("Delaying submission of '%s', no currently known"
" worker/s available to process it", request)
@@ -213,14 +180,14 @@ class WorkerTaskExecutor(executor.TaskExecutor):
return request.result
def _publish_request(self, request, topic):
def _publish_request(self, request, worker):
"""Publish request to a given topic."""
LOG.debug("Submitting execution of '%s' to topic '%s' (expecting"
LOG.debug("Submitting execution of '%s' to worker '%s' (expecting"
" response identified by reply_to=%s and"
" correlation_id=%s)", request, topic, self._uuid,
" correlation_id=%s)", request, worker, self._uuid,
request.uuid)
try:
self._proxy.publish(request, topic,
self._proxy.publish(request, worker.topic,
reply_to=self._uuid,
correlation_id=request.uuid)
except Exception:
@@ -255,20 +222,7 @@ class WorkerTaskExecutor(executor.TaskExecutor):
return how many workers are still needed, otherwise it will
return zero.
"""
if workers <= 0:
raise ValueError("Worker amount must be greater than zero")
w = None
if timeout is not None:
w = tt.StopWatch(timeout).start()
with self._workers_arrival:
while len(self._workers_cache) < workers:
if w is not None and w.expired():
return workers - len(self._workers_cache)
timeout = None
if w is not None:
timeout = w.leftover()
self._workers_arrival.wait(timeout)
return 0
return self._workers.wait_for_workers(workers=workers, timeout=timeout)
def start(self):
"""Starts proxy thread and associated topic notification thread."""
@@ -291,3 +245,5 @@ class WorkerTaskExecutor(executor.TaskExecutor):
self._proxy.stop()
self._proxy_thread.join()
self._proxy_thread = None
self._requests_cache.clear(self._handle_expired_request)
self._workers.clear()

View File

@@ -221,7 +221,6 @@ class Request(Message):
def __init__(self, task, uuid, action, arguments, timeout, **kwargs):
self._task = task
self._task_cls = reflection.get_class_name(task)
self._uuid = uuid
self._action = action
self._event = ACTION_TO_EVENT[action]
@@ -248,8 +247,8 @@ class Request(Message):
return self._uuid
@property
def task_cls(self):
return self._task_cls
def task(self):
return self._task
@property
def state(self):
@@ -281,9 +280,13 @@ class Request(Message):
convert all `failure.Failure` objects into dictionaries (which will
then be reconstituted by the receiver).
"""
request = dict(task_cls=self._task_cls, task_name=self._task.name,
task_version=self._task.version, action=self._action,
arguments=self._arguments)
request = {
'task_cls': reflection.get_class_name(self._task),
'task_name': self._task.name,
'task_version': self._task.version,
'action': self._action,
'arguments': self._arguments,
}
if 'result' in self._kwargs:
result = self._kwargs['result']
if isinstance(result, ft.Failure):

View File

@@ -0,0 +1,217 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import itertools
import logging
import random
import threading
from oslo.utils import reflection
import six
from taskflow.engines.worker_based import protocol as pr
from taskflow.types import cache as base
from taskflow.types import timing as tt
LOG = logging.getLogger(__name__)
class RequestsCache(base.ExpiringCache):
"""Represents a thread-safe requests cache."""
def get_waiting_requests(self, worker):
"""Get list of waiting requests that the given worker can satisfy."""
waiting_requests = []
with self._lock:
for request in six.itervalues(self._data):
if request.state == pr.WAITING \
and worker.performs(request.task):
waiting_requests.append(request)
return waiting_requests
# TODO(harlowja): this needs to be made better, once
# https://blueprints.launchpad.net/taskflow/+spec/wbe-worker-info is finally
# implemented we can go about using that instead.
class TopicWorker(object):
"""A (read-only) worker and its relevant information + useful methods."""
_NO_IDENTITY = object()
def __init__(self, topic, tasks, identity=_NO_IDENTITY):
self.tasks = []
for task in tasks:
if not isinstance(task, six.string_types):
task = reflection.get_class_name(task)
self.tasks.append(task)
self.topic = topic
self.identity = identity
def performs(self, task):
if not isinstance(task, six.string_types):
task = reflection.get_class_name(task)
return task in self.tasks
def __eq__(self, other):
if not isinstance(other, TopicWorker):
return NotImplemented
if len(other.tasks) != len(self.tasks):
return False
if other.topic != self.topic:
return False
for task in other.tasks:
if not self.performs(task):
return False
# If one of the identity equals _NO_IDENTITY, then allow it to match...
if self._NO_IDENTITY in (self.identity, other.identity):
return True
else:
return other.identity == self.identity
def __repr__(self):
r = reflection.get_class_name(self, fully_qualified=False)
if self.identity is not self._NO_IDENTITY:
r += "(identity=%s, tasks=%s, topic=%s)" % (self.identity,
self.tasks, self.topic)
else:
r += "(identity=*, tasks=%s, topic=%s)" % (self.tasks, self.topic)
return r
class TopicWorkers(object):
"""A collection of topic based workers."""
@staticmethod
def _match_worker(task, available_workers):
"""Select a worker (from geq 1 workers) that can best perform the task.
NOTE(harlowja): this method will be activated when there exists
one one greater than one potential workers that can perform a task,
the arguments provided will be the potential workers located and the
task that is being requested to perform and the result should be one
of those workers using whatever best-fit algorithm is possible (or
random at the least).
"""
if len(available_workers) == 1:
return available_workers[0]
else:
return random.choice(available_workers)
def __init__(self):
self._workers = {}
self._cond = threading.Condition()
# Used to name workers with more useful identities...
self._counter = itertools.count()
def __len__(self):
return len(self._workers)
def _next_worker(self, topic, tasks, temporary=False):
if not temporary:
return TopicWorker(topic, tasks,
identity=six.next(self._counter))
else:
return TopicWorker(topic, tasks)
def add(self, topic, tasks):
"""Adds/updates a worker for the topic for the given tasks."""
with self._cond:
try:
worker = self._workers[topic]
# Check if we already have an equivalent worker, if so just
# return it...
if worker == self._next_worker(topic, tasks, temporary=True):
return worker
# This *fall through* is done so that if someone is using an
# active worker object that already exists that we just create
# a new one; so that the existing object doesn't get
# affected (workers objects are supposed to be immutable).
except KeyError:
pass
worker = self._next_worker(topic, tasks)
self._workers[topic] = worker
self._cond.notify_all()
return worker
def wait_for_workers(self, workers=1, timeout=None):
"""Waits for geq workers to notify they are ready to do work.
NOTE(harlowja): if a timeout is provided this function will wait
until that timeout expires, if the amount of workers does not reach
the desired amount of workers before the timeout expires then this will
return how many workers are still needed, otherwise it will
return zero.
"""
if workers <= 0:
raise ValueError("Worker amount must be greater than zero")
w = None
if timeout is not None:
w = tt.StopWatch(timeout).start()
with self._cond:
while len(self._workers) < workers:
if w is not None and w.expired():
return max(0, workers - len(self._workers))
timeout = None
if w is not None:
timeout = w.leftover()
self._cond.wait(timeout)
return 0
def get_worker_for_task(self, task):
"""Gets a worker that can perform a given task."""
available_workers = []
with self._cond:
for worker in six.itervalues(self._workers):
if worker.performs(task):
available_workers.append(worker)
if available_workers:
return self._match_worker(task, available_workers)
else:
return None
def clear(self):
with self._cond:
self._workers.clear()
self._cond.notify_all()
class PeriodicWorker(object):
"""Calls a set of functions when activated periodically.
NOTE(harlowja): the provided timeout object determines the periodicity.
"""
def __init__(self, timeout, functors):
self._timeout = timeout
self._functors = []
for f in functors:
self._functors.append((f, reflection.get_callable_name(f)))
def start(self):
while not self._timeout.is_stopped():
for (f, f_name) in self._functors:
LOG.debug("Calling periodic function '%s'", f_name)
try:
f()
except Exception:
LOG.warn("Failed to call periodic function '%s'", f_name,
exc_info=True)
self._timeout.wait()
def stop(self):
self._timeout.interrupt()
def reset(self):
self._timeout.reset()

View File

@@ -136,7 +136,7 @@ class TestProtocol(test.TestCase):
def test_creation(self):
request = self.request()
self.assertEqual(request.uuid, self.task_uuid)
self.assertEqual(request.task_cls, self.task.name)
self.assertEqual(request.task, self.task)
self.assertIsInstance(request.result, futures.Future)
self.assertFalse(request.result.done())

View File

@@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import datetime
import threading
import time
from oslo.utils import reflection
from oslo.utils import timeutils
from taskflow.engines.worker_based import protocol as pr
from taskflow.engines.worker_based import types as worker_types
from taskflow import test
from taskflow.tests import utils
from taskflow.types import latch
from taskflow.types import timing
class TestWorkerTypes(test.TestCase):
def setUp(self):
super(TestWorkerTypes, self).setUp()
self.task = utils.DummyTask()
self.task_uuid = 'task-uuid'
self.task_action = 'execute'
self.task_args = {'a': 'a'}
self.timeout = 60
def request(self, **kwargs):
request_kwargs = dict(task=self.task,
uuid=self.task_uuid,
action=self.task_action,
arguments=self.task_args,
progress_callback=None,
timeout=self.timeout)
request_kwargs.update(kwargs)
return pr.Request(**request_kwargs)
def test_requests_cache_expiry(self):
# Mock out the calls the underlying objects will soon use to return
# times that we can control more easily...
now = timeutils.utcnow()
overrides = [
now,
now,
now + datetime.timedelta(seconds=1),
now + datetime.timedelta(seconds=self.timeout + 1),
]
timeutils.set_time_override(overrides)
self.addCleanup(timeutils.clear_time_override)
cache = worker_types.RequestsCache()
cache[self.task_uuid] = self.request()
cache.cleanup()
self.assertEqual(1, len(cache))
cache.cleanup()
self.assertEqual(0, len(cache))
def test_requests_cache_match(self):
cache = worker_types.RequestsCache()
cache[self.task_uuid] = self.request()
cache['task-uuid-2'] = self.request(task=utils.NastyTask(),
uuid='task-uuid-2')
worker = worker_types.TopicWorker("dummy-topic", [utils.DummyTask],
identity="dummy")
matches = cache.get_waiting_requests(worker)
self.assertEqual(1, len(matches))
self.assertEqual(2, len(cache))
def test_topic_worker(self):
worker = worker_types.TopicWorker("dummy-topic",
[utils.DummyTask], identity="dummy")
self.assertTrue(worker.performs(utils.DummyTask))
self.assertFalse(worker.performs(utils.NastyTask))
self.assertEqual('dummy', worker.identity)
self.assertEqual('dummy-topic', worker.topic)
def test_single_topic_workers(self):
workers = worker_types.TopicWorkers()
w = workers.add('dummy-topic', [utils.DummyTask])
self.assertIsNotNone(w)
self.assertEqual(1, len(workers))
w2 = workers.get_worker_for_task(utils.DummyTask)
self.assertEqual(w.identity, w2.identity)
def test_multi_same_topic_workers(self):
workers = worker_types.TopicWorkers()
w = workers.add('dummy-topic', [utils.DummyTask])
self.assertIsNotNone(w)
w2 = workers.add('dummy-topic-2', [utils.DummyTask])
self.assertIsNotNone(w2)
w3 = workers.get_worker_for_task(
reflection.get_class_name(utils.DummyTask))
self.assertIn(w3.identity, [w.identity, w2.identity])
def test_multi_different_topic_workers(self):
workers = worker_types.TopicWorkers()
added = []
added.append(workers.add('dummy-topic', [utils.DummyTask]))
added.append(workers.add('dummy-topic-2', [utils.DummyTask]))
added.append(workers.add('dummy-topic-3', [utils.NastyTask]))
self.assertEqual(3, len(workers))
w = workers.get_worker_for_task(utils.NastyTask)
self.assertEqual(added[-1].identity, w.identity)
w = workers.get_worker_for_task(utils.DummyTask)
self.assertIn(w.identity, [w_a.identity for w_a in added[0:2]])
def test_periodic_worker(self):
barrier = latch.Latch(5)
to = timing.Timeout(0.01)
called_at = []
def callee():
barrier.countdown()
if barrier.needed == 0:
to.interrupt()
called_at.append(time.time())
w = worker_types.PeriodicWorker(to, [callee])
t = threading.Thread(target=w.start)
t.start()
t.join()
self.assertEqual(0, barrier.needed)
self.assertEqual(5, len(called_at))
self.assertTrue(to.is_stopped())

View File

@@ -54,6 +54,21 @@ class ExpiringCache(object):
with self._lock:
del self._data[key]
def clear(self, on_cleared_callback=None):
"""Removes all keys & values from the cache."""
cleared_items = []
with self._lock:
if on_cleared_callback is not None:
cleared_items.extend(six.iteritems(self._data))
self._data.clear()
if on_cleared_callback is not None:
arg_c = len(reflection.get_callable_args(on_cleared_callback))
for (k, v) in cleared_items:
if arg_c == 2:
on_cleared_callback(k, v)
else:
on_cleared_callback(v)
def cleanup(self, on_expired_callback=None):
"""Delete out-dated keys & values from the cache."""
with self._lock: