diff --git a/doc/source/types.rst b/doc/source/types.rst index 001e15ae..47ba7e48 100644 --- a/doc/source/types.rst +++ b/doc/source/types.rst @@ -44,6 +44,11 @@ Notifier .. automodule:: taskflow.types.notifier +Periodic +======== + +.. automodule:: taskflow.types.periodic + Table ===== diff --git a/taskflow/engines/worker_based/dispatcher.py b/taskflow/engines/worker_based/dispatcher.py index 385fd13d..13470e08 100644 --- a/taskflow/engines/worker_based/dispatcher.py +++ b/taskflow/engines/worker_based/dispatcher.py @@ -15,7 +15,6 @@ # under the License. from kombu import exceptions as kombu_exc -import six from taskflow import exceptions as excp from taskflow import logging @@ -27,14 +26,35 @@ LOG = logging.getLogger(__name__) class TypeDispatcher(object): """Receives messages and dispatches to type specific handlers.""" - def __init__(self, type_handlers): - self._handlers = dict(type_handlers) - self._requeue_filters = [] + def __init__(self, type_handlers=None, requeue_filters=None): + if type_handlers is not None: + self._type_handlers = dict(type_handlers) + else: + self._type_handlers = {} + if requeue_filters is not None: + self._requeue_filters = list(requeue_filters) + else: + self._requeue_filters = [] - def add_requeue_filter(self, callback): - """Add a callback that can *request* message requeuing. + @property + def type_handlers(self): + """Dictionary of message type -> callback to handle that message. - The callback will be activated before the message has been acked and + The callback(s) will be activated by looking for a message + property 'type' and locating a callback in this dictionary that maps + to that type; if one is found it is expected to be a callback that + accepts two positional parameters; the first being the message data + and the second being the message object. If a callback is not found + then the message is rejected and it will be up to the underlying + message transport to determine what this means/implies... + """ + return self._type_handlers + + @property + def requeue_filters(self): + """List of filters (callbacks) to request a message to be requeued. + + The callback(s) will be activated before the message has been acked and it can be used to instruct the dispatcher to requeue the message instead of processing it. The callback, when called, will be provided two positional parameters; the first being the message data and the @@ -42,9 +62,7 @@ class TypeDispatcher(object): filter should return a truthy object if the message should be requeued and a falsey object if it should not. """ - if not six.callable(callback): - raise ValueError("Requeue filter callback must be callable") - self._requeue_filters.append(callback) + return self._requeue_filters def _collect_requeue_votes(self, data, message): # Returns how many of the filters asked for the message to be requeued. @@ -74,7 +92,7 @@ class TypeDispatcher(object): LOG.debug("Message '%s' was requeued.", ku.DelayedPretty(message)) def _process_message(self, data, message, message_type): - handler = self._handlers.get(message_type) + handler = self._type_handlers.get(message_type) if handler is None: message.reject_log_error(logger=LOG, errors=(kombu_exc.MessageStateError,)) diff --git a/taskflow/engines/worker_based/executor.py b/taskflow/engines/worker_based/executor.py index 6ece1b84..a55229f1 100644 --- a/taskflow/engines/worker_based/executor.py +++ b/taskflow/engines/worker_based/executor.py @@ -25,7 +25,7 @@ from taskflow.engines.worker_based import types as wt from taskflow import exceptions as exc from taskflow import logging from taskflow import task as task_atom -from taskflow.types import timing as tt +from taskflow.types import periodic from taskflow.utils import kombu_utils as ku from taskflow.utils import misc from taskflow.utils import threading_utils as tu @@ -41,51 +41,41 @@ class WorkerTaskExecutor(executor.TaskExecutor): url=None, transport=None, transport_options=None, retry_options=None): self._uuid = uuid - self._topics = topics self._requests_cache = wt.RequestsCache() - self._workers = wt.TopicWorkers() self._transition_timeout = transition_timeout type_handlers = { - pr.NOTIFY: [ - self._process_notify, - functools.partial(pr.Notify.validate, response=True), - ], pr.RESPONSE: [ self._process_response, pr.Response.validate, ], } - self._proxy = proxy.Proxy(uuid, exchange, type_handlers, + self._proxy = proxy.Proxy(uuid, exchange, + type_handlers=type_handlers, on_wait=self._on_wait, url=url, transport=transport, transport_options=transport_options, retry_options=retry_options) - self._periodic = wt.PeriodicWorker(tt.Timeout(pr.NOTIFY_PERIOD), - [self._notify_topics]) + # NOTE(harlowja): This is the most simplest finder impl. that + # doesn't have external dependencies (outside of what this engine + # already requires); it though does create periodic 'polling' traffic + # to workers to 'learn' of the tasks they can perform (and requires + # pre-existing knowledge of the topics those workers are on to gather + # and update this information). + self._finder = wt.ProxyWorkerFinder(uuid, self._proxy, topics) + self._finder.on_worker = self._on_worker self._helpers = tu.ThreadBundle() self._helpers.bind(lambda: tu.daemon_thread(self._proxy.start), after_start=lambda t: self._proxy.wait(), before_join=lambda t: self._proxy.stop()) - self._helpers.bind(lambda: tu.daemon_thread(self._periodic.start), - before_join=lambda t: self._periodic.stop(), - after_join=lambda t: self._periodic.reset(), - before_start=lambda t: self._periodic.reset()) + p_worker = periodic.PeriodicWorker.create([self._finder]) + if p_worker: + self._helpers.bind(lambda: tu.daemon_thread(p_worker.start), + before_join=lambda t: p_worker.stop(), + after_join=lambda t: p_worker.reset(), + before_start=lambda t: p_worker.reset()) - def _process_notify(self, notify, message): - """Process notify message from remote side.""" - LOG.debug("Started processing notify message '%s'", - ku.DelayedPretty(message)) - - topic = notify['topic'] - tasks = notify['tasks'] - - # Add worker info to the cache - worker = self._workers.add(topic, tasks) - LOG.debug("Received notification about worker '%s' (%s" - " total workers are currently known)", worker, - len(self._workers)) - - # Publish waiting requests + def _on_worker(self, worker): + """Process new worker that has arrived (and fire off any work).""" for request in self._requests_cache.get_waiting_requests(worker): if request.transition_and_log_error(pr.PENDING, logger=LOG): self._publish_request(request, worker) @@ -174,7 +164,7 @@ class WorkerTaskExecutor(executor.TaskExecutor): request.result.add_done_callback(lambda fut: cleaner()) # Get task's worker and publish request if worker was found. - worker = self._workers.get_worker_for_task(task) + worker = self._finder.get_worker_for_task(task) if worker is not None: # NOTE(skudriashev): Make sure request is set to the PENDING state # before putting it into the requests cache to prevent the notify @@ -208,10 +198,6 @@ class WorkerTaskExecutor(executor.TaskExecutor): del self._requests_cache[request.uuid] request.set_result(failure) - def _notify_topics(self): - """Cyclically called to publish notify message to each topic.""" - self._proxy.publish(pr.Notify(), self._topics, reply_to=self._uuid) - def execute_task(self, task, task_uuid, arguments, progress_callback=None): return self._submit_task(task, task_uuid, pr.EXECUTE, arguments, @@ -232,7 +218,8 @@ class WorkerTaskExecutor(executor.TaskExecutor): return how many workers are still needed, otherwise it will return zero. """ - return self._workers.wait_for_workers(workers=workers, timeout=timeout) + return self._finder.wait_for_workers(workers=workers, + timeout=timeout) def start(self): """Starts proxy thread and associated topic notification thread.""" @@ -242,4 +229,4 @@ class WorkerTaskExecutor(executor.TaskExecutor): """Stops proxy thread and associated topic notification thread.""" self._helpers.stop() self._requests_cache.clear(self._handle_expired_request) - self._workers.clear() + self._finder.clear() diff --git a/taskflow/engines/worker_based/proxy.py b/taskflow/engines/worker_based/proxy.py index 505ead2a..e9d2ec22 100644 --- a/taskflow/engines/worker_based/proxy.py +++ b/taskflow/engines/worker_based/proxy.py @@ -68,19 +68,19 @@ class Proxy(object): # value is valid... _RETRY_INT_OPTS = frozenset(['max_retries']) - def __init__(self, topic, exchange, type_handlers, - on_wait=None, url=None, + def __init__(self, topic, exchange, + type_handlers=None, on_wait=None, url=None, transport=None, transport_options=None, retry_options=None): self._topic = topic self._exchange_name = exchange self._on_wait = on_wait self._running = threading_utils.Event() - self._dispatcher = dispatcher.TypeDispatcher(type_handlers) - self._dispatcher.add_requeue_filter( + self._dispatcher = dispatcher.TypeDispatcher( # NOTE(skudriashev): Process all incoming messages only if proxy is # running, otherwise requeue them. - lambda data, message: not self.is_running) + requeue_filters=[lambda data, message: not self.is_running], + type_handlers=type_handlers) ensure_options = self.DEFAULT_RETRY_OPTIONS.copy() if retry_options is not None: @@ -112,11 +112,16 @@ class Proxy(object): # create exchange self._exchange = kombu.Exchange(name=self._exchange_name, - durable=False, - auto_delete=True) + durable=False, auto_delete=True) + + @property + def dispatcher(self): + """Dispatcher internally used to dispatch message(s) that match.""" + return self._dispatcher @property def connection_details(self): + """Details about the connection (read-only).""" # The kombu drivers seem to use 'N/A' when they don't have a version... driver_version = self._conn.transport.driver_version() if driver_version and driver_version.lower() == 'n/a': diff --git a/taskflow/engines/worker_based/server.py b/taskflow/engines/worker_based/server.py index 5bb0b2ab..949b4691 100644 --- a/taskflow/engines/worker_based/server.py +++ b/taskflow/engines/worker_based/server.py @@ -59,7 +59,8 @@ class Server(object): pr.Request.validate, ], } - self._proxy = proxy.Proxy(topic, exchange, type_handlers, + self._proxy = proxy.Proxy(topic, exchange, + type_handlers=type_handlers, url=url, transport=transport, transport_options=transport_options, retry_options=retry_options) diff --git a/taskflow/engines/worker_based/types.py b/taskflow/engines/worker_based/types.py index d8a7e413..70185d52 100644 --- a/taskflow/engines/worker_based/types.py +++ b/taskflow/engines/worker_based/types.py @@ -14,17 +14,21 @@ # License for the specific language governing permissions and limitations # under the License. +import abc +import functools import itertools -import logging import random import threading -from oslo.utils import reflection +from oslo_utils import reflection import six from taskflow.engines.worker_based import protocol as pr +from taskflow import logging from taskflow.types import cache as base +from taskflow.types import periodic from taskflow.types import timing as tt +from taskflow.utils import kombu_utils as ku LOG = logging.getLogger(__name__) @@ -91,8 +95,37 @@ class TopicWorker(object): return r -class TopicWorkers(object): - """A collection of topic based workers.""" +@six.add_metaclass(abc.ABCMeta) +class WorkerFinder(object): + """Base class for worker finders...""" + + def __init__(self): + self._cond = threading.Condition() + self.on_worker = None + + @abc.abstractmethod + def _total_workers(self): + """Returns how many workers are known.""" + + def wait_for_workers(self, workers=1, timeout=None): + """Waits for geq workers to notify they are ready to do work. + + NOTE(harlowja): if a timeout is provided this function will wait + until that timeout expires, if the amount of workers does not reach + the desired amount of workers before the timeout expires then this will + return how many workers are still needed, otherwise it will + return zero. + """ + if workers <= 0: + raise ValueError("Worker amount must be greater than zero") + watch = tt.StopWatch(duration=timeout) + watch.start() + with self._cond: + while self._total_workers() < workers: + if watch.expired(): + return max(0, workers - self._total_workers()) + self._cond.wait(watch.leftover(return_none=True)) + return 0 @staticmethod def _match_worker(task, available_workers): @@ -110,14 +143,30 @@ class TopicWorkers(object): else: return random.choice(available_workers) - def __init__(self): - self._workers = {} - self._cond = threading.Condition() - # Used to name workers with more useful identities... - self._counter = itertools.count() + @abc.abstractmethod + def get_worker_for_task(self, task): + """Gets a worker that can perform a given task.""" - def __len__(self): - return len(self._workers) + def clear(self): + pass + + +class ProxyWorkerFinder(WorkerFinder): + """Requests and receives responses about workers topic+task details.""" + + def __init__(self, uuid, proxy, topics): + super(ProxyWorkerFinder, self).__init__() + self._proxy = proxy + self._topics = topics + self._workers = {} + self._uuid = uuid + self._proxy.dispatcher.type_handlers.update({ + pr.NOTIFY: [ + self._process_response, + functools.partial(pr.Notify.validate, response=True), + ], + }) + self._counter = itertools.count() def _next_worker(self, topic, tasks, temporary=False): if not temporary: @@ -126,48 +175,54 @@ class TopicWorkers(object): else: return TopicWorker(topic, tasks) - def add(self, topic, tasks): + @periodic.periodic(pr.NOTIFY_PERIOD) + def beat(self): + """Cyclically called to publish notify message to each topic.""" + self._proxy.publish(pr.Notify(), self._topics, reply_to=self._uuid) + + def _total_workers(self): + return len(self._workers) + + def _add(self, topic, tasks): """Adds/updates a worker for the topic for the given tasks.""" + try: + worker = self._workers[topic] + # Check if we already have an equivalent worker, if so just + # return it... + if worker == self._next_worker(topic, tasks, temporary=True): + return (worker, False) + # This *fall through* is done so that if someone is using an + # active worker object that already exists that we just create + # a new one; so that the existing object doesn't get + # affected (workers objects are supposed to be immutable). + except KeyError: + pass + worker = self._next_worker(topic, tasks) + self._workers[topic] = worker + return (worker, True) + + def _process_response(self, response, message): + """Process notify message from remote side.""" + LOG.debug("Started processing notify message '%s'", + ku.DelayedPretty(message)) + topic = response['topic'] + tasks = response['tasks'] with self._cond: - try: - worker = self._workers[topic] - # Check if we already have an equivalent worker, if so just - # return it... - if worker == self._next_worker(topic, tasks, temporary=True): - return worker - # This *fall through* is done so that if someone is using an - # active worker object that already exists that we just create - # a new one; so that the existing object doesn't get - # affected (workers objects are supposed to be immutable). - except KeyError: - pass - worker = self._next_worker(topic, tasks) - self._workers[topic] = worker + worker, new_or_updated = self._add(topic, tasks) + if new_or_updated: + LOG.debug("Received notification about worker '%s' (%s" + " total workers are currently known)", worker, + self._total_workers()) + self._cond.notify_all() + if self.on_worker is not None and new_or_updated: + self.on_worker(worker) + + def clear(self): + with self._cond: + self._workers.clear() self._cond.notify_all() - return worker - - def wait_for_workers(self, workers=1, timeout=None): - """Waits for geq workers to notify they are ready to do work. - - NOTE(harlowja): if a timeout is provided this function will wait - until that timeout expires, if the amount of workers does not reach - the desired amount of workers before the timeout expires then this will - return how many workers are still needed, otherwise it will - return zero. - """ - if workers <= 0: - raise ValueError("Worker amount must be greater than zero") - watch = tt.StopWatch(duration=timeout) - watch.start() - with self._cond: - while len(self._workers) < workers: - if watch.expired(): - return max(0, workers - len(self._workers)) - self._cond.wait(watch.leftover(return_none=True)) - return 0 def get_worker_for_task(self, task): - """Gets a worker that can perform a given task.""" available_workers = [] with self._cond: for worker in six.itervalues(self._workers): @@ -177,37 +232,3 @@ class TopicWorkers(object): return self._match_worker(task, available_workers) else: return None - - def clear(self): - with self._cond: - self._workers.clear() - self._cond.notify_all() - - -class PeriodicWorker(object): - """Calls a set of functions when activated periodically. - - NOTE(harlowja): the provided timeout object determines the periodicity. - """ - def __init__(self, timeout, functors): - self._timeout = timeout - self._functors = [] - for f in functors: - self._functors.append((f, reflection.get_callable_name(f))) - - def start(self): - while not self._timeout.is_stopped(): - for (f, f_name) in self._functors: - LOG.debug("Calling periodic function '%s'", f_name) - try: - f() - except Exception: - LOG.warn("Failed to call periodic function '%s'", f_name, - exc_info=True) - self._timeout.wait() - - def stop(self): - self._timeout.interrupt() - - def reset(self): - self._timeout.reset() diff --git a/taskflow/tests/unit/test_types.py b/taskflow/tests/unit/test_types.py index 1e639767..4daea4bf 100644 --- a/taskflow/tests/unit/test_types.py +++ b/taskflow/tests/unit/test_types.py @@ -14,6 +14,8 @@ # License for the specific language governing permissions and limitations # under the License. +import time + import networkx as nx import six @@ -21,9 +23,31 @@ from taskflow import exceptions as excp from taskflow import test from taskflow.types import fsm from taskflow.types import graph +from taskflow.types import latch +from taskflow.types import periodic from taskflow.types import table from taskflow.types import timing as tt from taskflow.types import tree +from taskflow.utils import threading_utils as tu + + +class PeriodicThingy(object): + def __init__(self): + self.capture = [] + + @periodic.periodic(0.01) + def a(self): + self.capture.append('a') + + @periodic.periodic(0.02) + def b(self): + self.capture.append('b') + + def c(self): + pass + + def d(self): + pass class GraphTest(test.TestCase): @@ -451,3 +475,112 @@ class FSMTest(test.TestCase): m.add_state('broken') self.assertRaises(ValueError, m.add_state, 'b', on_enter=2) self.assertRaises(ValueError, m.add_state, 'b', on_exit=2) + + +class PeriodicTest(test.TestCase): + + def test_invalid_periodic(self): + + def no_op(): + pass + + self.assertRaises(ValueError, periodic.periodic, -1) + + def test_valid_periodic(self): + + @periodic.periodic(2) + def no_op(): + pass + + self.assertTrue(getattr(no_op, '_periodic')) + self.assertEqual(2, getattr(no_op, '_periodic_spacing')) + self.assertEqual(True, getattr(no_op, '_periodic_run_immediately')) + + def test_scanning_periodic(self): + p = PeriodicThingy() + w = periodic.PeriodicWorker.create([p]) + self.assertEqual(2, len(w)) + + t = tu.daemon_thread(target=w.start) + t.start() + time.sleep(0.1) + w.stop() + t.join() + + b_calls = [c for c in p.capture if c == 'b'] + self.assertGreater(0, len(b_calls)) + a_calls = [c for c in p.capture if c == 'a'] + self.assertGreater(0, len(a_calls)) + + def test_periodic_single(self): + barrier = latch.Latch(5) + capture = [] + tombstone = tu.Event() + + @periodic.periodic(0.01) + def callee(): + barrier.countdown() + if barrier.needed == 0: + tombstone.set() + capture.append(1) + + w = periodic.PeriodicWorker([callee], tombstone=tombstone) + t = tu.daemon_thread(target=w.start) + t.start() + t.join() + + self.assertEqual(0, barrier.needed) + self.assertEqual(5, sum(capture)) + self.assertTrue(tombstone.is_set()) + + def test_immediate(self): + capture = [] + + @periodic.periodic(120, run_immediately=True) + def a(): + capture.append('a') + + w = periodic.PeriodicWorker([a]) + t = tu.daemon_thread(target=w.start) + t.start() + time.sleep(0.1) + w.stop() + t.join() + + a_calls = [c for c in capture if c == 'a'] + self.assertGreater(0, len(a_calls)) + + def test_period_double_no_immediate(self): + capture = [] + + @periodic.periodic(0.01, run_immediately=False) + def a(): + capture.append('a') + + @periodic.periodic(0.02, run_immediately=False) + def b(): + capture.append('b') + + w = periodic.PeriodicWorker([a, b]) + t = tu.daemon_thread(target=w.start) + t.start() + time.sleep(0.1) + w.stop() + t.join() + + b_calls = [c for c in capture if c == 'b'] + self.assertGreater(0, len(b_calls)) + a_calls = [c for c in capture if c == 'a'] + self.assertGreater(0, len(a_calls)) + + def test_start_nothing_error(self): + w = periodic.PeriodicWorker([]) + self.assertRaises(RuntimeError, w.start) + + def test_missing_function_attrs(self): + + def fake_periodic(): + pass + + cb = fake_periodic + self.assertRaises(ValueError, periodic.PeriodicWorker, [cb]) diff --git a/taskflow/tests/unit/worker_based/test_dispatcher.py b/taskflow/tests/unit/worker_based/test_dispatcher.py index db0a719c..21fccdcc 100644 --- a/taskflow/tests/unit/worker_based/test_dispatcher.py +++ b/taskflow/tests/unit/worker_based/test_dispatcher.py @@ -41,12 +41,12 @@ class TestDispatcher(test.TestCase): def test_creation(self): on_hello = mock.MagicMock() handlers = {'hello': on_hello} - dispatcher.TypeDispatcher(handlers) + dispatcher.TypeDispatcher(type_handlers=handlers) def test_on_message(self): on_hello = mock.MagicMock() handlers = {'hello': on_hello} - d = dispatcher.TypeDispatcher(handlers) + d = dispatcher.TypeDispatcher(type_handlers=handlers) msg = mock_acked_message(properties={'type': 'hello'}) d.on_message("", msg) self.assertTrue(on_hello.called) @@ -54,15 +54,15 @@ class TestDispatcher(test.TestCase): self.assertTrue(msg.acknowledged) def test_on_rejected_message(self): - d = dispatcher.TypeDispatcher({}) + d = dispatcher.TypeDispatcher() msg = mock_acked_message(properties={'type': 'hello'}) d.on_message("", msg) self.assertTrue(msg.reject_log_error.called) self.assertFalse(msg.acknowledged) def test_on_requeue_message(self): - d = dispatcher.TypeDispatcher({}) - d.add_requeue_filter(lambda data, message: True) + d = dispatcher.TypeDispatcher() + d.requeue_filters.append(lambda data, message: True) msg = mock_acked_message() d.on_message("", msg) self.assertTrue(msg.requeue.called) @@ -71,7 +71,7 @@ class TestDispatcher(test.TestCase): def test_failed_ack(self): on_hello = mock.MagicMock() handlers = {'hello': on_hello} - d = dispatcher.TypeDispatcher(handlers) + d = dispatcher.TypeDispatcher(type_handlers=handlers) msg = mock_acked_message(ack_ok=False, properties={'type': 'hello'}) d.on_message("", msg) diff --git a/taskflow/tests/unit/worker_based/test_executor.py b/taskflow/tests/unit/worker_based/test_executor.py index 101031c4..e7831783 100644 --- a/taskflow/tests/unit/worker_based/test_executor.py +++ b/taskflow/tests/unit/worker_based/test_executor.py @@ -86,11 +86,12 @@ class TestWorkerTaskExecutor(test.MockTestCase): ex = self.executor(reset_master_mock=False) master_mock_calls = [ mock.call.Proxy(self.executor_uuid, self.executor_exchange, - mock.ANY, on_wait=ex._on_wait, + on_wait=ex._on_wait, url=self.broker_url, transport=mock.ANY, transport_options=mock.ANY, - retry_options=mock.ANY - ) + retry_options=mock.ANY, + type_handlers=mock.ANY), + mock.call.proxy.dispatcher.type_handlers.update(mock.ANY), ] self.assertEqual(self.master_mock.mock_calls, master_mock_calls) @@ -212,10 +213,8 @@ class TestWorkerTaskExecutor(test.MockTestCase): self.assertEqual(len(ex._requests_cache), 0) def test_execute_task(self): - self.message_mock.properties['type'] = pr.NOTIFY - notify = pr.Notify(topic=self.executor_topic, tasks=[self.task.name]) ex = self.executor() - ex._process_notify(notify.to_dict(), self.message_mock) + ex._finder._add(self.executor_topic, [self.task.name]) ex.execute_task(self.task, self.task_uuid, self.task_args) expected_calls = [ @@ -231,10 +230,8 @@ class TestWorkerTaskExecutor(test.MockTestCase): self.assertEqual(expected_calls, self.master_mock.mock_calls) def test_revert_task(self): - self.message_mock.properties['type'] = pr.NOTIFY - notify = pr.Notify(topic=self.executor_topic, tasks=[self.task.name]) ex = self.executor() - ex._process_notify(notify.to_dict(), self.message_mock) + ex._finder._add(self.executor_topic, [self.task.name]) ex.revert_task(self.task, self.task_uuid, self.task_args, self.task_result, self.task_failures) @@ -263,11 +260,9 @@ class TestWorkerTaskExecutor(test.MockTestCase): self.assertEqual(self.master_mock.mock_calls, expected_calls) def test_execute_task_publish_error(self): - self.message_mock.properties['type'] = pr.NOTIFY self.proxy_inst_mock.publish.side_effect = Exception('Woot!') - notify = pr.Notify(topic=self.executor_topic, tasks=[self.task.name]) ex = self.executor() - ex._process_notify(notify.to_dict(), self.message_mock) + ex._finder._add(self.executor_topic, [self.task.name]) ex.execute_task(self.task, self.task_uuid, self.task_args) expected_calls = [ diff --git a/taskflow/tests/unit/worker_based/test_server.py b/taskflow/tests/unit/worker_based/test_server.py index 1fb8aa5c..fea5d1cc 100644 --- a/taskflow/tests/unit/worker_based/test_server.py +++ b/taskflow/tests/unit/worker_based/test_server.py @@ -86,7 +86,7 @@ class TestServer(test.MockTestCase): # check calls master_mock_calls = [ mock.call.Proxy(self.server_topic, self.server_exchange, - mock.ANY, url=self.broker_url, + type_handlers=mock.ANY, url=self.broker_url, transport=mock.ANY, transport_options=mock.ANY, retry_options=mock.ANY) ] @@ -99,7 +99,7 @@ class TestServer(test.MockTestCase): # check calls master_mock_calls = [ mock.call.Proxy(self.server_topic, self.server_exchange, - mock.ANY, url=self.broker_url, + type_handlers=mock.ANY, url=self.broker_url, transport=mock.ANY, transport_options=mock.ANY, retry_options=mock.ANY) ] diff --git a/taskflow/tests/unit/worker_based/test_types.py b/taskflow/tests/unit/worker_based/test_types.py index e1bf949b..287283cf 100644 --- a/taskflow/tests/unit/worker_based/test_types.py +++ b/taskflow/tests/unit/worker_based/test_types.py @@ -14,23 +14,20 @@ # License for the specific language governing permissions and limitations # under the License. -import threading -import time - from oslo.utils import reflection from taskflow.engines.worker_based import protocol as pr from taskflow.engines.worker_based import types as worker_types from taskflow import test +from taskflow.test import mock from taskflow.tests import utils -from taskflow.types import latch from taskflow.types import timing -class TestWorkerTypes(test.TestCase): +class TestRequestCache(test.TestCase): def setUp(self): - super(TestWorkerTypes, self).setUp() + super(TestRequestCache, self).setUp() self.addCleanup(timing.StopWatch.clear_overrides) self.task = utils.DummyTask() self.task_uuid = 'task-uuid' @@ -76,6 +73,8 @@ class TestWorkerTypes(test.TestCase): self.assertEqual(1, len(matches)) self.assertEqual(2, len(cache)) + +class TestTopicWorker(test.TestCase): def test_topic_worker(self): worker = worker_types.TopicWorker("dummy-topic", [utils.DummyTask], identity="dummy") @@ -84,52 +83,37 @@ class TestWorkerTypes(test.TestCase): self.assertEqual('dummy', worker.identity) self.assertEqual('dummy-topic', worker.topic) - def test_single_topic_workers(self): - workers = worker_types.TopicWorkers() - w = workers.add('dummy-topic', [utils.DummyTask]) + +class TestProxyFinder(test.TestCase): + def test_single_topic_worker(self): + finder = worker_types.ProxyWorkerFinder('me', mock.MagicMock(), []) + w, emit = finder._add('dummy-topic', [utils.DummyTask]) self.assertIsNotNone(w) - self.assertEqual(1, len(workers)) - w2 = workers.get_worker_for_task(utils.DummyTask) + self.assertTrue(emit) + self.assertEqual(1, finder._total_workers()) + w2 = finder.get_worker_for_task(utils.DummyTask) self.assertEqual(w.identity, w2.identity) def test_multi_same_topic_workers(self): - workers = worker_types.TopicWorkers() - w = workers.add('dummy-topic', [utils.DummyTask]) + finder = worker_types.ProxyWorkerFinder('me', mock.MagicMock(), []) + w, emit = finder._add('dummy-topic', [utils.DummyTask]) self.assertIsNotNone(w) - w2 = workers.add('dummy-topic-2', [utils.DummyTask]) + self.assertTrue(emit) + w2, emit = finder._add('dummy-topic-2', [utils.DummyTask]) self.assertIsNotNone(w2) - w3 = workers.get_worker_for_task( + self.assertTrue(emit) + w3 = finder.get_worker_for_task( reflection.get_class_name(utils.DummyTask)) self.assertIn(w3.identity, [w.identity, w2.identity]) def test_multi_different_topic_workers(self): - workers = worker_types.TopicWorkers() + finder = worker_types.ProxyWorkerFinder('me', mock.MagicMock(), []) added = [] - added.append(workers.add('dummy-topic', [utils.DummyTask])) - added.append(workers.add('dummy-topic-2', [utils.DummyTask])) - added.append(workers.add('dummy-topic-3', [utils.NastyTask])) - self.assertEqual(3, len(workers)) - w = workers.get_worker_for_task(utils.NastyTask) - self.assertEqual(added[-1].identity, w.identity) - w = workers.get_worker_for_task(utils.DummyTask) - self.assertIn(w.identity, [w_a.identity for w_a in added[0:2]]) - - def test_periodic_worker(self): - barrier = latch.Latch(5) - to = timing.Timeout(0.01) - called_at = [] - - def callee(): - barrier.countdown() - if barrier.needed == 0: - to.interrupt() - called_at.append(time.time()) - - w = worker_types.PeriodicWorker(to, [callee]) - t = threading.Thread(target=w.start) - t.start() - t.join() - - self.assertEqual(0, barrier.needed) - self.assertEqual(5, len(called_at)) - self.assertTrue(to.is_stopped()) + added.append(finder._add('dummy-topic', [utils.DummyTask])) + added.append(finder._add('dummy-topic-2', [utils.DummyTask])) + added.append(finder._add('dummy-topic-3', [utils.NastyTask])) + self.assertEqual(3, finder._total_workers()) + w = finder.get_worker_for_task(utils.NastyTask) + self.assertEqual(added[-1][0].identity, w.identity) + w = finder.get_worker_for_task(utils.DummyTask) + self.assertIn(w.identity, [w_a[0].identity for w_a in added[0:2]]) diff --git a/taskflow/types/periodic.py b/taskflow/types/periodic.py new file mode 100644 index 00000000..bbb494d3 --- /dev/null +++ b/taskflow/types/periodic.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- + +# Copyright (C) 2015 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import heapq +import inspect + +from oslo_utils import reflection +import six + +from taskflow import logging +from taskflow.utils import misc +from taskflow.utils import threading_utils as tu + +LOG = logging.getLogger(__name__) + +# Find a monotonic providing time (or fallback to using time.time() +# which isn't *always* accurate but will suffice). +_now = misc.find_monotonic(allow_time_time=True) + +# Attributes expected on periodic tagged/decorated functions or methods... +_PERIODIC_ATTRS = tuple([ + '_periodic', + '_periodic_spacing', + '_periodic_run_immediately', +]) + + +def periodic(spacing, run_immediately=True): + """Tags a method/function as wanting/able to execute periodically.""" + + if spacing <= 0: + raise ValueError("Periodicity/spacing must be greater than" + " zero instead of %s" % spacing) + + def wrapper(f): + f._periodic = True + f._periodic_spacing = spacing + f._periodic_run_immediately = run_immediately + + @six.wraps(f) + def decorator(*args, **kwargs): + return f(*args, **kwargs) + + return decorator + + return wrapper + + +class PeriodicWorker(object): + """Calls a collection of callables periodically (sleeping as needed...). + + NOTE(harlowja): typically the :py:meth:`.start` method is executed in a + background thread so that the periodic callables are executed in + the background/asynchronously (using the defined periods to determine + when each is called). + """ + + @classmethod + def create(cls, objects, exclude_hidden=True): + """Automatically creates a worker by analyzing object(s) methods. + + Only picks up methods that have been tagged/decorated with + the :py:func:`.periodic` decorator (does not match against private + or protected methods unless explicitly requested to). + """ + callables = [] + for obj in objects: + for (name, member) in inspect.getmembers(obj): + if name.startswith("_") and exclude_hidden: + continue + if reflection.is_bound_method(member): + consume = True + for attr_name in _PERIODIC_ATTRS: + if not hasattr(member, attr_name): + consume = False + break + if consume: + callables.append(member) + return cls(callables) + + def __init__(self, callables, tombstone=None): + if tombstone is None: + self._tombstone = tu.Event() + else: + # Allows someone to share an event (if they so want to...) + self._tombstone = tombstone + almost_callables = list(callables) + for cb in almost_callables: + if not six.callable(cb): + raise ValueError("Periodic callback must be callable") + for attr_name in _PERIODIC_ATTRS: + if not hasattr(cb, attr_name): + raise ValueError("Periodic callback missing required" + " attribute '%s'" % attr_name) + self._callables = tuple((cb, reflection.get_callable_name(cb)) + for cb in almost_callables) + self._schedule = [] + self._immediates = [] + now = _now() + for i, (cb, cb_name) in enumerate(self._callables): + spacing = getattr(cb, '_periodic_spacing') + next_run = now + spacing + heapq.heappush(self._schedule, (next_run, i)) + for (cb, cb_name) in reversed(self._callables): + if getattr(cb, '_periodic_run_immediately', False): + self._immediates.append((cb, cb_name)) + + def __len__(self): + return len(self._callables) + + @staticmethod + def _safe_call(cb, cb_name, kind='periodic'): + try: + cb() + except Exception: + LOG.warn("Failed to call %s callable '%s'", + kind, cb_name, exc_info=True) + + def start(self): + """Starts running (will not stop/return until the tombstone is set). + + NOTE(harlowja): If this worker has no contained callables this raises + a runtime error and does not run since it is impossible to periodically + run nothing. + """ + if not self._callables: + raise RuntimeError("A periodic worker can not start" + " without any callables") + while not self._tombstone.is_set(): + if self._immediates: + cb, cb_name = self._immediates.pop() + LOG.debug("Calling immediate callable '%s'", cb_name) + self._safe_call(cb, cb_name, kind='immediate') + else: + # Figure out when we should run next (by selecting the + # minimum item from the heap, where the minimum should be + # the callable that needs to run next and has the lowest + # next desired run time). + now = _now() + next_run, i = heapq.heappop(self._schedule) + when_next = next_run - now + if when_next <= 0: + cb, cb_name = self._callables[i] + spacing = getattr(cb, '_periodic_spacing') + LOG.debug("Calling periodic callable '%s' (it runs every" + " %s seconds)", cb_name, spacing) + self._safe_call(cb, cb_name) + # Run again someday... + next_run = now + spacing + heapq.heappush(self._schedule, (next_run, i)) + else: + # Gotta wait... + heapq.heappush(self._schedule, (next_run, i)) + self._tombstone.wait(when_next) + + def stop(self): + """Sets the tombstone (this stops any further executions).""" + self._tombstone.set() + + def reset(self): + """Resets the tombstone and re-queues up any immediate executions.""" + self._tombstone.clear() + self._immediates = [] + for (cb, cb_name) in reversed(self._callables): + if getattr(cb, '_periodic_run_immediately', False): + self._immediates.append((cb, cb_name))