Abstract out the worker finding from the WBE engine
To be able to easily plug-in future types of ways to get which topics (and tasks) workers exist on (and can perform) and to identify and keep this information up-to date refactor the functionality that currently does this using periodic messages into a finder type and a periodic function that exists on it (that will be periodically activated by an updated and improved periodic worker). Part of blueprint wbe-worker-info Change-Id: Ib3ae29758af3d244b4ac4624ac380caf88b159fd
This commit is contained in:

committed by
Joshua Harlow

parent
934e15a029
commit
19f9674877
@@ -44,6 +44,11 @@ Notifier
|
|||||||
|
|
||||||
.. automodule:: taskflow.types.notifier
|
.. automodule:: taskflow.types.notifier
|
||||||
|
|
||||||
|
Periodic
|
||||||
|
========
|
||||||
|
|
||||||
|
.. automodule:: taskflow.types.periodic
|
||||||
|
|
||||||
Table
|
Table
|
||||||
=====
|
=====
|
||||||
|
|
||||||
|
@@ -15,7 +15,6 @@
|
|||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
from kombu import exceptions as kombu_exc
|
from kombu import exceptions as kombu_exc
|
||||||
import six
|
|
||||||
|
|
||||||
from taskflow import exceptions as excp
|
from taskflow import exceptions as excp
|
||||||
from taskflow import logging
|
from taskflow import logging
|
||||||
@@ -27,14 +26,35 @@ LOG = logging.getLogger(__name__)
|
|||||||
class TypeDispatcher(object):
|
class TypeDispatcher(object):
|
||||||
"""Receives messages and dispatches to type specific handlers."""
|
"""Receives messages and dispatches to type specific handlers."""
|
||||||
|
|
||||||
def __init__(self, type_handlers):
|
def __init__(self, type_handlers=None, requeue_filters=None):
|
||||||
self._handlers = dict(type_handlers)
|
if type_handlers is not None:
|
||||||
self._requeue_filters = []
|
self._type_handlers = dict(type_handlers)
|
||||||
|
else:
|
||||||
|
self._type_handlers = {}
|
||||||
|
if requeue_filters is not None:
|
||||||
|
self._requeue_filters = list(requeue_filters)
|
||||||
|
else:
|
||||||
|
self._requeue_filters = []
|
||||||
|
|
||||||
def add_requeue_filter(self, callback):
|
@property
|
||||||
"""Add a callback that can *request* message requeuing.
|
def type_handlers(self):
|
||||||
|
"""Dictionary of message type -> callback to handle that message.
|
||||||
|
|
||||||
The callback will be activated before the message has been acked and
|
The callback(s) will be activated by looking for a message
|
||||||
|
property 'type' and locating a callback in this dictionary that maps
|
||||||
|
to that type; if one is found it is expected to be a callback that
|
||||||
|
accepts two positional parameters; the first being the message data
|
||||||
|
and the second being the message object. If a callback is not found
|
||||||
|
then the message is rejected and it will be up to the underlying
|
||||||
|
message transport to determine what this means/implies...
|
||||||
|
"""
|
||||||
|
return self._type_handlers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requeue_filters(self):
|
||||||
|
"""List of filters (callbacks) to request a message to be requeued.
|
||||||
|
|
||||||
|
The callback(s) will be activated before the message has been acked and
|
||||||
it can be used to instruct the dispatcher to requeue the message
|
it can be used to instruct the dispatcher to requeue the message
|
||||||
instead of processing it. The callback, when called, will be provided
|
instead of processing it. The callback, when called, will be provided
|
||||||
two positional parameters; the first being the message data and the
|
two positional parameters; the first being the message data and the
|
||||||
@@ -42,9 +62,7 @@ class TypeDispatcher(object):
|
|||||||
filter should return a truthy object if the message should be requeued
|
filter should return a truthy object if the message should be requeued
|
||||||
and a falsey object if it should not.
|
and a falsey object if it should not.
|
||||||
"""
|
"""
|
||||||
if not six.callable(callback):
|
return self._requeue_filters
|
||||||
raise ValueError("Requeue filter callback must be callable")
|
|
||||||
self._requeue_filters.append(callback)
|
|
||||||
|
|
||||||
def _collect_requeue_votes(self, data, message):
|
def _collect_requeue_votes(self, data, message):
|
||||||
# Returns how many of the filters asked for the message to be requeued.
|
# Returns how many of the filters asked for the message to be requeued.
|
||||||
@@ -74,7 +92,7 @@ class TypeDispatcher(object):
|
|||||||
LOG.debug("Message '%s' was requeued.", ku.DelayedPretty(message))
|
LOG.debug("Message '%s' was requeued.", ku.DelayedPretty(message))
|
||||||
|
|
||||||
def _process_message(self, data, message, message_type):
|
def _process_message(self, data, message, message_type):
|
||||||
handler = self._handlers.get(message_type)
|
handler = self._type_handlers.get(message_type)
|
||||||
if handler is None:
|
if handler is None:
|
||||||
message.reject_log_error(logger=LOG,
|
message.reject_log_error(logger=LOG,
|
||||||
errors=(kombu_exc.MessageStateError,))
|
errors=(kombu_exc.MessageStateError,))
|
||||||
|
@@ -25,7 +25,7 @@ from taskflow.engines.worker_based import types as wt
|
|||||||
from taskflow import exceptions as exc
|
from taskflow import exceptions as exc
|
||||||
from taskflow import logging
|
from taskflow import logging
|
||||||
from taskflow import task as task_atom
|
from taskflow import task as task_atom
|
||||||
from taskflow.types import timing as tt
|
from taskflow.types import periodic
|
||||||
from taskflow.utils import kombu_utils as ku
|
from taskflow.utils import kombu_utils as ku
|
||||||
from taskflow.utils import misc
|
from taskflow.utils import misc
|
||||||
from taskflow.utils import threading_utils as tu
|
from taskflow.utils import threading_utils as tu
|
||||||
@@ -41,51 +41,41 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
|||||||
url=None, transport=None, transport_options=None,
|
url=None, transport=None, transport_options=None,
|
||||||
retry_options=None):
|
retry_options=None):
|
||||||
self._uuid = uuid
|
self._uuid = uuid
|
||||||
self._topics = topics
|
|
||||||
self._requests_cache = wt.RequestsCache()
|
self._requests_cache = wt.RequestsCache()
|
||||||
self._workers = wt.TopicWorkers()
|
|
||||||
self._transition_timeout = transition_timeout
|
self._transition_timeout = transition_timeout
|
||||||
type_handlers = {
|
type_handlers = {
|
||||||
pr.NOTIFY: [
|
|
||||||
self._process_notify,
|
|
||||||
functools.partial(pr.Notify.validate, response=True),
|
|
||||||
],
|
|
||||||
pr.RESPONSE: [
|
pr.RESPONSE: [
|
||||||
self._process_response,
|
self._process_response,
|
||||||
pr.Response.validate,
|
pr.Response.validate,
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
self._proxy = proxy.Proxy(uuid, exchange, type_handlers,
|
self._proxy = proxy.Proxy(uuid, exchange,
|
||||||
|
type_handlers=type_handlers,
|
||||||
on_wait=self._on_wait, url=url,
|
on_wait=self._on_wait, url=url,
|
||||||
transport=transport,
|
transport=transport,
|
||||||
transport_options=transport_options,
|
transport_options=transport_options,
|
||||||
retry_options=retry_options)
|
retry_options=retry_options)
|
||||||
self._periodic = wt.PeriodicWorker(tt.Timeout(pr.NOTIFY_PERIOD),
|
# NOTE(harlowja): This is the most simplest finder impl. that
|
||||||
[self._notify_topics])
|
# doesn't have external dependencies (outside of what this engine
|
||||||
|
# already requires); it though does create periodic 'polling' traffic
|
||||||
|
# to workers to 'learn' of the tasks they can perform (and requires
|
||||||
|
# pre-existing knowledge of the topics those workers are on to gather
|
||||||
|
# and update this information).
|
||||||
|
self._finder = wt.ProxyWorkerFinder(uuid, self._proxy, topics)
|
||||||
|
self._finder.on_worker = self._on_worker
|
||||||
self._helpers = tu.ThreadBundle()
|
self._helpers = tu.ThreadBundle()
|
||||||
self._helpers.bind(lambda: tu.daemon_thread(self._proxy.start),
|
self._helpers.bind(lambda: tu.daemon_thread(self._proxy.start),
|
||||||
after_start=lambda t: self._proxy.wait(),
|
after_start=lambda t: self._proxy.wait(),
|
||||||
before_join=lambda t: self._proxy.stop())
|
before_join=lambda t: self._proxy.stop())
|
||||||
self._helpers.bind(lambda: tu.daemon_thread(self._periodic.start),
|
p_worker = periodic.PeriodicWorker.create([self._finder])
|
||||||
before_join=lambda t: self._periodic.stop(),
|
if p_worker:
|
||||||
after_join=lambda t: self._periodic.reset(),
|
self._helpers.bind(lambda: tu.daemon_thread(p_worker.start),
|
||||||
before_start=lambda t: self._periodic.reset())
|
before_join=lambda t: p_worker.stop(),
|
||||||
|
after_join=lambda t: p_worker.reset(),
|
||||||
|
before_start=lambda t: p_worker.reset())
|
||||||
|
|
||||||
def _process_notify(self, notify, message):
|
def _on_worker(self, worker):
|
||||||
"""Process notify message from remote side."""
|
"""Process new worker that has arrived (and fire off any work)."""
|
||||||
LOG.debug("Started processing notify message '%s'",
|
|
||||||
ku.DelayedPretty(message))
|
|
||||||
|
|
||||||
topic = notify['topic']
|
|
||||||
tasks = notify['tasks']
|
|
||||||
|
|
||||||
# Add worker info to the cache
|
|
||||||
worker = self._workers.add(topic, tasks)
|
|
||||||
LOG.debug("Received notification about worker '%s' (%s"
|
|
||||||
" total workers are currently known)", worker,
|
|
||||||
len(self._workers))
|
|
||||||
|
|
||||||
# Publish waiting requests
|
|
||||||
for request in self._requests_cache.get_waiting_requests(worker):
|
for request in self._requests_cache.get_waiting_requests(worker):
|
||||||
if request.transition_and_log_error(pr.PENDING, logger=LOG):
|
if request.transition_and_log_error(pr.PENDING, logger=LOG):
|
||||||
self._publish_request(request, worker)
|
self._publish_request(request, worker)
|
||||||
@@ -174,7 +164,7 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
|||||||
request.result.add_done_callback(lambda fut: cleaner())
|
request.result.add_done_callback(lambda fut: cleaner())
|
||||||
|
|
||||||
# Get task's worker and publish request if worker was found.
|
# Get task's worker and publish request if worker was found.
|
||||||
worker = self._workers.get_worker_for_task(task)
|
worker = self._finder.get_worker_for_task(task)
|
||||||
if worker is not None:
|
if worker is not None:
|
||||||
# NOTE(skudriashev): Make sure request is set to the PENDING state
|
# NOTE(skudriashev): Make sure request is set to the PENDING state
|
||||||
# before putting it into the requests cache to prevent the notify
|
# before putting it into the requests cache to prevent the notify
|
||||||
@@ -208,10 +198,6 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
|||||||
del self._requests_cache[request.uuid]
|
del self._requests_cache[request.uuid]
|
||||||
request.set_result(failure)
|
request.set_result(failure)
|
||||||
|
|
||||||
def _notify_topics(self):
|
|
||||||
"""Cyclically called to publish notify message to each topic."""
|
|
||||||
self._proxy.publish(pr.Notify(), self._topics, reply_to=self._uuid)
|
|
||||||
|
|
||||||
def execute_task(self, task, task_uuid, arguments,
|
def execute_task(self, task, task_uuid, arguments,
|
||||||
progress_callback=None):
|
progress_callback=None):
|
||||||
return self._submit_task(task, task_uuid, pr.EXECUTE, arguments,
|
return self._submit_task(task, task_uuid, pr.EXECUTE, arguments,
|
||||||
@@ -232,7 +218,8 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
|||||||
return how many workers are still needed, otherwise it will
|
return how many workers are still needed, otherwise it will
|
||||||
return zero.
|
return zero.
|
||||||
"""
|
"""
|
||||||
return self._workers.wait_for_workers(workers=workers, timeout=timeout)
|
return self._finder.wait_for_workers(workers=workers,
|
||||||
|
timeout=timeout)
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Starts proxy thread and associated topic notification thread."""
|
"""Starts proxy thread and associated topic notification thread."""
|
||||||
@@ -242,4 +229,4 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
|||||||
"""Stops proxy thread and associated topic notification thread."""
|
"""Stops proxy thread and associated topic notification thread."""
|
||||||
self._helpers.stop()
|
self._helpers.stop()
|
||||||
self._requests_cache.clear(self._handle_expired_request)
|
self._requests_cache.clear(self._handle_expired_request)
|
||||||
self._workers.clear()
|
self._finder.clear()
|
||||||
|
@@ -68,19 +68,19 @@ class Proxy(object):
|
|||||||
# value is valid...
|
# value is valid...
|
||||||
_RETRY_INT_OPTS = frozenset(['max_retries'])
|
_RETRY_INT_OPTS = frozenset(['max_retries'])
|
||||||
|
|
||||||
def __init__(self, topic, exchange, type_handlers,
|
def __init__(self, topic, exchange,
|
||||||
on_wait=None, url=None,
|
type_handlers=None, on_wait=None, url=None,
|
||||||
transport=None, transport_options=None,
|
transport=None, transport_options=None,
|
||||||
retry_options=None):
|
retry_options=None):
|
||||||
self._topic = topic
|
self._topic = topic
|
||||||
self._exchange_name = exchange
|
self._exchange_name = exchange
|
||||||
self._on_wait = on_wait
|
self._on_wait = on_wait
|
||||||
self._running = threading_utils.Event()
|
self._running = threading_utils.Event()
|
||||||
self._dispatcher = dispatcher.TypeDispatcher(type_handlers)
|
self._dispatcher = dispatcher.TypeDispatcher(
|
||||||
self._dispatcher.add_requeue_filter(
|
|
||||||
# NOTE(skudriashev): Process all incoming messages only if proxy is
|
# NOTE(skudriashev): Process all incoming messages only if proxy is
|
||||||
# running, otherwise requeue them.
|
# running, otherwise requeue them.
|
||||||
lambda data, message: not self.is_running)
|
requeue_filters=[lambda data, message: not self.is_running],
|
||||||
|
type_handlers=type_handlers)
|
||||||
|
|
||||||
ensure_options = self.DEFAULT_RETRY_OPTIONS.copy()
|
ensure_options = self.DEFAULT_RETRY_OPTIONS.copy()
|
||||||
if retry_options is not None:
|
if retry_options is not None:
|
||||||
@@ -112,11 +112,16 @@ class Proxy(object):
|
|||||||
|
|
||||||
# create exchange
|
# create exchange
|
||||||
self._exchange = kombu.Exchange(name=self._exchange_name,
|
self._exchange = kombu.Exchange(name=self._exchange_name,
|
||||||
durable=False,
|
durable=False, auto_delete=True)
|
||||||
auto_delete=True)
|
|
||||||
|
@property
|
||||||
|
def dispatcher(self):
|
||||||
|
"""Dispatcher internally used to dispatch message(s) that match."""
|
||||||
|
return self._dispatcher
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def connection_details(self):
|
def connection_details(self):
|
||||||
|
"""Details about the connection (read-only)."""
|
||||||
# The kombu drivers seem to use 'N/A' when they don't have a version...
|
# The kombu drivers seem to use 'N/A' when they don't have a version...
|
||||||
driver_version = self._conn.transport.driver_version()
|
driver_version = self._conn.transport.driver_version()
|
||||||
if driver_version and driver_version.lower() == 'n/a':
|
if driver_version and driver_version.lower() == 'n/a':
|
||||||
|
@@ -59,7 +59,8 @@ class Server(object):
|
|||||||
pr.Request.validate,
|
pr.Request.validate,
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
self._proxy = proxy.Proxy(topic, exchange, type_handlers,
|
self._proxy = proxy.Proxy(topic, exchange,
|
||||||
|
type_handlers=type_handlers,
|
||||||
url=url, transport=transport,
|
url=url, transport=transport,
|
||||||
transport_options=transport_options,
|
transport_options=transport_options,
|
||||||
retry_options=retry_options)
|
retry_options=retry_options)
|
||||||
|
@@ -14,17 +14,21 @@
|
|||||||
# License for the specific language governing permissions and limitations
|
# License for the specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
|
||||||
import random
|
import random
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from oslo.utils import reflection
|
from oslo_utils import reflection
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from taskflow.engines.worker_based import protocol as pr
|
from taskflow.engines.worker_based import protocol as pr
|
||||||
|
from taskflow import logging
|
||||||
from taskflow.types import cache as base
|
from taskflow.types import cache as base
|
||||||
|
from taskflow.types import periodic
|
||||||
from taskflow.types import timing as tt
|
from taskflow.types import timing as tt
|
||||||
|
from taskflow.utils import kombu_utils as ku
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -91,8 +95,37 @@ class TopicWorker(object):
|
|||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
||||||
class TopicWorkers(object):
|
@six.add_metaclass(abc.ABCMeta)
|
||||||
"""A collection of topic based workers."""
|
class WorkerFinder(object):
|
||||||
|
"""Base class for worker finders..."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._cond = threading.Condition()
|
||||||
|
self.on_worker = None
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _total_workers(self):
|
||||||
|
"""Returns how many workers are known."""
|
||||||
|
|
||||||
|
def wait_for_workers(self, workers=1, timeout=None):
|
||||||
|
"""Waits for geq workers to notify they are ready to do work.
|
||||||
|
|
||||||
|
NOTE(harlowja): if a timeout is provided this function will wait
|
||||||
|
until that timeout expires, if the amount of workers does not reach
|
||||||
|
the desired amount of workers before the timeout expires then this will
|
||||||
|
return how many workers are still needed, otherwise it will
|
||||||
|
return zero.
|
||||||
|
"""
|
||||||
|
if workers <= 0:
|
||||||
|
raise ValueError("Worker amount must be greater than zero")
|
||||||
|
watch = tt.StopWatch(duration=timeout)
|
||||||
|
watch.start()
|
||||||
|
with self._cond:
|
||||||
|
while self._total_workers() < workers:
|
||||||
|
if watch.expired():
|
||||||
|
return max(0, workers - self._total_workers())
|
||||||
|
self._cond.wait(watch.leftover(return_none=True))
|
||||||
|
return 0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _match_worker(task, available_workers):
|
def _match_worker(task, available_workers):
|
||||||
@@ -110,14 +143,30 @@ class TopicWorkers(object):
|
|||||||
else:
|
else:
|
||||||
return random.choice(available_workers)
|
return random.choice(available_workers)
|
||||||
|
|
||||||
def __init__(self):
|
@abc.abstractmethod
|
||||||
self._workers = {}
|
def get_worker_for_task(self, task):
|
||||||
self._cond = threading.Condition()
|
"""Gets a worker that can perform a given task."""
|
||||||
# Used to name workers with more useful identities...
|
|
||||||
self._counter = itertools.count()
|
|
||||||
|
|
||||||
def __len__(self):
|
def clear(self):
|
||||||
return len(self._workers)
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ProxyWorkerFinder(WorkerFinder):
|
||||||
|
"""Requests and receives responses about workers topic+task details."""
|
||||||
|
|
||||||
|
def __init__(self, uuid, proxy, topics):
|
||||||
|
super(ProxyWorkerFinder, self).__init__()
|
||||||
|
self._proxy = proxy
|
||||||
|
self._topics = topics
|
||||||
|
self._workers = {}
|
||||||
|
self._uuid = uuid
|
||||||
|
self._proxy.dispatcher.type_handlers.update({
|
||||||
|
pr.NOTIFY: [
|
||||||
|
self._process_response,
|
||||||
|
functools.partial(pr.Notify.validate, response=True),
|
||||||
|
],
|
||||||
|
})
|
||||||
|
self._counter = itertools.count()
|
||||||
|
|
||||||
def _next_worker(self, topic, tasks, temporary=False):
|
def _next_worker(self, topic, tasks, temporary=False):
|
||||||
if not temporary:
|
if not temporary:
|
||||||
@@ -126,48 +175,54 @@ class TopicWorkers(object):
|
|||||||
else:
|
else:
|
||||||
return TopicWorker(topic, tasks)
|
return TopicWorker(topic, tasks)
|
||||||
|
|
||||||
def add(self, topic, tasks):
|
@periodic.periodic(pr.NOTIFY_PERIOD)
|
||||||
|
def beat(self):
|
||||||
|
"""Cyclically called to publish notify message to each topic."""
|
||||||
|
self._proxy.publish(pr.Notify(), self._topics, reply_to=self._uuid)
|
||||||
|
|
||||||
|
def _total_workers(self):
|
||||||
|
return len(self._workers)
|
||||||
|
|
||||||
|
def _add(self, topic, tasks):
|
||||||
"""Adds/updates a worker for the topic for the given tasks."""
|
"""Adds/updates a worker for the topic for the given tasks."""
|
||||||
|
try:
|
||||||
|
worker = self._workers[topic]
|
||||||
|
# Check if we already have an equivalent worker, if so just
|
||||||
|
# return it...
|
||||||
|
if worker == self._next_worker(topic, tasks, temporary=True):
|
||||||
|
return (worker, False)
|
||||||
|
# This *fall through* is done so that if someone is using an
|
||||||
|
# active worker object that already exists that we just create
|
||||||
|
# a new one; so that the existing object doesn't get
|
||||||
|
# affected (workers objects are supposed to be immutable).
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
worker = self._next_worker(topic, tasks)
|
||||||
|
self._workers[topic] = worker
|
||||||
|
return (worker, True)
|
||||||
|
|
||||||
|
def _process_response(self, response, message):
|
||||||
|
"""Process notify message from remote side."""
|
||||||
|
LOG.debug("Started processing notify message '%s'",
|
||||||
|
ku.DelayedPretty(message))
|
||||||
|
topic = response['topic']
|
||||||
|
tasks = response['tasks']
|
||||||
with self._cond:
|
with self._cond:
|
||||||
try:
|
worker, new_or_updated = self._add(topic, tasks)
|
||||||
worker = self._workers[topic]
|
if new_or_updated:
|
||||||
# Check if we already have an equivalent worker, if so just
|
LOG.debug("Received notification about worker '%s' (%s"
|
||||||
# return it...
|
" total workers are currently known)", worker,
|
||||||
if worker == self._next_worker(topic, tasks, temporary=True):
|
self._total_workers())
|
||||||
return worker
|
self._cond.notify_all()
|
||||||
# This *fall through* is done so that if someone is using an
|
if self.on_worker is not None and new_or_updated:
|
||||||
# active worker object that already exists that we just create
|
self.on_worker(worker)
|
||||||
# a new one; so that the existing object doesn't get
|
|
||||||
# affected (workers objects are supposed to be immutable).
|
def clear(self):
|
||||||
except KeyError:
|
with self._cond:
|
||||||
pass
|
self._workers.clear()
|
||||||
worker = self._next_worker(topic, tasks)
|
|
||||||
self._workers[topic] = worker
|
|
||||||
self._cond.notify_all()
|
self._cond.notify_all()
|
||||||
return worker
|
|
||||||
|
|
||||||
def wait_for_workers(self, workers=1, timeout=None):
|
|
||||||
"""Waits for geq workers to notify they are ready to do work.
|
|
||||||
|
|
||||||
NOTE(harlowja): if a timeout is provided this function will wait
|
|
||||||
until that timeout expires, if the amount of workers does not reach
|
|
||||||
the desired amount of workers before the timeout expires then this will
|
|
||||||
return how many workers are still needed, otherwise it will
|
|
||||||
return zero.
|
|
||||||
"""
|
|
||||||
if workers <= 0:
|
|
||||||
raise ValueError("Worker amount must be greater than zero")
|
|
||||||
watch = tt.StopWatch(duration=timeout)
|
|
||||||
watch.start()
|
|
||||||
with self._cond:
|
|
||||||
while len(self._workers) < workers:
|
|
||||||
if watch.expired():
|
|
||||||
return max(0, workers - len(self._workers))
|
|
||||||
self._cond.wait(watch.leftover(return_none=True))
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def get_worker_for_task(self, task):
|
def get_worker_for_task(self, task):
|
||||||
"""Gets a worker that can perform a given task."""
|
|
||||||
available_workers = []
|
available_workers = []
|
||||||
with self._cond:
|
with self._cond:
|
||||||
for worker in six.itervalues(self._workers):
|
for worker in six.itervalues(self._workers):
|
||||||
@@ -177,37 +232,3 @@ class TopicWorkers(object):
|
|||||||
return self._match_worker(task, available_workers)
|
return self._match_worker(task, available_workers)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
with self._cond:
|
|
||||||
self._workers.clear()
|
|
||||||
self._cond.notify_all()
|
|
||||||
|
|
||||||
|
|
||||||
class PeriodicWorker(object):
|
|
||||||
"""Calls a set of functions when activated periodically.
|
|
||||||
|
|
||||||
NOTE(harlowja): the provided timeout object determines the periodicity.
|
|
||||||
"""
|
|
||||||
def __init__(self, timeout, functors):
|
|
||||||
self._timeout = timeout
|
|
||||||
self._functors = []
|
|
||||||
for f in functors:
|
|
||||||
self._functors.append((f, reflection.get_callable_name(f)))
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
while not self._timeout.is_stopped():
|
|
||||||
for (f, f_name) in self._functors:
|
|
||||||
LOG.debug("Calling periodic function '%s'", f_name)
|
|
||||||
try:
|
|
||||||
f()
|
|
||||||
except Exception:
|
|
||||||
LOG.warn("Failed to call periodic function '%s'", f_name,
|
|
||||||
exc_info=True)
|
|
||||||
self._timeout.wait()
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
self._timeout.interrupt()
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self._timeout.reset()
|
|
||||||
|
@@ -14,6 +14,8 @@
|
|||||||
# License for the specific language governing permissions and limitations
|
# License for the specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import six
|
import six
|
||||||
|
|
||||||
@@ -21,9 +23,31 @@ from taskflow import exceptions as excp
|
|||||||
from taskflow import test
|
from taskflow import test
|
||||||
from taskflow.types import fsm
|
from taskflow.types import fsm
|
||||||
from taskflow.types import graph
|
from taskflow.types import graph
|
||||||
|
from taskflow.types import latch
|
||||||
|
from taskflow.types import periodic
|
||||||
from taskflow.types import table
|
from taskflow.types import table
|
||||||
from taskflow.types import timing as tt
|
from taskflow.types import timing as tt
|
||||||
from taskflow.types import tree
|
from taskflow.types import tree
|
||||||
|
from taskflow.utils import threading_utils as tu
|
||||||
|
|
||||||
|
|
||||||
|
class PeriodicThingy(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.capture = []
|
||||||
|
|
||||||
|
@periodic.periodic(0.01)
|
||||||
|
def a(self):
|
||||||
|
self.capture.append('a')
|
||||||
|
|
||||||
|
@periodic.periodic(0.02)
|
||||||
|
def b(self):
|
||||||
|
self.capture.append('b')
|
||||||
|
|
||||||
|
def c(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def d(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class GraphTest(test.TestCase):
|
class GraphTest(test.TestCase):
|
||||||
@@ -451,3 +475,112 @@ class FSMTest(test.TestCase):
|
|||||||
m.add_state('broken')
|
m.add_state('broken')
|
||||||
self.assertRaises(ValueError, m.add_state, 'b', on_enter=2)
|
self.assertRaises(ValueError, m.add_state, 'b', on_enter=2)
|
||||||
self.assertRaises(ValueError, m.add_state, 'b', on_exit=2)
|
self.assertRaises(ValueError, m.add_state, 'b', on_exit=2)
|
||||||
|
|
||||||
|
|
||||||
|
class PeriodicTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_invalid_periodic(self):
|
||||||
|
|
||||||
|
def no_op():
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.assertRaises(ValueError, periodic.periodic, -1)
|
||||||
|
|
||||||
|
def test_valid_periodic(self):
|
||||||
|
|
||||||
|
@periodic.periodic(2)
|
||||||
|
def no_op():
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.assertTrue(getattr(no_op, '_periodic'))
|
||||||
|
self.assertEqual(2, getattr(no_op, '_periodic_spacing'))
|
||||||
|
self.assertEqual(True, getattr(no_op, '_periodic_run_immediately'))
|
||||||
|
|
||||||
|
def test_scanning_periodic(self):
|
||||||
|
p = PeriodicThingy()
|
||||||
|
w = periodic.PeriodicWorker.create([p])
|
||||||
|
self.assertEqual(2, len(w))
|
||||||
|
|
||||||
|
t = tu.daemon_thread(target=w.start)
|
||||||
|
t.start()
|
||||||
|
time.sleep(0.1)
|
||||||
|
w.stop()
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
b_calls = [c for c in p.capture if c == 'b']
|
||||||
|
self.assertGreater(0, len(b_calls))
|
||||||
|
a_calls = [c for c in p.capture if c == 'a']
|
||||||
|
self.assertGreater(0, len(a_calls))
|
||||||
|
|
||||||
|
def test_periodic_single(self):
|
||||||
|
barrier = latch.Latch(5)
|
||||||
|
capture = []
|
||||||
|
tombstone = tu.Event()
|
||||||
|
|
||||||
|
@periodic.periodic(0.01)
|
||||||
|
def callee():
|
||||||
|
barrier.countdown()
|
||||||
|
if barrier.needed == 0:
|
||||||
|
tombstone.set()
|
||||||
|
capture.append(1)
|
||||||
|
|
||||||
|
w = periodic.PeriodicWorker([callee], tombstone=tombstone)
|
||||||
|
t = tu.daemon_thread(target=w.start)
|
||||||
|
t.start()
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
self.assertEqual(0, barrier.needed)
|
||||||
|
self.assertEqual(5, sum(capture))
|
||||||
|
self.assertTrue(tombstone.is_set())
|
||||||
|
|
||||||
|
def test_immediate(self):
|
||||||
|
capture = []
|
||||||
|
|
||||||
|
@periodic.periodic(120, run_immediately=True)
|
||||||
|
def a():
|
||||||
|
capture.append('a')
|
||||||
|
|
||||||
|
w = periodic.PeriodicWorker([a])
|
||||||
|
t = tu.daemon_thread(target=w.start)
|
||||||
|
t.start()
|
||||||
|
time.sleep(0.1)
|
||||||
|
w.stop()
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
a_calls = [c for c in capture if c == 'a']
|
||||||
|
self.assertGreater(0, len(a_calls))
|
||||||
|
|
||||||
|
def test_period_double_no_immediate(self):
|
||||||
|
capture = []
|
||||||
|
|
||||||
|
@periodic.periodic(0.01, run_immediately=False)
|
||||||
|
def a():
|
||||||
|
capture.append('a')
|
||||||
|
|
||||||
|
@periodic.periodic(0.02, run_immediately=False)
|
||||||
|
def b():
|
||||||
|
capture.append('b')
|
||||||
|
|
||||||
|
w = periodic.PeriodicWorker([a, b])
|
||||||
|
t = tu.daemon_thread(target=w.start)
|
||||||
|
t.start()
|
||||||
|
time.sleep(0.1)
|
||||||
|
w.stop()
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
b_calls = [c for c in capture if c == 'b']
|
||||||
|
self.assertGreater(0, len(b_calls))
|
||||||
|
a_calls = [c for c in capture if c == 'a']
|
||||||
|
self.assertGreater(0, len(a_calls))
|
||||||
|
|
||||||
|
def test_start_nothing_error(self):
|
||||||
|
w = periodic.PeriodicWorker([])
|
||||||
|
self.assertRaises(RuntimeError, w.start)
|
||||||
|
|
||||||
|
def test_missing_function_attrs(self):
|
||||||
|
|
||||||
|
def fake_periodic():
|
||||||
|
pass
|
||||||
|
|
||||||
|
cb = fake_periodic
|
||||||
|
self.assertRaises(ValueError, periodic.PeriodicWorker, [cb])
|
||||||
|
@@ -41,12 +41,12 @@ class TestDispatcher(test.TestCase):
|
|||||||
def test_creation(self):
|
def test_creation(self):
|
||||||
on_hello = mock.MagicMock()
|
on_hello = mock.MagicMock()
|
||||||
handlers = {'hello': on_hello}
|
handlers = {'hello': on_hello}
|
||||||
dispatcher.TypeDispatcher(handlers)
|
dispatcher.TypeDispatcher(type_handlers=handlers)
|
||||||
|
|
||||||
def test_on_message(self):
|
def test_on_message(self):
|
||||||
on_hello = mock.MagicMock()
|
on_hello = mock.MagicMock()
|
||||||
handlers = {'hello': on_hello}
|
handlers = {'hello': on_hello}
|
||||||
d = dispatcher.TypeDispatcher(handlers)
|
d = dispatcher.TypeDispatcher(type_handlers=handlers)
|
||||||
msg = mock_acked_message(properties={'type': 'hello'})
|
msg = mock_acked_message(properties={'type': 'hello'})
|
||||||
d.on_message("", msg)
|
d.on_message("", msg)
|
||||||
self.assertTrue(on_hello.called)
|
self.assertTrue(on_hello.called)
|
||||||
@@ -54,15 +54,15 @@ class TestDispatcher(test.TestCase):
|
|||||||
self.assertTrue(msg.acknowledged)
|
self.assertTrue(msg.acknowledged)
|
||||||
|
|
||||||
def test_on_rejected_message(self):
|
def test_on_rejected_message(self):
|
||||||
d = dispatcher.TypeDispatcher({})
|
d = dispatcher.TypeDispatcher()
|
||||||
msg = mock_acked_message(properties={'type': 'hello'})
|
msg = mock_acked_message(properties={'type': 'hello'})
|
||||||
d.on_message("", msg)
|
d.on_message("", msg)
|
||||||
self.assertTrue(msg.reject_log_error.called)
|
self.assertTrue(msg.reject_log_error.called)
|
||||||
self.assertFalse(msg.acknowledged)
|
self.assertFalse(msg.acknowledged)
|
||||||
|
|
||||||
def test_on_requeue_message(self):
|
def test_on_requeue_message(self):
|
||||||
d = dispatcher.TypeDispatcher({})
|
d = dispatcher.TypeDispatcher()
|
||||||
d.add_requeue_filter(lambda data, message: True)
|
d.requeue_filters.append(lambda data, message: True)
|
||||||
msg = mock_acked_message()
|
msg = mock_acked_message()
|
||||||
d.on_message("", msg)
|
d.on_message("", msg)
|
||||||
self.assertTrue(msg.requeue.called)
|
self.assertTrue(msg.requeue.called)
|
||||||
@@ -71,7 +71,7 @@ class TestDispatcher(test.TestCase):
|
|||||||
def test_failed_ack(self):
|
def test_failed_ack(self):
|
||||||
on_hello = mock.MagicMock()
|
on_hello = mock.MagicMock()
|
||||||
handlers = {'hello': on_hello}
|
handlers = {'hello': on_hello}
|
||||||
d = dispatcher.TypeDispatcher(handlers)
|
d = dispatcher.TypeDispatcher(type_handlers=handlers)
|
||||||
msg = mock_acked_message(ack_ok=False,
|
msg = mock_acked_message(ack_ok=False,
|
||||||
properties={'type': 'hello'})
|
properties={'type': 'hello'})
|
||||||
d.on_message("", msg)
|
d.on_message("", msg)
|
||||||
|
@@ -86,11 +86,12 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
|||||||
ex = self.executor(reset_master_mock=False)
|
ex = self.executor(reset_master_mock=False)
|
||||||
master_mock_calls = [
|
master_mock_calls = [
|
||||||
mock.call.Proxy(self.executor_uuid, self.executor_exchange,
|
mock.call.Proxy(self.executor_uuid, self.executor_exchange,
|
||||||
mock.ANY, on_wait=ex._on_wait,
|
on_wait=ex._on_wait,
|
||||||
url=self.broker_url, transport=mock.ANY,
|
url=self.broker_url, transport=mock.ANY,
|
||||||
transport_options=mock.ANY,
|
transport_options=mock.ANY,
|
||||||
retry_options=mock.ANY
|
retry_options=mock.ANY,
|
||||||
)
|
type_handlers=mock.ANY),
|
||||||
|
mock.call.proxy.dispatcher.type_handlers.update(mock.ANY),
|
||||||
]
|
]
|
||||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||||
|
|
||||||
@@ -212,10 +213,8 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
|||||||
self.assertEqual(len(ex._requests_cache), 0)
|
self.assertEqual(len(ex._requests_cache), 0)
|
||||||
|
|
||||||
def test_execute_task(self):
|
def test_execute_task(self):
|
||||||
self.message_mock.properties['type'] = pr.NOTIFY
|
|
||||||
notify = pr.Notify(topic=self.executor_topic, tasks=[self.task.name])
|
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._process_notify(notify.to_dict(), self.message_mock)
|
ex._finder._add(self.executor_topic, [self.task.name])
|
||||||
ex.execute_task(self.task, self.task_uuid, self.task_args)
|
ex.execute_task(self.task, self.task_uuid, self.task_args)
|
||||||
|
|
||||||
expected_calls = [
|
expected_calls = [
|
||||||
@@ -231,10 +230,8 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
|||||||
self.assertEqual(expected_calls, self.master_mock.mock_calls)
|
self.assertEqual(expected_calls, self.master_mock.mock_calls)
|
||||||
|
|
||||||
def test_revert_task(self):
|
def test_revert_task(self):
|
||||||
self.message_mock.properties['type'] = pr.NOTIFY
|
|
||||||
notify = pr.Notify(topic=self.executor_topic, tasks=[self.task.name])
|
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._process_notify(notify.to_dict(), self.message_mock)
|
ex._finder._add(self.executor_topic, [self.task.name])
|
||||||
ex.revert_task(self.task, self.task_uuid, self.task_args,
|
ex.revert_task(self.task, self.task_uuid, self.task_args,
|
||||||
self.task_result, self.task_failures)
|
self.task_result, self.task_failures)
|
||||||
|
|
||||||
@@ -263,11 +260,9 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
|||||||
self.assertEqual(self.master_mock.mock_calls, expected_calls)
|
self.assertEqual(self.master_mock.mock_calls, expected_calls)
|
||||||
|
|
||||||
def test_execute_task_publish_error(self):
|
def test_execute_task_publish_error(self):
|
||||||
self.message_mock.properties['type'] = pr.NOTIFY
|
|
||||||
self.proxy_inst_mock.publish.side_effect = Exception('Woot!')
|
self.proxy_inst_mock.publish.side_effect = Exception('Woot!')
|
||||||
notify = pr.Notify(topic=self.executor_topic, tasks=[self.task.name])
|
|
||||||
ex = self.executor()
|
ex = self.executor()
|
||||||
ex._process_notify(notify.to_dict(), self.message_mock)
|
ex._finder._add(self.executor_topic, [self.task.name])
|
||||||
ex.execute_task(self.task, self.task_uuid, self.task_args)
|
ex.execute_task(self.task, self.task_uuid, self.task_args)
|
||||||
|
|
||||||
expected_calls = [
|
expected_calls = [
|
||||||
|
@@ -86,7 +86,7 @@ class TestServer(test.MockTestCase):
|
|||||||
# check calls
|
# check calls
|
||||||
master_mock_calls = [
|
master_mock_calls = [
|
||||||
mock.call.Proxy(self.server_topic, self.server_exchange,
|
mock.call.Proxy(self.server_topic, self.server_exchange,
|
||||||
mock.ANY, url=self.broker_url,
|
type_handlers=mock.ANY, url=self.broker_url,
|
||||||
transport=mock.ANY, transport_options=mock.ANY,
|
transport=mock.ANY, transport_options=mock.ANY,
|
||||||
retry_options=mock.ANY)
|
retry_options=mock.ANY)
|
||||||
]
|
]
|
||||||
@@ -99,7 +99,7 @@ class TestServer(test.MockTestCase):
|
|||||||
# check calls
|
# check calls
|
||||||
master_mock_calls = [
|
master_mock_calls = [
|
||||||
mock.call.Proxy(self.server_topic, self.server_exchange,
|
mock.call.Proxy(self.server_topic, self.server_exchange,
|
||||||
mock.ANY, url=self.broker_url,
|
type_handlers=mock.ANY, url=self.broker_url,
|
||||||
transport=mock.ANY, transport_options=mock.ANY,
|
transport=mock.ANY, transport_options=mock.ANY,
|
||||||
retry_options=mock.ANY)
|
retry_options=mock.ANY)
|
||||||
]
|
]
|
||||||
|
@@ -14,23 +14,20 @@
|
|||||||
# License for the specific language governing permissions and limitations
|
# License for the specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
|
|
||||||
from oslo.utils import reflection
|
from oslo.utils import reflection
|
||||||
|
|
||||||
from taskflow.engines.worker_based import protocol as pr
|
from taskflow.engines.worker_based import protocol as pr
|
||||||
from taskflow.engines.worker_based import types as worker_types
|
from taskflow.engines.worker_based import types as worker_types
|
||||||
from taskflow import test
|
from taskflow import test
|
||||||
|
from taskflow.test import mock
|
||||||
from taskflow.tests import utils
|
from taskflow.tests import utils
|
||||||
from taskflow.types import latch
|
|
||||||
from taskflow.types import timing
|
from taskflow.types import timing
|
||||||
|
|
||||||
|
|
||||||
class TestWorkerTypes(test.TestCase):
|
class TestRequestCache(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(TestWorkerTypes, self).setUp()
|
super(TestRequestCache, self).setUp()
|
||||||
self.addCleanup(timing.StopWatch.clear_overrides)
|
self.addCleanup(timing.StopWatch.clear_overrides)
|
||||||
self.task = utils.DummyTask()
|
self.task = utils.DummyTask()
|
||||||
self.task_uuid = 'task-uuid'
|
self.task_uuid = 'task-uuid'
|
||||||
@@ -76,6 +73,8 @@ class TestWorkerTypes(test.TestCase):
|
|||||||
self.assertEqual(1, len(matches))
|
self.assertEqual(1, len(matches))
|
||||||
self.assertEqual(2, len(cache))
|
self.assertEqual(2, len(cache))
|
||||||
|
|
||||||
|
|
||||||
|
class TestTopicWorker(test.TestCase):
|
||||||
def test_topic_worker(self):
|
def test_topic_worker(self):
|
||||||
worker = worker_types.TopicWorker("dummy-topic",
|
worker = worker_types.TopicWorker("dummy-topic",
|
||||||
[utils.DummyTask], identity="dummy")
|
[utils.DummyTask], identity="dummy")
|
||||||
@@ -84,52 +83,37 @@ class TestWorkerTypes(test.TestCase):
|
|||||||
self.assertEqual('dummy', worker.identity)
|
self.assertEqual('dummy', worker.identity)
|
||||||
self.assertEqual('dummy-topic', worker.topic)
|
self.assertEqual('dummy-topic', worker.topic)
|
||||||
|
|
||||||
def test_single_topic_workers(self):
|
|
||||||
workers = worker_types.TopicWorkers()
|
class TestProxyFinder(test.TestCase):
|
||||||
w = workers.add('dummy-topic', [utils.DummyTask])
|
def test_single_topic_worker(self):
|
||||||
|
finder = worker_types.ProxyWorkerFinder('me', mock.MagicMock(), [])
|
||||||
|
w, emit = finder._add('dummy-topic', [utils.DummyTask])
|
||||||
self.assertIsNotNone(w)
|
self.assertIsNotNone(w)
|
||||||
self.assertEqual(1, len(workers))
|
self.assertTrue(emit)
|
||||||
w2 = workers.get_worker_for_task(utils.DummyTask)
|
self.assertEqual(1, finder._total_workers())
|
||||||
|
w2 = finder.get_worker_for_task(utils.DummyTask)
|
||||||
self.assertEqual(w.identity, w2.identity)
|
self.assertEqual(w.identity, w2.identity)
|
||||||
|
|
||||||
def test_multi_same_topic_workers(self):
|
def test_multi_same_topic_workers(self):
|
||||||
workers = worker_types.TopicWorkers()
|
finder = worker_types.ProxyWorkerFinder('me', mock.MagicMock(), [])
|
||||||
w = workers.add('dummy-topic', [utils.DummyTask])
|
w, emit = finder._add('dummy-topic', [utils.DummyTask])
|
||||||
self.assertIsNotNone(w)
|
self.assertIsNotNone(w)
|
||||||
w2 = workers.add('dummy-topic-2', [utils.DummyTask])
|
self.assertTrue(emit)
|
||||||
|
w2, emit = finder._add('dummy-topic-2', [utils.DummyTask])
|
||||||
self.assertIsNotNone(w2)
|
self.assertIsNotNone(w2)
|
||||||
w3 = workers.get_worker_for_task(
|
self.assertTrue(emit)
|
||||||
|
w3 = finder.get_worker_for_task(
|
||||||
reflection.get_class_name(utils.DummyTask))
|
reflection.get_class_name(utils.DummyTask))
|
||||||
self.assertIn(w3.identity, [w.identity, w2.identity])
|
self.assertIn(w3.identity, [w.identity, w2.identity])
|
||||||
|
|
||||||
def test_multi_different_topic_workers(self):
|
def test_multi_different_topic_workers(self):
|
||||||
workers = worker_types.TopicWorkers()
|
finder = worker_types.ProxyWorkerFinder('me', mock.MagicMock(), [])
|
||||||
added = []
|
added = []
|
||||||
added.append(workers.add('dummy-topic', [utils.DummyTask]))
|
added.append(finder._add('dummy-topic', [utils.DummyTask]))
|
||||||
added.append(workers.add('dummy-topic-2', [utils.DummyTask]))
|
added.append(finder._add('dummy-topic-2', [utils.DummyTask]))
|
||||||
added.append(workers.add('dummy-topic-3', [utils.NastyTask]))
|
added.append(finder._add('dummy-topic-3', [utils.NastyTask]))
|
||||||
self.assertEqual(3, len(workers))
|
self.assertEqual(3, finder._total_workers())
|
||||||
w = workers.get_worker_for_task(utils.NastyTask)
|
w = finder.get_worker_for_task(utils.NastyTask)
|
||||||
self.assertEqual(added[-1].identity, w.identity)
|
self.assertEqual(added[-1][0].identity, w.identity)
|
||||||
w = workers.get_worker_for_task(utils.DummyTask)
|
w = finder.get_worker_for_task(utils.DummyTask)
|
||||||
self.assertIn(w.identity, [w_a.identity for w_a in added[0:2]])
|
self.assertIn(w.identity, [w_a[0].identity for w_a in added[0:2]])
|
||||||
|
|
||||||
def test_periodic_worker(self):
|
|
||||||
barrier = latch.Latch(5)
|
|
||||||
to = timing.Timeout(0.01)
|
|
||||||
called_at = []
|
|
||||||
|
|
||||||
def callee():
|
|
||||||
barrier.countdown()
|
|
||||||
if barrier.needed == 0:
|
|
||||||
to.interrupt()
|
|
||||||
called_at.append(time.time())
|
|
||||||
|
|
||||||
w = worker_types.PeriodicWorker(to, [callee])
|
|
||||||
t = threading.Thread(target=w.start)
|
|
||||||
t.start()
|
|
||||||
t.join()
|
|
||||||
|
|
||||||
self.assertEqual(0, barrier.needed)
|
|
||||||
self.assertEqual(5, len(called_at))
|
|
||||||
self.assertTrue(to.is_stopped())
|
|
||||||
|
179
taskflow/types/periodic.py
Normal file
179
taskflow/types/periodic.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright (C) 2015 Yahoo! Inc. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
# not use this file except in compliance with the License. You may obtain
|
||||||
|
# a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
# License for the specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
import heapq
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from oslo_utils import reflection
|
||||||
|
import six
|
||||||
|
|
||||||
|
from taskflow import logging
|
||||||
|
from taskflow.utils import misc
|
||||||
|
from taskflow.utils import threading_utils as tu
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Find a monotonic providing time (or fallback to using time.time()
|
||||||
|
# which isn't *always* accurate but will suffice).
|
||||||
|
_now = misc.find_monotonic(allow_time_time=True)
|
||||||
|
|
||||||
|
# Attributes expected on periodic tagged/decorated functions or methods...
|
||||||
|
_PERIODIC_ATTRS = tuple([
|
||||||
|
'_periodic',
|
||||||
|
'_periodic_spacing',
|
||||||
|
'_periodic_run_immediately',
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def periodic(spacing, run_immediately=True):
|
||||||
|
"""Tags a method/function as wanting/able to execute periodically."""
|
||||||
|
|
||||||
|
if spacing <= 0:
|
||||||
|
raise ValueError("Periodicity/spacing must be greater than"
|
||||||
|
" zero instead of %s" % spacing)
|
||||||
|
|
||||||
|
def wrapper(f):
|
||||||
|
f._periodic = True
|
||||||
|
f._periodic_spacing = spacing
|
||||||
|
f._periodic_run_immediately = run_immediately
|
||||||
|
|
||||||
|
@six.wraps(f)
|
||||||
|
def decorator(*args, **kwargs):
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class PeriodicWorker(object):
|
||||||
|
"""Calls a collection of callables periodically (sleeping as needed...).
|
||||||
|
|
||||||
|
NOTE(harlowja): typically the :py:meth:`.start` method is executed in a
|
||||||
|
background thread so that the periodic callables are executed in
|
||||||
|
the background/asynchronously (using the defined periods to determine
|
||||||
|
when each is called).
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, objects, exclude_hidden=True):
|
||||||
|
"""Automatically creates a worker by analyzing object(s) methods.
|
||||||
|
|
||||||
|
Only picks up methods that have been tagged/decorated with
|
||||||
|
the :py:func:`.periodic` decorator (does not match against private
|
||||||
|
or protected methods unless explicitly requested to).
|
||||||
|
"""
|
||||||
|
callables = []
|
||||||
|
for obj in objects:
|
||||||
|
for (name, member) in inspect.getmembers(obj):
|
||||||
|
if name.startswith("_") and exclude_hidden:
|
||||||
|
continue
|
||||||
|
if reflection.is_bound_method(member):
|
||||||
|
consume = True
|
||||||
|
for attr_name in _PERIODIC_ATTRS:
|
||||||
|
if not hasattr(member, attr_name):
|
||||||
|
consume = False
|
||||||
|
break
|
||||||
|
if consume:
|
||||||
|
callables.append(member)
|
||||||
|
return cls(callables)
|
||||||
|
|
||||||
|
def __init__(self, callables, tombstone=None):
|
||||||
|
if tombstone is None:
|
||||||
|
self._tombstone = tu.Event()
|
||||||
|
else:
|
||||||
|
# Allows someone to share an event (if they so want to...)
|
||||||
|
self._tombstone = tombstone
|
||||||
|
almost_callables = list(callables)
|
||||||
|
for cb in almost_callables:
|
||||||
|
if not six.callable(cb):
|
||||||
|
raise ValueError("Periodic callback must be callable")
|
||||||
|
for attr_name in _PERIODIC_ATTRS:
|
||||||
|
if not hasattr(cb, attr_name):
|
||||||
|
raise ValueError("Periodic callback missing required"
|
||||||
|
" attribute '%s'" % attr_name)
|
||||||
|
self._callables = tuple((cb, reflection.get_callable_name(cb))
|
||||||
|
for cb in almost_callables)
|
||||||
|
self._schedule = []
|
||||||
|
self._immediates = []
|
||||||
|
now = _now()
|
||||||
|
for i, (cb, cb_name) in enumerate(self._callables):
|
||||||
|
spacing = getattr(cb, '_periodic_spacing')
|
||||||
|
next_run = now + spacing
|
||||||
|
heapq.heappush(self._schedule, (next_run, i))
|
||||||
|
for (cb, cb_name) in reversed(self._callables):
|
||||||
|
if getattr(cb, '_periodic_run_immediately', False):
|
||||||
|
self._immediates.append((cb, cb_name))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._callables)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _safe_call(cb, cb_name, kind='periodic'):
|
||||||
|
try:
|
||||||
|
cb()
|
||||||
|
except Exception:
|
||||||
|
LOG.warn("Failed to call %s callable '%s'",
|
||||||
|
kind, cb_name, exc_info=True)
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Starts running (will not stop/return until the tombstone is set).
|
||||||
|
|
||||||
|
NOTE(harlowja): If this worker has no contained callables this raises
|
||||||
|
a runtime error and does not run since it is impossible to periodically
|
||||||
|
run nothing.
|
||||||
|
"""
|
||||||
|
if not self._callables:
|
||||||
|
raise RuntimeError("A periodic worker can not start"
|
||||||
|
" without any callables")
|
||||||
|
while not self._tombstone.is_set():
|
||||||
|
if self._immediates:
|
||||||
|
cb, cb_name = self._immediates.pop()
|
||||||
|
LOG.debug("Calling immediate callable '%s'", cb_name)
|
||||||
|
self._safe_call(cb, cb_name, kind='immediate')
|
||||||
|
else:
|
||||||
|
# Figure out when we should run next (by selecting the
|
||||||
|
# minimum item from the heap, where the minimum should be
|
||||||
|
# the callable that needs to run next and has the lowest
|
||||||
|
# next desired run time).
|
||||||
|
now = _now()
|
||||||
|
next_run, i = heapq.heappop(self._schedule)
|
||||||
|
when_next = next_run - now
|
||||||
|
if when_next <= 0:
|
||||||
|
cb, cb_name = self._callables[i]
|
||||||
|
spacing = getattr(cb, '_periodic_spacing')
|
||||||
|
LOG.debug("Calling periodic callable '%s' (it runs every"
|
||||||
|
" %s seconds)", cb_name, spacing)
|
||||||
|
self._safe_call(cb, cb_name)
|
||||||
|
# Run again someday...
|
||||||
|
next_run = now + spacing
|
||||||
|
heapq.heappush(self._schedule, (next_run, i))
|
||||||
|
else:
|
||||||
|
# Gotta wait...
|
||||||
|
heapq.heappush(self._schedule, (next_run, i))
|
||||||
|
self._tombstone.wait(when_next)
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Sets the tombstone (this stops any further executions)."""
|
||||||
|
self._tombstone.set()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Resets the tombstone and re-queues up any immediate executions."""
|
||||||
|
self._tombstone.clear()
|
||||||
|
self._immediates = []
|
||||||
|
for (cb, cb_name) in reversed(self._callables):
|
||||||
|
if getattr(cb, '_periodic_run_immediately', False):
|
||||||
|
self._immediates.append((cb, cb_name))
|
Reference in New Issue
Block a user