Use the notifier type in the task class/module directly

Instead of having code that is some what like the notifier
code we already have, but is duplicated and is slightly different
in the task class just move the code that was in the task class (and
doing similar actions) to instead now use a notifier that is directly
contained in the task base class for internal task triggering of
internal task events.

Breaking change: alters the capabilities of the task to process
notifications itself, most actions now must go through the task
notifier property and instead use that (update_progress still exists
as a common utility method, since it's likely the most common type
of notification that will be used).

Removes the following methods from task base class (as they are
no longer needed with a notifier attribute):

- trigger (replaced with notifier.notify)
- autobind (removed, not replaced, can be created by the user
            of taskflow in a simple manner, without requiring
            functionality in taskflow)
- bind (replaced with notifier.register)
- unbind (replaced with notifier.unregister)
- listeners_iter (replaced with notifier.listeners_iter)

Due to this change we can now also correctly proxy back events from
remote tasks to the engine for correct proxying back to the original
task.

Fixes bug 1370766

Change-Id: Ic9dfef516d72e6e32e71dda30a1cb3522c9e0be6
This commit is contained in:
Joshua Harlow 2014-12-12 15:57:53 -08:00 committed by Joshua Harlow
parent cdfd8ece61
commit 1f4dd72e6e
17 changed files with 492 additions and 364 deletions

View File

@ -158,10 +158,11 @@ engine executor in the following manner:
from dicts after receiving on both executor & worker sides (this from dicts after receiving on both executor & worker sides (this
translation is lossy since the traceback won't be fully retained). translation is lossy since the traceback won't be fully retained).
Executor request format Executor execute format
~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~
* **task** - full task name to be performed * **task_name** - full task name to be performed
* **task_cls** - full task class name to be performed
* **action** - task action to be performed (e.g. execute, revert) * **action** - task action to be performed (e.g. execute, revert)
* **arguments** - arguments the task action to be called with * **arguments** - arguments the task action to be called with
* **result** - task execution result (result or * **result** - task execution result (result or
@ -180,9 +181,14 @@ Additionally, the following parameters are added to the request message:
{ {
"action": "execute", "action": "execute",
"arguments": { "arguments": {
"joe_number": 444 "x": 111
}, },
"task": "tasks.CallJoe" "task_cls": "taskflow.tests.utils.TaskOneArgOneReturn",
"task_name": "taskflow.tests.utils.TaskOneArgOneReturn",
"task_version": [
1,
0
]
} }
Worker response format Worker response format
@ -193,7 +199,8 @@ When **running:**
.. code:: json .. code:: json
{ {
"status": "RUNNING" "data": {},
"state": "RUNNING"
} }
When **progressing:** When **progressing:**
@ -201,9 +208,11 @@ When **progressing:**
.. code:: json .. code:: json
{ {
"event_data": <event_data>, "details": {
"progress": <progress>, "progress": 0.5
"state": "PROGRESS" },
"event_type": "update_progress",
"state": "EVENT"
} }
When **succeeded:** When **succeeded:**
@ -211,8 +220,9 @@ When **succeeded:**
.. code:: json .. code:: json
{ {
"event": <event>, "data": {
"result": <result>, "result": 666
},
"state": "SUCCESS" "state": "SUCCESS"
} }
@ -221,11 +231,64 @@ When **failed:**
.. code:: json .. code:: json
{ {
"event": <event>, "data": {
"result": <types.failure.Failure>, "result": {
"exc_type_names": [
"RuntimeError",
"StandardError",
"Exception"
],
"exception_str": "Woot!",
"traceback_str": " File \"/homes/harlowja/dev/os/taskflow/taskflow/engines/action_engine/executor.py\", line 56, in _execute_task\n result = task.execute(**arguments)\n File \"/homes/harlowja/dev/os/taskflow/taskflow/tests/utils.py\", line 165, in execute\n raise RuntimeError('Woot!')\n",
"version": 1
}
},
"state": "FAILURE" "state": "FAILURE"
} }
Executor revert format
~~~~~~~~~~~~~~~~~~~~~~
When **reverting:**
.. code:: json
{
"action": "revert",
"arguments": {},
"failures": {
"taskflow.tests.utils.TaskWithFailure": {
"exc_type_names": [
"RuntimeError",
"StandardError",
"Exception"
],
"exception_str": "Woot!",
"traceback_str": " File \"/homes/harlowja/dev/os/taskflow/taskflow/engines/action_engine/executor.py\", line 56, in _execute_task\n result = task.execute(**arguments)\n File \"/homes/harlowja/dev/os/taskflow/taskflow/tests/utils.py\", line 165, in execute\n raise RuntimeError('Woot!')\n",
"version": 1
}
},
"result": [
"failure",
{
"exc_type_names": [
"RuntimeError",
"StandardError",
"Exception"
],
"exception_str": "Woot!",
"traceback_str": " File \"/homes/harlowja/dev/os/taskflow/taskflow/engines/action_engine/executor.py\", line 56, in _execute_task\n result = task.execute(**arguments)\n File \"/homes/harlowja/dev/os/taskflow/taskflow/tests/utils.py\", line 165, in execute\n raise RuntimeError('Woot!')\n",
"version": 1
}
],
"task_cls": "taskflow.tests.utils.TaskWithFailure",
"task_name": "taskflow.tests.utils.TaskWithFailure",
"task_version": [
1,
0
]
}
Usage Usage
===== =====

View File

@ -14,6 +14,8 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import functools
from taskflow import logging from taskflow import logging
from taskflow import states from taskflow import states
from taskflow import task as task_atom from taskflow import task as task_atom
@ -75,25 +77,37 @@ class TaskAction(object):
if progress is not None: if progress is not None:
task.update_progress(progress) task.update_progress(progress)
def _on_update_progress(self, task, event_data, progress, **kwargs): def _on_update_progress(self, task, event_type, details):
"""Should be called when task updates its progress.""" """Should be called when task updates its progress."""
try: try:
self._storage.set_task_progress(task.name, progress, kwargs) progress = details.pop('progress')
except KeyError:
pass
else:
try:
self._storage.set_task_progress(task.name, progress,
details=details)
except Exception: except Exception:
# Update progress callbacks should never fail, so capture and log # Update progress callbacks should never fail, so capture and
# the emitted exception instead of raising it. # log the emitted exception instead of raising it.
LOG.exception("Failed setting task progress for %s to %0.3f", LOG.exception("Failed setting task progress for %s to %0.3f",
task, progress) task, progress)
def schedule_execution(self, task): def schedule_execution(self, task):
self.change_state(task, states.RUNNING, progress=0.0) self.change_state(task, states.RUNNING, progress=0.0)
scope_walker = self._walker_factory(task) scope_walker = self._walker_factory(task)
kwargs = self._storage.fetch_mapped_args(task.rebind, arguments = self._storage.fetch_mapped_args(task.rebind,
atom_name=task.name, atom_name=task.name,
scope_walker=scope_walker) scope_walker=scope_walker)
if task.notifier.can_be_registered(task_atom.EVENT_UPDATE_PROGRESS):
progress_callback = functools.partial(self._on_update_progress,
task)
else:
progress_callback = None
task_uuid = self._storage.get_atom_uuid(task.name) task_uuid = self._storage.get_atom_uuid(task.name)
return self._task_executor.execute_task(task, task_uuid, kwargs, return self._task_executor.execute_task(
self._on_update_progress) task, task_uuid, arguments,
progress_callback=progress_callback)
def complete_execution(self, task, result): def complete_execution(self, task, result):
if isinstance(result, failure.Failure): if isinstance(result, failure.Failure):
@ -105,15 +119,20 @@ class TaskAction(object):
def schedule_reversion(self, task): def schedule_reversion(self, task):
self.change_state(task, states.REVERTING, progress=0.0) self.change_state(task, states.REVERTING, progress=0.0)
scope_walker = self._walker_factory(task) scope_walker = self._walker_factory(task)
kwargs = self._storage.fetch_mapped_args(task.rebind, arguments = self._storage.fetch_mapped_args(task.rebind,
atom_name=task.name, atom_name=task.name,
scope_walker=scope_walker) scope_walker=scope_walker)
task_uuid = self._storage.get_atom_uuid(task.name) task_uuid = self._storage.get_atom_uuid(task.name)
task_result = self._storage.get(task.name) task_result = self._storage.get(task.name)
failures = self._storage.get_failures() failures = self._storage.get_failures()
future = self._task_executor.revert_task(task, task_uuid, kwargs, if task.notifier.can_be_registered(task_atom.EVENT_UPDATE_PROGRESS):
task_result, failures, progress_callback = functools.partial(self._on_update_progress,
self._on_update_progress) task)
else:
progress_callback = None
future = self._task_executor.revert_task(
task, task_uuid, arguments, task_result, failures,
progress_callback=progress_callback)
return future return future
def complete_reversion(self, task, rev_result): def complete_reversion(self, task, rev_result):

View File

@ -15,10 +15,11 @@
# under the License. # under the License.
import abc import abc
import contextlib
import six import six
from taskflow import task as _task from taskflow import task as task_atom
from taskflow.types import failure from taskflow.types import failure
from taskflow.types import futures from taskflow.types import futures
from taskflow.utils import async_utils from taskflow.utils import async_utils
@ -29,8 +30,23 @@ EXECUTED = 'executed'
REVERTED = 'reverted' REVERTED = 'reverted'
def _execute_task(task, arguments, progress_callback): @contextlib.contextmanager
with task.autobind(_task.EVENT_UPDATE_PROGRESS, progress_callback): def _autobind(task, progress_callback=None):
bound = False
if progress_callback is not None:
task.notifier.register(task_atom.EVENT_UPDATE_PROGRESS,
progress_callback)
bound = True
try:
yield
finally:
if bound:
task.notifier.deregister(task_atom.EVENT_UPDATE_PROGRESS,
progress_callback)
def _execute_task(task, arguments, progress_callback=None):
with _autobind(task, progress_callback=progress_callback):
try: try:
task.pre_execute() task.pre_execute()
result = task.execute(**arguments) result = task.execute(**arguments)
@ -43,14 +59,14 @@ def _execute_task(task, arguments, progress_callback):
return (EXECUTED, result) return (EXECUTED, result)
def _revert_task(task, arguments, result, failures, progress_callback): def _revert_task(task, arguments, result, failures, progress_callback=None):
kwargs = arguments.copy() arguments = arguments.copy()
kwargs[_task.REVERT_RESULT] = result arguments[task_atom.REVERT_RESULT] = result
kwargs[_task.REVERT_FLOW_FAILURES] = failures arguments[task_atom.REVERT_FLOW_FAILURES] = failures
with task.autobind(_task.EVENT_UPDATE_PROGRESS, progress_callback): with _autobind(task, progress_callback=progress_callback):
try: try:
task.pre_revert() task.pre_revert()
result = task.revert(**kwargs) result = task.revert(**arguments)
except Exception: except Exception:
# NOTE(imelnikov): wrap current exception with Failure # NOTE(imelnikov): wrap current exception with Failure
# object and return it. # object and return it.
@ -98,15 +114,17 @@ class SerialTaskExecutor(TaskExecutor):
self._executor = futures.SynchronousExecutor() self._executor = futures.SynchronousExecutor()
def execute_task(self, task, task_uuid, arguments, progress_callback=None): def execute_task(self, task, task_uuid, arguments, progress_callback=None):
fut = self._executor.submit(_execute_task, task, arguments, fut = self._executor.submit(_execute_task,
progress_callback) task, arguments,
progress_callback=progress_callback)
fut.atom = task fut.atom = task
return fut return fut
def revert_task(self, task, task_uuid, arguments, result, failures, def revert_task(self, task, task_uuid, arguments, result, failures,
progress_callback=None): progress_callback=None):
fut = self._executor.submit(_revert_task, task, arguments, result, fut = self._executor.submit(_revert_task,
failures, progress_callback) task, arguments, result, failures,
progress_callback=progress_callback)
fut.atom = task fut.atom = task
return fut return fut
@ -127,15 +145,17 @@ class ParallelTaskExecutor(TaskExecutor):
self._create_executor = executor is None self._create_executor = executor is None
def execute_task(self, task, task_uuid, arguments, progress_callback=None): def execute_task(self, task, task_uuid, arguments, progress_callback=None):
fut = self._executor.submit(_execute_task, task, fut = self._executor.submit(_execute_task,
arguments, progress_callback) task, arguments,
progress_callback=progress_callback)
fut.atom = task fut.atom = task
return fut return fut
def revert_task(self, task, task_uuid, arguments, result, failures, def revert_task(self, task, task_uuid, arguments, result, failures,
progress_callback=None): progress_callback=None):
fut = self._executor.submit(_revert_task, task, arguments, fut = self._executor.submit(_revert_task,
result, failures, progress_callback) task, arguments, result, failures,
progress_callback=progress_callback)
fut.atom = task fut.atom = task
return fut return fut

View File

@ -33,18 +33,16 @@ class Endpoint(object):
def name(self): def name(self):
return self._task_cls_name return self._task_cls_name
def _get_task(self, name=None): def generate(self, name=None):
# NOTE(skudriashev): Note that task is created here with the `name` # NOTE(skudriashev): Note that task is created here with the `name`
# argument passed to its constructor. This will be a problem when # argument passed to its constructor. This will be a problem when
# task's constructor requires any other arguments. # task's constructor requires any other arguments.
return self._task_cls(name=name) return self._task_cls(name=name)
def execute(self, task_name, **kwargs): def execute(self, task, **kwargs):
task = self._get_task(task_name)
event, result = self._executor.execute_task(task, **kwargs).result() event, result = self._executor.execute_task(task, **kwargs).result()
return result return result
def revert(self, task_name, **kwargs): def revert(self, task, **kwargs):
task = self._get_task(task_name)
event, result = self._executor.revert_task(task, **kwargs).result() event, result = self._executor.revert_task(task, **kwargs).result()
return result return result

View File

@ -25,6 +25,7 @@ from taskflow.engines.worker_based import protocol as pr
from taskflow.engines.worker_based import proxy from taskflow.engines.worker_based import proxy
from taskflow import exceptions as exc from taskflow import exceptions as exc
from taskflow import logging from taskflow import logging
from taskflow import task as task_atom
from taskflow.types import timing as tt from taskflow.types import timing as tt
from taskflow.utils import async_utils from taskflow.utils import async_utils
from taskflow.utils import misc from taskflow.utils import misc
@ -132,8 +133,11 @@ class WorkerTaskExecutor(executor.TaskExecutor):
response.state, request) response.state, request)
if response.state == pr.RUNNING: if response.state == pr.RUNNING:
request.transition_and_log_error(pr.RUNNING, logger=LOG) request.transition_and_log_error(pr.RUNNING, logger=LOG)
elif response.state == pr.PROGRESS: elif response.state == pr.EVENT:
request.on_progress(**response.data) # Proxy the event + details to the task/request notifier...
event_type = response.data['event_type']
details = response.data['details']
request.notifier.notify(event_type, details)
elif response.state in (pr.FAILURE, pr.SUCCESS): elif response.state in (pr.FAILURE, pr.SUCCESS):
moved = request.transition_and_log_error(response.state, moved = request.transition_and_log_error(response.state,
logger=LOG) logger=LOG)
@ -181,8 +185,18 @@ class WorkerTaskExecutor(executor.TaskExecutor):
progress_callback, **kwargs): progress_callback, **kwargs):
"""Submit task request to a worker.""" """Submit task request to a worker."""
request = pr.Request(task, task_uuid, action, arguments, request = pr.Request(task, task_uuid, action, arguments,
progress_callback, self._transition_timeout, self._transition_timeout, **kwargs)
**kwargs)
# Register the callback, so that we can proxy the progress correctly.
if (progress_callback is not None and
request.notifier.can_be_registered(
task_atom.EVENT_UPDATE_PROGRESS)):
request.notifier.register(task_atom.EVENT_UPDATE_PROGRESS,
progress_callback)
cleaner = functools.partial(request.notifier.deregister,
task_atom.EVENT_UPDATE_PROGRESS,
progress_callback)
request.result.add_done_callback(lambda fut: cleaner())
# Get task's topic and publish request if topic was found. # Get task's topic and publish request if topic was found.
topic = self._workers_cache.get_topic_by_task(request.task_cls) topic = self._workers_cache.get_topic_by_task(request.task_cls)

View File

@ -38,14 +38,14 @@ PENDING = 'PENDING'
RUNNING = 'RUNNING' RUNNING = 'RUNNING'
SUCCESS = 'SUCCESS' SUCCESS = 'SUCCESS'
FAILURE = 'FAILURE' FAILURE = 'FAILURE'
PROGRESS = 'PROGRESS' EVENT = 'EVENT'
# During these states the expiry is active (once out of these states the expiry # During these states the expiry is active (once out of these states the expiry
# no longer matters, since we have no way of knowing how long a task will run # no longer matters, since we have no way of knowing how long a task will run
# for). # for).
WAITING_STATES = (WAITING, PENDING) WAITING_STATES = (WAITING, PENDING)
_ALL_STATES = (WAITING, PENDING, RUNNING, SUCCESS, FAILURE, PROGRESS) _ALL_STATES = (WAITING, PENDING, RUNNING, SUCCESS, FAILURE, EVENT)
_STOP_TIMER_STATES = (RUNNING, SUCCESS, FAILURE) _STOP_TIMER_STATES = (RUNNING, SUCCESS, FAILURE)
# Transitions that a request state can go through. # Transitions that a request state can go through.
@ -219,22 +219,29 @@ class Request(Message):
'required': ['task_cls', 'task_name', 'task_version', 'action'], 'required': ['task_cls', 'task_name', 'task_version', 'action'],
} }
def __init__(self, task, uuid, action, arguments, progress_callback, def __init__(self, task, uuid, action, arguments, timeout, **kwargs):
timeout, **kwargs):
self._task = task self._task = task
self._task_cls = reflection.get_class_name(task) self._task_cls = reflection.get_class_name(task)
self._uuid = uuid self._uuid = uuid
self._action = action self._action = action
self._event = ACTION_TO_EVENT[action] self._event = ACTION_TO_EVENT[action]
self._arguments = arguments self._arguments = arguments
self._progress_callback = progress_callback
self._kwargs = kwargs self._kwargs = kwargs
self._watch = tt.StopWatch(duration=timeout).start() self._watch = tt.StopWatch(duration=timeout).start()
self._state = WAITING self._state = WAITING
self._lock = threading.Lock() self._lock = threading.Lock()
self._created_on = timeutils.utcnow() self._created_on = timeutils.utcnow()
self.result = futures.Future() self._result = futures.Future()
self.result.atom = task self._result.atom = task
self._notifier = task.notifier
@property
def result(self):
return self._result
@property
def notifier(self):
return self._notifier
@property @property
def uuid(self): def uuid(self):
@ -293,9 +300,6 @@ class Request(Message):
def set_result(self, result): def set_result(self, result):
self.result.set_result((self._event, result)) self.result.set_result((self._event, result))
def on_progress(self, event_data, progress):
self._progress_callback(self._task, event_data, progress)
def transition_and_log_error(self, new_state, logger=None): def transition_and_log_error(self, new_state, logger=None):
"""Transitions *and* logs an error if that transitioning raises. """Transitions *and* logs an error if that transitioning raises.
@ -362,7 +366,7 @@ class Response(Message):
'data': { 'data': {
"anyOf": [ "anyOf": [
{ {
"$ref": "#/definitions/progress", "$ref": "#/definitions/event",
}, },
{ {
"$ref": "#/definitions/completion", "$ref": "#/definitions/completion",
@ -376,17 +380,17 @@ class Response(Message):
"required": ["state", 'data'], "required": ["state", 'data'],
"additionalProperties": False, "additionalProperties": False,
"definitions": { "definitions": {
"progress": { "event": {
"type": "object", "type": "object",
"properties": { "properties": {
'progress': { 'event_type': {
'type': 'number', 'type': 'string',
}, },
'event_data': { 'details': {
'type': 'object', 'type': 'object',
}, },
}, },
"required": ["progress", 'event_data'], "required": ["event_type", 'details'],
"additionalProperties": False, "additionalProperties": False,
}, },
# Used when sending *only* request state changes (and no data is # Used when sending *only* request state changes (and no data is

View File

@ -22,6 +22,7 @@ from taskflow.engines.worker_based import protocol as pr
from taskflow.engines.worker_based import proxy from taskflow.engines.worker_based import proxy
from taskflow import logging from taskflow import logging
from taskflow.types import failure as ft from taskflow.types import failure as ft
from taskflow.types import notifier as nt
from taskflow.utils import misc from taskflow.utils import misc
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -73,18 +74,23 @@ class Server(object):
All `failure.Failure` objects that have been converted to dict on the All `failure.Failure` objects that have been converted to dict on the
remote side will now converted back to `failure.Failure` objects. remote side will now converted back to `failure.Failure` objects.
""" """
action_args = dict(arguments=arguments, task_name=task_name) # These arguments will eventually be given to the task executor
# so they need to be in a format it will accept (and using keyword
# argument names that it accepts)...
arguments = {
'arguments': arguments,
}
if result is not None: if result is not None:
data_type, data = result data_type, data = result
if data_type == 'failure': if data_type == 'failure':
action_args['result'] = ft.Failure.from_dict(data) arguments['result'] = ft.Failure.from_dict(data)
else: else:
action_args['result'] = data arguments['result'] = data
if failures is not None: if failures is not None:
action_args['failures'] = {} arguments['failures'] = {}
for key, data in six.iteritems(failures): for key, data in six.iteritems(failures):
action_args['failures'][key] = ft.Failure.from_dict(data) arguments['failures'][key] = ft.Failure.from_dict(data)
return task_cls, action, action_args return (task_cls, task_name, action, arguments)
@staticmethod @staticmethod
def _parse_message(message): def _parse_message(message):
@ -122,14 +128,13 @@ class Server(object):
exc_info=True) exc_info=True)
return published return published
def _on_update_progress(self, reply_to, task_uuid, task, event_data, def _on_event(self, reply_to, task_uuid, event_type, details):
progress): """Send out a task event notification."""
"""Send task update progress notification."""
# NOTE(harlowja): the executor that will trigger this using the # NOTE(harlowja): the executor that will trigger this using the
# task notification/listener mechanism will handle logging if this # task notification/listener mechanism will handle logging if this
# fails, so thats why capture is 'False' is used here. # fails, so thats why capture is 'False' is used here.
self._reply(False, reply_to, task_uuid, pr.PROGRESS, self._reply(False, reply_to, task_uuid, pr.EVENT,
event_data=event_data, progress=progress) event_type=event_type, details=details)
def _process_notify(self, notify, message): def _process_notify(self, notify, message):
"""Process notify message and reply back.""" """Process notify message and reply back."""
@ -165,18 +170,15 @@ class Server(object):
message.delivery_tag, exc_info=True) message.delivery_tag, exc_info=True)
return return
else: else:
# prepare task progress callback
progress_callback = functools.partial(self._on_update_progress,
reply_to, task_uuid)
# prepare reply callback # prepare reply callback
reply_callback = functools.partial(self._reply, True, reply_to, reply_callback = functools.partial(self._reply, True, reply_to,
task_uuid) task_uuid)
# parse request to get task name, action and action arguments # parse request to get task name, action and action arguments
try: try:
task_cls, action, action_args = self._parse_request(**request) bundle = self._parse_request(**request)
action_args.update(task_uuid=task_uuid, task_cls, task_name, action, arguments = bundle
progress_callback=progress_callback) arguments['task_uuid'] = task_uuid
except ValueError: except ValueError:
with misc.capture_failure() as failure: with misc.capture_failure() as failure:
LOG.warn("Failed to parse request contents from message %r", LOG.warn("Failed to parse request contents from message %r",
@ -205,13 +207,37 @@ class Server(object):
message.delivery_tag, exc_info=True) message.delivery_tag, exc_info=True)
reply_callback(result=failure.to_dict()) reply_callback(result=failure.to_dict())
return return
else:
try:
task = endpoint.generate(name=task_name)
except Exception:
with misc.capture_failure() as failure:
LOG.warn("The '%s' task '%s' generation for request"
" message %r failed", endpoint, action,
message.delivery_tag, exc_info=True)
reply_callback(result=failure.to_dict())
return
else: else:
if not reply_callback(state=pr.RUNNING): if not reply_callback(state=pr.RUNNING):
return return
# perform task action # associate *any* events this task emits with a proxy that will
# emit them back to the engine... for handling at the engine side
# of things...
if task.notifier.can_be_registered(nt.Notifier.ANY):
task.notifier.register(nt.Notifier.ANY,
functools.partial(self._on_event,
reply_to, task_uuid))
elif isinstance(task.notifier, nt.RestrictedNotifier):
# only proxy the allowable events then...
for event_type in task.notifier.events_iter():
task.notifier.register(event_type,
functools.partial(self._on_event,
reply_to, task_uuid))
# perform the task action
try: try:
result = handler(**action_args) result = handler(task, **arguments)
except Exception: except Exception:
with misc.capture_failure() as failure: with misc.capture_failure() as failure:
LOG.warn("The '%s' endpoint '%s' execution for request" LOG.warn("The '%s' endpoint '%s' execution for request"

View File

@ -16,14 +16,13 @@
# under the License. # under the License.
import abc import abc
import collections
import contextlib
import copy import copy
import six import six
from taskflow import atom from taskflow import atom
from taskflow import logging from taskflow import logging
from taskflow.types import notifier
from taskflow.utils import misc from taskflow.utils import misc
from taskflow.utils import reflection from taskflow.utils import reflection
@ -51,18 +50,28 @@ class BaseTask(atom.Atom):
same piece of work. same piece of work.
""" """
# Known events this task can have callbacks bound to (others that are not # Known internal events this task can have callbacks bound to (others that
# in this set/tuple will not be able to be bound); this should be updated # are not in this set/tuple will not be able to be bound); this should be
# and/or extended in subclasses as needed to enable or disable new or # updated and/or extended in subclasses as needed to enable or disable new
# existing events... # or existing internal events...
TASK_EVENTS = (EVENT_UPDATE_PROGRESS,) TASK_EVENTS = (EVENT_UPDATE_PROGRESS,)
def __init__(self, name, provides=None, inject=None): def __init__(self, name, provides=None, inject=None):
if name is None: if name is None:
name = reflection.get_class_name(self) name = reflection.get_class_name(self)
super(BaseTask, self).__init__(name, provides, inject=inject) super(BaseTask, self).__init__(name, provides, inject=inject)
# Map of events => lists of callbacks to invoke on task events. self._notifier = notifier.RestrictedNotifier(self.TASK_EVENTS)
self._events_listeners = collections.defaultdict(list)
@property
def notifier(self):
"""Internal notification dispatcher/registry.
A notification object that will dispatch events that occur related
to *internal* notifications that the task internally emits to
listeners (for example for progress status updates, telling others
that a task has reached 50% completion...).
"""
return self._notifier
def pre_execute(self): def pre_execute(self):
"""Code to be run prior to executing the task. """Code to be run prior to executing the task.
@ -138,152 +147,31 @@ class BaseTask(atom.Atom):
def copy(self, retain_listeners=True): def copy(self, retain_listeners=True):
"""Clone/copy this task. """Clone/copy this task.
:param retain_listeners: retain the attached listeners when cloning, :param retain_listeners: retain the attached notification listeners
when false the listeners will be emptied, when when cloning, when false the listeners will
true the listeners will be copied and retained be emptied, when true the listeners will be
copied and retained
:return: the copied task :return: the copied task
""" """
c = copy.copy(self) c = copy.copy(self)
c._events_listeners = collections.defaultdict(list) c._notifier = self._notifier.copy()
if retain_listeners: if not retain_listeners:
for event_name, listeners in six.iteritems(self._events_listeners): c._notifier.reset()
c._events_listeners[event_name] = listeners[:]
return c return c
def update_progress(self, progress, **kwargs): def update_progress(self, progress):
"""Update task progress and notify all registered listeners. """Update task progress and notify all registered listeners.
:param progress: task progress float value between 0.0 and 1.0 :param progress: task progress float value between 0.0 and 1.0
:param kwargs: any keyword arguments that are tied to the specific
progress value.
""" """
def on_clamped(): def on_clamped():
LOG.warn("Progress value must be greater or equal to 0.0 or less" LOG.warn("Progress value must be greater or equal to 0.0 or less"
" than or equal to 1.0 instead of being '%s'", progress) " than or equal to 1.0 instead of being '%s'", progress)
cleaned_progress = misc.clamp(progress, 0.0, 1.0, cleaned_progress = misc.clamp(progress, 0.0, 1.0,
on_clamped=on_clamped) on_clamped=on_clamped)
self.trigger(EVENT_UPDATE_PROGRESS, cleaned_progress, **kwargs) self._notifier.notify(EVENT_UPDATE_PROGRESS,
{'progress': cleaned_progress})
def trigger(self, event_name, *args, **kwargs):
"""Execute all callbacks registered for the given event type.
NOTE(harlowja): if a bound callback raises an exception it will be
logged (at a ``WARNING`` level) and the exception
will be dropped.
:param event_name: event name to trigger
:param args: arbitrary positional arguments passed to the triggered
callbacks (if any are matched), these will be in addition
to any ``kwargs`` provided on binding (these are passed
as positional arguments to the callback).
:param kwargs: arbitrary keyword arguments passed to the triggered
callbacks (if any are matched), these will be in addition
to any ``kwargs`` provided on binding (these are passed
as keyword arguments to the callback).
"""
for (cb, event_data) in self._events_listeners.get(event_name, []):
try:
cb(self, event_data, *args, **kwargs)
except Exception:
LOG.warn("Failed calling callback `%s` on event '%s'",
reflection.get_callable_name(cb), event_name,
exc_info=True)
@contextlib.contextmanager
def autobind(self, event_name, callback, **kwargs):
"""Binds & unbinds a given callback to the task.
This function binds and unbinds using the context manager protocol.
When events are triggered on the task of the given event name this
callback will automatically be called with the provided
keyword arguments as the first argument (further arguments may be
provided by the entity triggering the event).
The arguments are interpreted as for :func:`bind() <bind>`.
"""
bound = False
if callback is not None:
try:
self.bind(event_name, callback, **kwargs)
bound = True
except ValueError:
LOG.warn("Failed binding callback `%s` as a receiver of"
" event '%s' notifications emitted from task '%s'",
reflection.get_callable_name(callback), event_name,
self, exc_info=True)
try:
yield self
finally:
if bound:
self.unbind(event_name, callback)
def bind(self, event_name, callback, **kwargs):
"""Attach a callback to be triggered on a task event.
Callbacks should *not* be bound, modified, or removed after execution
has commenced (they may be adjusted after execution has finished). This
is primarily due to the need to preserve the callbacks that exist at
execution time for engines which run tasks remotely or out of
process (so that those engines can correctly proxy back transmitted
events).
Callbacks should also be *quick* to execute so that the engine calling
them can continue execution in a timely manner (if long running
callbacks need to exist, consider creating a separate pool + queue
for those that the attached callbacks put long running operations into
for execution by other entities).
:param event_name: event type name
:param callback: callable to execute each time event is triggered
:param kwargs: optional named parameters that will be passed to the
callable object as a dictionary to the callbacks
*second* positional parameter.
:raises ValueError: if invalid event type, or callback is passed
"""
if event_name not in self.TASK_EVENTS:
raise ValueError("Unknown task event '%s', can only bind"
" to events %s" % (event_name, self.TASK_EVENTS))
if callback is not None:
if not six.callable(callback):
raise ValueError("Event handler callback must be callable")
self._events_listeners[event_name].append((callback, kwargs))
def unbind(self, event_name, callback=None):
"""Remove a previously-attached event callback from the task.
If a callback is not passed, then this will unbind *all* event
callbacks for the provided event. If multiple of the same callbacks
are bound, then the first match is removed (and only the first match).
:param event_name: event type
:param callback: callback previously bound
:rtype: boolean
:return: whether anything was removed
"""
removed_any = False
if not callback:
removed_any = self._events_listeners.pop(event_name, removed_any)
else:
event_listeners = self._events_listeners.get(event_name, [])
for i, (cb, _event_data) in enumerate(event_listeners):
if reflection.is_same_callback(cb, callback):
# NOTE(harlowja): its safe to do this as long as we stop
# iterating after we do the removal, otherwise its not
# safe (since this could have resized the list).
event_listeners.pop(i)
removed_any = True
break
return bool(removed_any)
def listeners_iter(self):
"""Return an iterator over the mapping of event => callbacks bound."""
for event_name in list(six.iterkeys(self._events_listeners)):
# Use get() just incase it was removed while iterating...
event_listeners = self._events_listeners.get(event_name, [])
if event_listeners:
yield (event_name, event_listeners[:])
class Task(BaseTask): class Task(BaseTask):

View File

@ -84,6 +84,31 @@ class NotifierTest(test.TestCase):
self.assertRaises(ValueError, notifier.register, self.assertRaises(ValueError, notifier.register,
nt.Notifier.ANY, 2) nt.Notifier.ANY, 2)
def test_restricted_notifier(self):
notifier = nt.RestrictedNotifier(['a', 'b'])
self.assertRaises(ValueError, notifier.register,
'c', lambda *args, **kargs: None)
notifier.register('b', lambda *args, **kargs: None)
self.assertEqual(1, len(notifier))
def test_restricted_notifier_any(self):
notifier = nt.RestrictedNotifier(['a', 'b'])
self.assertRaises(ValueError, notifier.register,
'c', lambda *args, **kargs: None)
notifier.register('b', lambda *args, **kargs: None)
self.assertEqual(1, len(notifier))
notifier.register(nt.RestrictedNotifier.ANY,
lambda *args, **kargs: None)
self.assertEqual(2, len(notifier))
def test_restricted_notifier_no_any(self):
notifier = nt.RestrictedNotifier(['a', 'b'], allow_any=False)
self.assertRaises(ValueError, notifier.register,
nt.RestrictedNotifier.ANY,
lambda *args, **kargs: None)
notifier.register('b', lambda *args, **kargs: None)
self.assertEqual(1, len(notifier))
def test_selective_notify(self): def test_selective_notify(self):
call_counts = collections.defaultdict(list) call_counts = collections.defaultdict(list)

View File

@ -39,7 +39,12 @@ class ProgressTask(task.Task):
class ProgressTaskWithDetails(task.Task): class ProgressTaskWithDetails(task.Task):
def execute(self): def execute(self):
self.update_progress(0.5, test='test data', foo='bar') details = {
'progress': 0.5,
'test': 'test data',
'foo': 'bar',
}
self.notifier.notify(task.EVENT_UPDATE_PROGRESS, details)
class TestProgress(test.TestCase): class TestProgress(test.TestCase):
@ -60,12 +65,12 @@ class TestProgress(test.TestCase):
def test_sanity_progress(self): def test_sanity_progress(self):
fired_events = [] fired_events = []
def notify_me(task, event_data, progress): def notify_me(event_type, details):
fired_events.append(progress) fired_events.append(details.pop('progress'))
ev_count = 5 ev_count = 5
t = ProgressTask("test", ev_count) t = ProgressTask("test", ev_count)
t.bind('update_progress', notify_me) t.notifier.register(task.EVENT_UPDATE_PROGRESS, notify_me)
flo = lf.Flow("test") flo = lf.Flow("test")
flo.add(t) flo.add(t)
e = self._make_engine(flo) e = self._make_engine(flo)
@ -77,11 +82,11 @@ class TestProgress(test.TestCase):
def test_no_segments_progress(self): def test_no_segments_progress(self):
fired_events = [] fired_events = []
def notify_me(task, event_data, progress): def notify_me(event_type, details):
fired_events.append(progress) fired_events.append(details.pop('progress'))
t = ProgressTask("test", 0) t = ProgressTask("test", 0)
t.bind('update_progress', notify_me) t.notifier.register(task.EVENT_UPDATE_PROGRESS, notify_me)
flo = lf.Flow("test") flo = lf.Flow("test")
flo.add(t) flo.add(t)
e = self._make_engine(flo) e = self._make_engine(flo)
@ -121,12 +126,12 @@ class TestProgress(test.TestCase):
def test_dual_storage_progress(self): def test_dual_storage_progress(self):
fired_events = [] fired_events = []
def notify_me(task, event_data, progress): def notify_me(event_type, details):
fired_events.append(progress) fired_events.append(details.pop('progress'))
with contextlib.closing(impl_memory.MemoryBackend({})) as be: with contextlib.closing(impl_memory.MemoryBackend({})) as be:
t = ProgressTask("test", 5) t = ProgressTask("test", 5)
t.bind('update_progress', notify_me) t.notifier.register(task.EVENT_UPDATE_PROGRESS, notify_me)
flo = lf.Flow("test") flo = lf.Flow("test")
flo.add(t) flo.add(t)
b, fd = p_utils.temporary_flow_detail(be) b, fd = p_utils.temporary_flow_detail(be)

View File

@ -17,7 +17,7 @@
from taskflow import task from taskflow import task
from taskflow import test from taskflow import test
from taskflow.test import mock from taskflow.test import mock
from taskflow.utils import reflection from taskflow.types import notifier
class MyTask(task.Task): class MyTask(task.Task):
@ -198,11 +198,11 @@ class TaskTest(test.TestCase):
values = [0.0, 0.5, 1.0] values = [0.0, 0.5, 1.0]
result = [] result = []
def progress_callback(task, event_data, progress): def progress_callback(event_type, details):
result.append(progress) result.append(details.pop('progress'))
a_task = ProgressTask() a_task = ProgressTask()
with a_task.autobind(task.EVENT_UPDATE_PROGRESS, progress_callback): a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, progress_callback)
a_task.execute(values) a_task.execute(values)
self.assertEqual(result, values) self.assertEqual(result, values)
@ -210,11 +210,11 @@ class TaskTest(test.TestCase):
def test_update_progress_lower_bound(self, mocked_warn): def test_update_progress_lower_bound(self, mocked_warn):
result = [] result = []
def progress_callback(task, event_data, progress): def progress_callback(event_type, details):
result.append(progress) result.append(details.pop('progress'))
a_task = ProgressTask() a_task = ProgressTask()
with a_task.autobind(task.EVENT_UPDATE_PROGRESS, progress_callback): a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, progress_callback)
a_task.execute([-1.0, -0.5, 0.0]) a_task.execute([-1.0, -0.5, 0.0])
self.assertEqual(result, [0.0, 0.0, 0.0]) self.assertEqual(result, [0.0, 0.0, 0.0])
self.assertEqual(mocked_warn.call_count, 2) self.assertEqual(mocked_warn.call_count, 2)
@ -223,81 +223,87 @@ class TaskTest(test.TestCase):
def test_update_progress_upper_bound(self, mocked_warn): def test_update_progress_upper_bound(self, mocked_warn):
result = [] result = []
def progress_callback(task, event_data, progress): def progress_callback(event_type, details):
result.append(progress) result.append(details.pop('progress'))
a_task = ProgressTask() a_task = ProgressTask()
with a_task.autobind(task.EVENT_UPDATE_PROGRESS, progress_callback): a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, progress_callback)
a_task.execute([1.0, 1.5, 2.0]) a_task.execute([1.0, 1.5, 2.0])
self.assertEqual(result, [1.0, 1.0, 1.0]) self.assertEqual(result, [1.0, 1.0, 1.0])
self.assertEqual(mocked_warn.call_count, 2) self.assertEqual(mocked_warn.call_count, 2)
@mock.patch.object(task.LOG, 'warn') @mock.patch.object(notifier.LOG, 'warn')
def test_update_progress_handler_failure(self, mocked_warn): def test_update_progress_handler_failure(self, mocked_warn):
def progress_callback(*args, **kwargs): def progress_callback(*args, **kwargs):
raise Exception('Woot!') raise Exception('Woot!')
a_task = ProgressTask() a_task = ProgressTask()
with a_task.autobind(task.EVENT_UPDATE_PROGRESS, progress_callback): a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, progress_callback)
a_task.execute([0.5]) a_task.execute([0.5])
mocked_warn.assert_called_once_with( mocked_warn.assert_called_once()
mock.ANY, reflection.get_callable_name(progress_callback),
task.EVENT_UPDATE_PROGRESS, exc_info=mock.ANY)
def test_autobind_handler_is_none(self): def test_register_handler_is_none(self):
a_task = MyTask() a_task = MyTask()
with a_task.autobind(task.EVENT_UPDATE_PROGRESS, None): self.assertRaises(ValueError, a_task.notifier.register,
self.assertEqual(len(list(a_task.listeners_iter())), 0) task.EVENT_UPDATE_PROGRESS, None)
self.assertEqual(len(a_task.notifier), 0)
def test_unbind_any_handler(self): def test_deregister_any_handler(self):
a_task = MyTask() a_task = MyTask()
self.assertEqual(len(list(a_task.listeners_iter())), 0) self.assertEqual(len(a_task.notifier), 0)
a_task.bind(task.EVENT_UPDATE_PROGRESS, lambda: None) a_task.notifier.register(task.EVENT_UPDATE_PROGRESS,
self.assertEqual(len(list(a_task.listeners_iter())), 1) lambda event_type, details: None)
self.assertTrue(a_task.unbind(task.EVENT_UPDATE_PROGRESS)) self.assertEqual(len(a_task.notifier), 1)
self.assertEqual(len(list(a_task.listeners_iter())), 0) a_task.notifier.deregister_event(task.EVENT_UPDATE_PROGRESS)
self.assertEqual(len(a_task.notifier), 0)
def test_unbind_any_handler_empty_listeners(self): def test_deregister_any_handler_empty_listeners(self):
a_task = MyTask() a_task = MyTask()
self.assertEqual(len(list(a_task.listeners_iter())), 0) self.assertEqual(len(a_task.notifier), 0)
self.assertFalse(a_task.unbind(task.EVENT_UPDATE_PROGRESS)) self.assertFalse(a_task.notifier.deregister_event(
self.assertEqual(len(list(a_task.listeners_iter())), 0) task.EVENT_UPDATE_PROGRESS))
self.assertEqual(len(a_task.notifier), 0)
def test_unbind_non_existent_listener(self): def test_deregister_non_existent_listener(self):
handler1 = lambda: None handler1 = lambda event_type, details: None
handler2 = lambda: None handler2 = lambda event_type, details: None
a_task = MyTask() a_task = MyTask()
a_task.bind(task.EVENT_UPDATE_PROGRESS, handler1) a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, handler1)
self.assertEqual(len(list(a_task.listeners_iter())), 1) self.assertEqual(len(list(a_task.notifier.listeners_iter())), 1)
self.assertFalse(a_task.unbind(task.EVENT_UPDATE_PROGRESS, handler2)) a_task.notifier.deregister(task.EVENT_UPDATE_PROGRESS, handler2)
self.assertEqual(len(list(a_task.listeners_iter())), 1) self.assertEqual(len(list(a_task.notifier.listeners_iter())), 1)
a_task.notifier.deregister(task.EVENT_UPDATE_PROGRESS, handler1)
self.assertEqual(len(list(a_task.notifier.listeners_iter())), 0)
def test_bind_not_callable(self): def test_bind_not_callable(self):
task = MyTask() a_task = MyTask()
self.assertRaises(ValueError, task.bind, 'update_progress', 2) self.assertRaises(ValueError, a_task.notifier.register,
task.EVENT_UPDATE_PROGRESS, 2)
def test_copy_no_listeners(self): def test_copy_no_listeners(self):
handler1 = lambda: None handler1 = lambda event_type, details: None
a_task = MyTask() a_task = MyTask()
a_task.bind(task.EVENT_UPDATE_PROGRESS, handler1) a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, handler1)
b_task = a_task.copy(retain_listeners=False) b_task = a_task.copy(retain_listeners=False)
self.assertEqual(len(list(a_task.listeners_iter())), 1) self.assertEqual(len(a_task.notifier), 1)
self.assertEqual(len(list(b_task.listeners_iter())), 0) self.assertEqual(len(b_task.notifier), 0)
def test_copy_listeners(self): def test_copy_listeners(self):
handler1 = lambda: None handler1 = lambda event_type, details: None
handler2 = lambda: None handler2 = lambda event_type, details: None
a_task = MyTask() a_task = MyTask()
a_task.bind(task.EVENT_UPDATE_PROGRESS, handler1) a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, handler1)
b_task = a_task.copy() b_task = a_task.copy()
self.assertEqual(len(list(b_task.listeners_iter())), 1) self.assertEqual(len(b_task.notifier), 1)
self.assertTrue(a_task.unbind(task.EVENT_UPDATE_PROGRESS)) self.assertTrue(a_task.notifier.deregister_event(
self.assertEqual(len(list(a_task.listeners_iter())), 0) task.EVENT_UPDATE_PROGRESS))
self.assertEqual(len(list(b_task.listeners_iter())), 1) self.assertEqual(len(a_task.notifier), 0)
b_task.bind(task.EVENT_UPDATE_PROGRESS, handler2) self.assertEqual(len(b_task.notifier), 1)
listeners = dict(list(b_task.listeners_iter())) b_task.notifier.register(task.EVENT_UPDATE_PROGRESS, handler2)
listeners = dict(list(b_task.notifier.listeners_iter()))
self.assertEqual(len(listeners[task.EVENT_UPDATE_PROGRESS]), 2) self.assertEqual(len(listeners[task.EVENT_UPDATE_PROGRESS]), 2)
self.assertEqual(len(list(a_task.listeners_iter())), 0) self.assertEqual(len(a_task.notifier), 0)
class FunctorTaskTest(test.TestCase): class FunctorTaskTest(test.TestCase):

View File

@ -42,14 +42,14 @@ class TestEndpoint(test.TestCase):
self.task_result = 1 self.task_result = 1
def test_creation(self): def test_creation(self):
task = self.task_ep._get_task() task = self.task_ep.generate()
self.assertEqual(self.task_ep.name, self.task_cls_name) self.assertEqual(self.task_ep.name, self.task_cls_name)
self.assertIsInstance(task, self.task_cls) self.assertIsInstance(task, self.task_cls)
self.assertEqual(task.name, self.task_cls_name) self.assertEqual(task.name, self.task_cls_name)
def test_creation_with_task_name(self): def test_creation_with_task_name(self):
task_name = 'test' task_name = 'test'
task = self.task_ep._get_task(name=task_name) task = self.task_ep.generate(name=task_name)
self.assertEqual(self.task_ep.name, self.task_cls_name) self.assertEqual(self.task_ep.name, self.task_cls_name)
self.assertIsInstance(task, self.task_cls) self.assertIsInstance(task, self.task_cls)
self.assertEqual(task.name, task_name) self.assertEqual(task.name, task_name)
@ -58,20 +58,22 @@ class TestEndpoint(test.TestCase):
# NOTE(skudriashev): Exception is expected here since task # NOTE(skudriashev): Exception is expected here since task
# is created without any arguments passing to its constructor. # is created without any arguments passing to its constructor.
endpoint = ep.Endpoint(Task) endpoint = ep.Endpoint(Task)
self.assertRaises(TypeError, endpoint._get_task) self.assertRaises(TypeError, endpoint.generate)
def test_to_str(self): def test_to_str(self):
self.assertEqual(str(self.task_ep), self.task_cls_name) self.assertEqual(str(self.task_ep), self.task_cls_name)
def test_execute(self): def test_execute(self):
result = self.task_ep.execute(task_name=self.task_cls_name, task = self.task_ep.generate(self.task_cls_name)
result = self.task_ep.execute(task,
task_uuid=self.task_uuid, task_uuid=self.task_uuid,
arguments=self.task_args, arguments=self.task_args,
progress_callback=None) progress_callback=None)
self.assertEqual(result, self.task_result) self.assertEqual(result, self.task_result)
def test_revert(self): def test_revert(self):
result = self.task_ep.revert(task_name=self.task_cls_name, task = self.task_ep.generate(self.task_cls_name)
result = self.task_ep.revert(task,
task_uuid=self.task_uuid, task_uuid=self.task_uuid,
arguments=self.task_args, arguments=self.task_args,
progress_callback=None, progress_callback=None,

View File

@ -21,6 +21,7 @@ from oslo.utils import timeutils
from taskflow.engines.worker_based import executor from taskflow.engines.worker_based import executor
from taskflow.engines.worker_based import protocol as pr from taskflow.engines.worker_based import protocol as pr
from taskflow import task as task_atom
from taskflow import test from taskflow import test
from taskflow.test import mock from taskflow.test import mock
from taskflow.tests import utils as test_utils from taskflow.tests import utils as test_utils
@ -102,13 +103,18 @@ class TestWorkerTaskExecutor(test.MockTestCase):
self.assertEqual(expected_calls, self.request_inst_mock.mock_calls) self.assertEqual(expected_calls, self.request_inst_mock.mock_calls)
def test_on_message_response_state_progress(self): def test_on_message_response_state_progress(self):
response = pr.Response(pr.PROGRESS, progress=1.0) response = pr.Response(pr.EVENT,
event_type=task_atom.EVENT_UPDATE_PROGRESS,
details={'progress': 1.0})
ex = self.executor() ex = self.executor()
ex._requests_cache[self.task_uuid] = self.request_inst_mock ex._requests_cache[self.task_uuid] = self.request_inst_mock
ex._process_response(response.to_dict(), self.message_mock) ex._process_response(response.to_dict(), self.message_mock)
self.assertEqual(self.request_inst_mock.mock_calls, expected_calls = [
[mock.call.on_progress(progress=1.0)]) mock.call.notifier.notify(task_atom.EVENT_UPDATE_PROGRESS,
{'progress': 1.0}),
]
self.assertEqual(expected_calls, self.request_inst_mock.mock_calls)
def test_on_message_response_state_failure(self): def test_on_message_response_state_failure(self):
a_failure = failure.Failure.from_exception(Exception('test')) a_failure = failure.Failure.from_exception(Exception('test'))
@ -211,7 +217,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
expected_calls = [ expected_calls = [
mock.call.Request(self.task, self.task_uuid, 'execute', mock.call.Request(self.task, self.task_uuid, 'execute',
self.task_args, None, self.timeout), self.task_args, self.timeout),
mock.call.request.transition_and_log_error(pr.PENDING, mock.call.request.transition_and_log_error(pr.PENDING,
logger=mock.ANY), logger=mock.ANY),
mock.call.proxy.publish(self.request_inst_mock, mock.call.proxy.publish(self.request_inst_mock,
@ -231,7 +237,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
expected_calls = [ expected_calls = [
mock.call.Request(self.task, self.task_uuid, 'revert', mock.call.Request(self.task, self.task_uuid, 'revert',
self.task_args, None, self.timeout, self.task_args, self.timeout,
failures=self.task_failures, failures=self.task_failures,
result=self.task_result), result=self.task_result),
mock.call.request.transition_and_log_error(pr.PENDING, mock.call.request.transition_and_log_error(pr.PENDING,
@ -250,7 +256,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
expected_calls = [ expected_calls = [
mock.call.Request(self.task, self.task_uuid, 'execute', mock.call.Request(self.task, self.task_uuid, 'execute',
self.task_args, None, self.timeout) self.task_args, self.timeout),
] ]
self.assertEqual(self.master_mock.mock_calls, expected_calls) self.assertEqual(self.master_mock.mock_calls, expected_calls)
@ -264,7 +270,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
expected_calls = [ expected_calls = [
mock.call.Request(self.task, self.task_uuid, 'execute', mock.call.Request(self.task, self.task_uuid, 'execute',
self.task_args, None, self.timeout), self.task_args, self.timeout),
mock.call.request.transition_and_log_error(pr.PENDING, mock.call.request.transition_and_log_error(pr.PENDING,
logger=mock.ANY), logger=mock.ANY),
mock.call.proxy.publish(self.request_inst_mock, mock.call.proxy.publish(self.request_inst_mock,

View File

@ -118,7 +118,7 @@ class TestMessagePump(test.TestCase):
else: else:
p.publish(pr.Request(test_utils.DummyTask("dummy_%s" % i), p.publish(pr.Request(test_utils.DummyTask("dummy_%s" % i),
uuidutils.generate_uuid(), uuidutils.generate_uuid(),
pr.EXECUTE, [], None, None), TEST_TOPIC) pr.EXECUTE, [], None), TEST_TOPIC)
self.assertTrue(barrier.wait(test_utils.WAIT_TIMEOUT)) self.assertTrue(barrier.wait(test_utils.WAIT_TIMEOUT))
self.assertEqual(0, barrier.needed) self.assertEqual(0, barrier.needed)

View File

@ -22,7 +22,6 @@ from taskflow.engines.worker_based import protocol as pr
from taskflow import exceptions as excp from taskflow import exceptions as excp
from taskflow.openstack.common import uuidutils from taskflow.openstack.common import uuidutils
from taskflow import test from taskflow import test
from taskflow.test import mock
from taskflow.tests import utils from taskflow.tests import utils
from taskflow.types import failure from taskflow.types import failure
@ -53,7 +52,7 @@ class TestProtocolValidation(test.TestCase):
def test_request(self): def test_request(self):
msg = pr.Request(utils.DummyTask("hi"), uuidutils.generate_uuid(), msg = pr.Request(utils.DummyTask("hi"), uuidutils.generate_uuid(),
pr.EXECUTE, {}, None, 1.0) pr.EXECUTE, {}, 1.0)
pr.Request.validate(msg.to_dict()) pr.Request.validate(msg.to_dict())
def test_request_invalid(self): def test_request_invalid(self):
@ -66,13 +65,14 @@ class TestProtocolValidation(test.TestCase):
def test_request_invalid_action(self): def test_request_invalid_action(self):
msg = pr.Request(utils.DummyTask("hi"), uuidutils.generate_uuid(), msg = pr.Request(utils.DummyTask("hi"), uuidutils.generate_uuid(),
pr.EXECUTE, {}, None, 1.0) pr.EXECUTE, {}, 1.0)
msg = msg.to_dict() msg = msg.to_dict()
msg['action'] = 'NOTHING' msg['action'] = 'NOTHING'
self.assertRaises(excp.InvalidFormat, pr.Request.validate, msg) self.assertRaises(excp.InvalidFormat, pr.Request.validate, msg)
def test_response_progress(self): def test_response_progress(self):
msg = pr.Response(pr.PROGRESS, progress=0.5, event_data={}) msg = pr.Response(pr.EVENT, details={'progress': 0.5},
event_type='blah')
pr.Response.validate(msg.to_dict()) pr.Response.validate(msg.to_dict())
def test_response_completion(self): def test_response_completion(self):
@ -80,7 +80,9 @@ class TestProtocolValidation(test.TestCase):
pr.Response.validate(msg.to_dict()) pr.Response.validate(msg.to_dict())
def test_response_mixed_invalid(self): def test_response_mixed_invalid(self):
msg = pr.Response(pr.PROGRESS, progress=0.5, event_data={}, result=1) msg = pr.Response(pr.EVENT,
details={'progress': 0.5},
event_type='blah', result=1)
self.assertRaises(excp.InvalidFormat, pr.Response.validate, msg) self.assertRaises(excp.InvalidFormat, pr.Response.validate, msg)
def test_response_bad_state(self): def test_response_bad_state(self):
@ -184,16 +186,3 @@ class TestProtocol(test.TestCase):
request.set_result(111) request.set_result(111)
result = request.result.result() result = request.result.result()
self.assertEqual(result, (executor.EXECUTED, 111)) self.assertEqual(result, (executor.EXECUTED, 111))
def test_on_progress(self):
progress_callback = mock.MagicMock(name='progress_callback')
request = self.request(task=self.task,
progress_callback=progress_callback)
request.on_progress('event_data', 0.0)
request.on_progress('event_data', 1.0)
expected_calls = [
mock.call(self.task, 'event_data', 0.0),
mock.call(self.task, 'event_data', 1.0)
]
self.assertEqual(progress_callback.mock_calls, expected_calls)

View File

@ -19,6 +19,7 @@ import six
from taskflow.engines.worker_based import endpoint as ep from taskflow.engines.worker_based import endpoint as ep
from taskflow.engines.worker_based import protocol as pr from taskflow.engines.worker_based import protocol as pr
from taskflow.engines.worker_based import server from taskflow.engines.worker_based import server
from taskflow import task as task_atom
from taskflow import test from taskflow import test
from taskflow.test import mock from taskflow.test import mock
from taskflow.tests import utils from taskflow.tests import utils
@ -103,45 +104,41 @@ class TestServer(test.MockTestCase):
def test_parse_request(self): def test_parse_request(self):
request = self.make_request() request = self.make_request()
task_cls, action, task_args = server.Server._parse_request(**request) bundle = server.Server._parse_request(**request)
task_cls, task_name, action, task_args = bundle
self.assertEqual((task_cls, action, task_args), self.assertEqual((task_cls, task_name, action, task_args),
(self.task.name, self.task_action, (self.task.name, self.task.name, self.task_action,
dict(task_name=self.task.name, dict(arguments=self.task_args)))
arguments=self.task_args)))
def test_parse_request_with_success_result(self): def test_parse_request_with_success_result(self):
request = self.make_request(action='revert', result=1) request = self.make_request(action='revert', result=1)
task_cls, action, task_args = server.Server._parse_request(**request) bundle = server.Server._parse_request(**request)
task_cls, task_name, action, task_args = bundle
self.assertEqual((task_cls, action, task_args), self.assertEqual((task_cls, task_name, action, task_args),
(self.task.name, 'revert', (self.task.name, self.task.name, 'revert',
dict(task_name=self.task.name, dict(arguments=self.task_args,
arguments=self.task_args,
result=1))) result=1)))
def test_parse_request_with_failure_result(self): def test_parse_request_with_failure_result(self):
a_failure = failure.Failure.from_exception(Exception('test')) a_failure = failure.Failure.from_exception(Exception('test'))
request = self.make_request(action='revert', result=a_failure) request = self.make_request(action='revert', result=a_failure)
task_cls, action, task_args = server.Server._parse_request(**request) bundle = server.Server._parse_request(**request)
task_cls, task_name, action, task_args = bundle
self.assertEqual((task_cls, action, task_args), self.assertEqual((task_cls, task_name, action, task_args),
(self.task.name, 'revert', (self.task.name, self.task.name, 'revert',
dict(task_name=self.task.name, dict(arguments=self.task_args,
arguments=self.task_args,
result=utils.FailureMatcher(a_failure)))) result=utils.FailureMatcher(a_failure))))
def test_parse_request_with_failures(self): def test_parse_request_with_failures(self):
failures = {'0': failure.Failure.from_exception(Exception('test1')), failures = {'0': failure.Failure.from_exception(Exception('test1')),
'1': failure.Failure.from_exception(Exception('test2'))} '1': failure.Failure.from_exception(Exception('test2'))}
request = self.make_request(action='revert', failures=failures) request = self.make_request(action='revert', failures=failures)
task_cls, action, task_args = server.Server._parse_request(**request) bundle = server.Server._parse_request(**request)
task_cls, task_name, action, task_args = bundle
self.assertEqual( self.assertEqual(
(task_cls, action, task_args), (task_cls, task_name, action, task_args),
(self.task.name, 'revert', (self.task.name, self.task.name, 'revert',
dict(task_name=self.task.name, dict(arguments=self.task_args,
arguments=self.task_args,
failures=dict((i, utils.FailureMatcher(f)) failures=dict((i, utils.FailureMatcher(f))
for i, f in six.iteritems(failures))))) for i, f in six.iteritems(failures)))))
@ -182,17 +179,19 @@ class TestServer(test.MockTestCase):
mock.call.Response(pr.RUNNING), mock.call.Response(pr.RUNNING),
mock.call.proxy.publish(self.response_inst_mock, self.reply_to, mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
correlation_id=self.task_uuid), correlation_id=self.task_uuid),
mock.call.Response(pr.PROGRESS, progress=0.0, event_data={}), mock.call.Response(pr.EVENT, details={'progress': 0.0},
event_type=task_atom.EVENT_UPDATE_PROGRESS),
mock.call.proxy.publish(self.response_inst_mock, self.reply_to, mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
correlation_id=self.task_uuid), correlation_id=self.task_uuid),
mock.call.Response(pr.PROGRESS, progress=1.0, event_data={}), mock.call.Response(pr.EVENT, details={'progress': 1.0},
event_type=task_atom.EVENT_UPDATE_PROGRESS),
mock.call.proxy.publish(self.response_inst_mock, self.reply_to, mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
correlation_id=self.task_uuid), correlation_id=self.task_uuid),
mock.call.Response(pr.SUCCESS, result=5), mock.call.Response(pr.SUCCESS, result=5),
mock.call.proxy.publish(self.response_inst_mock, self.reply_to, mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
correlation_id=self.task_uuid) correlation_id=self.task_uuid)
] ]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls) self.assertEqual(master_mock_calls, self.master_mock.mock_calls)
def test_process_request(self): def test_process_request(self):
# create server and process request # create server and process request

View File

@ -15,6 +15,7 @@
# under the License. # under the License.
import collections import collections
import copy
import logging import logging
import six import six
@ -39,6 +40,14 @@ class _Listener(object):
else: else:
self._kwargs = kwargs.copy() self._kwargs = kwargs.copy()
@property
def kwargs(self):
return self._kwargs
@property
def args(self):
return self._args
def __call__(self, event_type, details): def __call__(self, event_type, details):
if self._details_filter is not None: if self._details_filter is not None:
if not self._details_filter(details): if not self._details_filter(details):
@ -117,17 +126,18 @@ class Notifier(object):
event type will be called. event type will be called.
:param event_type: event type that occurred :param event_type: event type that occurred
:param details: addition event details :param details: additional event details *dictionary* passed to
callback keyword argument with the same name.
""" """
listeners = list(self._listeners.get(self.ANY, [])) listeners = list(self._listeners.get(self.ANY, []))
for listener in self._listeners[event_type]: listeners.extend(self._listeners.get(event_type, []))
if listener not in listeners:
listeners.append(listener)
if not listeners: if not listeners:
return return
if not details:
details = {}
for listener in listeners: for listener in listeners:
try: try:
listener(event_type, details) listener(event_type, details.copy())
except Exception: except Exception:
LOG.warn("Failure calling listener %s to notify about event" LOG.warn("Failure calling listener %s to notify about event"
" %s, details: %s", listener, event_type, " %s, details: %s", listener, event_type,
@ -150,6 +160,9 @@ class Notifier(object):
if details_filter is not None: if details_filter is not None:
if not six.callable(details_filter): if not six.callable(details_filter):
raise ValueError("Details filter must be callable") raise ValueError("Details filter must be callable")
if not self.can_be_registered(event_type):
raise ValueError("Disallowed event type '%s' can not have a"
" callback registered" % event_type)
if self.is_registered(event_type, callback, if self.is_registered(event_type, callback,
details_filter=details_filter): details_filter=details_filter):
raise ValueError("Event callback already registered with" raise ValueError("Event callback already registered with"
@ -165,10 +178,61 @@ class Notifier(object):
details_filter=details_filter)) details_filter=details_filter))
def deregister(self, event_type, callback, details_filter=None): def deregister(self, event_type, callback, details_filter=None):
"""Remove a single callback from listening to event ``event_type``.""" """Remove a single listener bound to event ``event_type``."""
if event_type not in self._listeners: if event_type not in self._listeners:
return return False
for i, listener in enumerate(self._listeners[event_type]): for i, listener in enumerate(self._listeners.get(event_type, [])):
if listener.is_equivalent(callback, details_filter=details_filter): if listener.is_equivalent(callback, details_filter=details_filter):
self._listeners[event_type].pop(i) self._listeners[event_type].pop(i)
break return True
return False
def deregister_event(self, event_type):
"""Remove a group of listeners bound to event ``event_type``."""
return len(self._listeners.pop(event_type, []))
def copy(self):
c = copy.copy(self)
c._listeners = collections.defaultdict(list)
for event_type, listeners in six.iteritems(self._listeners):
c._listeners[event_type] = listeners[:]
return c
def listeners_iter(self):
"""Return an iterator over the mapping of event => listeners bound."""
for event_type, listeners in six.iteritems(self._listeners):
if listeners:
yield (event_type, listeners)
def can_be_registered(self, event_type):
"""Checks if the event can be registered/subscribed to."""
return True
class RestrictedNotifier(Notifier):
"""A notification class that restricts events registered/triggered.
NOTE(harlowja): This class unlike :class:`.Notifier` restricts and
disallows registering callbacks for event types that are not declared
when constructing the notifier.
"""
def __init__(self, watchable_events, allow_any=True):
super(RestrictedNotifier, self).__init__()
self._watchable_events = frozenset(watchable_events)
self._allow_any = allow_any
def events_iter(self):
"""Returns iterator of events that can be registered/subscribed to.
NOTE(harlowja): does not include back the ``ANY`` event type as that
meta-type is not a specific event but is a capture-all that does not
imply the same meaning as specific event types.
"""
for event_type in self._watchable_events:
yield event_type
def can_be_registered(self, event_type):
"""Checks if the event can be registered/subscribed to."""
return (event_type in self._watchable_events or
(event_type == self.ANY and self._allow_any))