Use the notifier type in the task class/module directly
Instead of having code that is some what like the notifier code we already have, but is duplicated and is slightly different in the task class just move the code that was in the task class (and doing similar actions) to instead now use a notifier that is directly contained in the task base class for internal task triggering of internal task events. Breaking change: alters the capabilities of the task to process notifications itself, most actions now must go through the task notifier property and instead use that (update_progress still exists as a common utility method, since it's likely the most common type of notification that will be used). Removes the following methods from task base class (as they are no longer needed with a notifier attribute): - trigger (replaced with notifier.notify) - autobind (removed, not replaced, can be created by the user of taskflow in a simple manner, without requiring functionality in taskflow) - bind (replaced with notifier.register) - unbind (replaced with notifier.unregister) - listeners_iter (replaced with notifier.listeners_iter) Due to this change we can now also correctly proxy back events from remote tasks to the engine for correct proxying back to the original task. Fixes bug 1370766 Change-Id: Ic9dfef516d72e6e32e71dda30a1cb3522c9e0be6
This commit is contained in:
parent
cdfd8ece61
commit
1f4dd72e6e
@ -158,10 +158,11 @@ engine executor in the following manner:
|
||||
from dicts after receiving on both executor & worker sides (this
|
||||
translation is lossy since the traceback won't be fully retained).
|
||||
|
||||
Executor request format
|
||||
Executor execute format
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
* **task** - full task name to be performed
|
||||
* **task_name** - full task name to be performed
|
||||
* **task_cls** - full task class name to be performed
|
||||
* **action** - task action to be performed (e.g. execute, revert)
|
||||
* **arguments** - arguments the task action to be called with
|
||||
* **result** - task execution result (result or
|
||||
@ -180,9 +181,14 @@ Additionally, the following parameters are added to the request message:
|
||||
{
|
||||
"action": "execute",
|
||||
"arguments": {
|
||||
"joe_number": 444
|
||||
"x": 111
|
||||
},
|
||||
"task": "tasks.CallJoe"
|
||||
"task_cls": "taskflow.tests.utils.TaskOneArgOneReturn",
|
||||
"task_name": "taskflow.tests.utils.TaskOneArgOneReturn",
|
||||
"task_version": [
|
||||
1,
|
||||
0
|
||||
]
|
||||
}
|
||||
|
||||
Worker response format
|
||||
@ -193,7 +199,8 @@ When **running:**
|
||||
.. code:: json
|
||||
|
||||
{
|
||||
"status": "RUNNING"
|
||||
"data": {},
|
||||
"state": "RUNNING"
|
||||
}
|
||||
|
||||
When **progressing:**
|
||||
@ -201,9 +208,11 @@ When **progressing:**
|
||||
.. code:: json
|
||||
|
||||
{
|
||||
"event_data": <event_data>,
|
||||
"progress": <progress>,
|
||||
"state": "PROGRESS"
|
||||
"details": {
|
||||
"progress": 0.5
|
||||
},
|
||||
"event_type": "update_progress",
|
||||
"state": "EVENT"
|
||||
}
|
||||
|
||||
When **succeeded:**
|
||||
@ -211,8 +220,9 @@ When **succeeded:**
|
||||
.. code:: json
|
||||
|
||||
{
|
||||
"event": <event>,
|
||||
"result": <result>,
|
||||
"data": {
|
||||
"result": 666
|
||||
},
|
||||
"state": "SUCCESS"
|
||||
}
|
||||
|
||||
@ -221,11 +231,64 @@ When **failed:**
|
||||
.. code:: json
|
||||
|
||||
{
|
||||
"event": <event>,
|
||||
"result": <types.failure.Failure>,
|
||||
"data": {
|
||||
"result": {
|
||||
"exc_type_names": [
|
||||
"RuntimeError",
|
||||
"StandardError",
|
||||
"Exception"
|
||||
],
|
||||
"exception_str": "Woot!",
|
||||
"traceback_str": " File \"/homes/harlowja/dev/os/taskflow/taskflow/engines/action_engine/executor.py\", line 56, in _execute_task\n result = task.execute(**arguments)\n File \"/homes/harlowja/dev/os/taskflow/taskflow/tests/utils.py\", line 165, in execute\n raise RuntimeError('Woot!')\n",
|
||||
"version": 1
|
||||
}
|
||||
},
|
||||
"state": "FAILURE"
|
||||
}
|
||||
|
||||
Executor revert format
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
When **reverting:**
|
||||
|
||||
.. code:: json
|
||||
|
||||
{
|
||||
"action": "revert",
|
||||
"arguments": {},
|
||||
"failures": {
|
||||
"taskflow.tests.utils.TaskWithFailure": {
|
||||
"exc_type_names": [
|
||||
"RuntimeError",
|
||||
"StandardError",
|
||||
"Exception"
|
||||
],
|
||||
"exception_str": "Woot!",
|
||||
"traceback_str": " File \"/homes/harlowja/dev/os/taskflow/taskflow/engines/action_engine/executor.py\", line 56, in _execute_task\n result = task.execute(**arguments)\n File \"/homes/harlowja/dev/os/taskflow/taskflow/tests/utils.py\", line 165, in execute\n raise RuntimeError('Woot!')\n",
|
||||
"version": 1
|
||||
}
|
||||
},
|
||||
"result": [
|
||||
"failure",
|
||||
{
|
||||
"exc_type_names": [
|
||||
"RuntimeError",
|
||||
"StandardError",
|
||||
"Exception"
|
||||
],
|
||||
"exception_str": "Woot!",
|
||||
"traceback_str": " File \"/homes/harlowja/dev/os/taskflow/taskflow/engines/action_engine/executor.py\", line 56, in _execute_task\n result = task.execute(**arguments)\n File \"/homes/harlowja/dev/os/taskflow/taskflow/tests/utils.py\", line 165, in execute\n raise RuntimeError('Woot!')\n",
|
||||
"version": 1
|
||||
}
|
||||
],
|
||||
"task_cls": "taskflow.tests.utils.TaskWithFailure",
|
||||
"task_name": "taskflow.tests.utils.TaskWithFailure",
|
||||
"task_version": [
|
||||
1,
|
||||
0
|
||||
]
|
||||
}
|
||||
|
||||
Usage
|
||||
=====
|
||||
|
||||
|
@ -14,6 +14,8 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import functools
|
||||
|
||||
from taskflow import logging
|
||||
from taskflow import states
|
||||
from taskflow import task as task_atom
|
||||
@ -75,25 +77,37 @@ class TaskAction(object):
|
||||
if progress is not None:
|
||||
task.update_progress(progress)
|
||||
|
||||
def _on_update_progress(self, task, event_data, progress, **kwargs):
|
||||
def _on_update_progress(self, task, event_type, details):
|
||||
"""Should be called when task updates its progress."""
|
||||
try:
|
||||
self._storage.set_task_progress(task.name, progress, kwargs)
|
||||
except Exception:
|
||||
# Update progress callbacks should never fail, so capture and log
|
||||
# the emitted exception instead of raising it.
|
||||
LOG.exception("Failed setting task progress for %s to %0.3f",
|
||||
task, progress)
|
||||
progress = details.pop('progress')
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
self._storage.set_task_progress(task.name, progress,
|
||||
details=details)
|
||||
except Exception:
|
||||
# Update progress callbacks should never fail, so capture and
|
||||
# log the emitted exception instead of raising it.
|
||||
LOG.exception("Failed setting task progress for %s to %0.3f",
|
||||
task, progress)
|
||||
|
||||
def schedule_execution(self, task):
|
||||
self.change_state(task, states.RUNNING, progress=0.0)
|
||||
scope_walker = self._walker_factory(task)
|
||||
kwargs = self._storage.fetch_mapped_args(task.rebind,
|
||||
atom_name=task.name,
|
||||
scope_walker=scope_walker)
|
||||
arguments = self._storage.fetch_mapped_args(task.rebind,
|
||||
atom_name=task.name,
|
||||
scope_walker=scope_walker)
|
||||
if task.notifier.can_be_registered(task_atom.EVENT_UPDATE_PROGRESS):
|
||||
progress_callback = functools.partial(self._on_update_progress,
|
||||
task)
|
||||
else:
|
||||
progress_callback = None
|
||||
task_uuid = self._storage.get_atom_uuid(task.name)
|
||||
return self._task_executor.execute_task(task, task_uuid, kwargs,
|
||||
self._on_update_progress)
|
||||
return self._task_executor.execute_task(
|
||||
task, task_uuid, arguments,
|
||||
progress_callback=progress_callback)
|
||||
|
||||
def complete_execution(self, task, result):
|
||||
if isinstance(result, failure.Failure):
|
||||
@ -105,15 +119,20 @@ class TaskAction(object):
|
||||
def schedule_reversion(self, task):
|
||||
self.change_state(task, states.REVERTING, progress=0.0)
|
||||
scope_walker = self._walker_factory(task)
|
||||
kwargs = self._storage.fetch_mapped_args(task.rebind,
|
||||
atom_name=task.name,
|
||||
scope_walker=scope_walker)
|
||||
arguments = self._storage.fetch_mapped_args(task.rebind,
|
||||
atom_name=task.name,
|
||||
scope_walker=scope_walker)
|
||||
task_uuid = self._storage.get_atom_uuid(task.name)
|
||||
task_result = self._storage.get(task.name)
|
||||
failures = self._storage.get_failures()
|
||||
future = self._task_executor.revert_task(task, task_uuid, kwargs,
|
||||
task_result, failures,
|
||||
self._on_update_progress)
|
||||
if task.notifier.can_be_registered(task_atom.EVENT_UPDATE_PROGRESS):
|
||||
progress_callback = functools.partial(self._on_update_progress,
|
||||
task)
|
||||
else:
|
||||
progress_callback = None
|
||||
future = self._task_executor.revert_task(
|
||||
task, task_uuid, arguments, task_result, failures,
|
||||
progress_callback=progress_callback)
|
||||
return future
|
||||
|
||||
def complete_reversion(self, task, rev_result):
|
||||
|
@ -15,10 +15,11 @@
|
||||
# under the License.
|
||||
|
||||
import abc
|
||||
import contextlib
|
||||
|
||||
import six
|
||||
|
||||
from taskflow import task as _task
|
||||
from taskflow import task as task_atom
|
||||
from taskflow.types import failure
|
||||
from taskflow.types import futures
|
||||
from taskflow.utils import async_utils
|
||||
@ -29,8 +30,23 @@ EXECUTED = 'executed'
|
||||
REVERTED = 'reverted'
|
||||
|
||||
|
||||
def _execute_task(task, arguments, progress_callback):
|
||||
with task.autobind(_task.EVENT_UPDATE_PROGRESS, progress_callback):
|
||||
@contextlib.contextmanager
|
||||
def _autobind(task, progress_callback=None):
|
||||
bound = False
|
||||
if progress_callback is not None:
|
||||
task.notifier.register(task_atom.EVENT_UPDATE_PROGRESS,
|
||||
progress_callback)
|
||||
bound = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if bound:
|
||||
task.notifier.deregister(task_atom.EVENT_UPDATE_PROGRESS,
|
||||
progress_callback)
|
||||
|
||||
|
||||
def _execute_task(task, arguments, progress_callback=None):
|
||||
with _autobind(task, progress_callback=progress_callback):
|
||||
try:
|
||||
task.pre_execute()
|
||||
result = task.execute(**arguments)
|
||||
@ -43,14 +59,14 @@ def _execute_task(task, arguments, progress_callback):
|
||||
return (EXECUTED, result)
|
||||
|
||||
|
||||
def _revert_task(task, arguments, result, failures, progress_callback):
|
||||
kwargs = arguments.copy()
|
||||
kwargs[_task.REVERT_RESULT] = result
|
||||
kwargs[_task.REVERT_FLOW_FAILURES] = failures
|
||||
with task.autobind(_task.EVENT_UPDATE_PROGRESS, progress_callback):
|
||||
def _revert_task(task, arguments, result, failures, progress_callback=None):
|
||||
arguments = arguments.copy()
|
||||
arguments[task_atom.REVERT_RESULT] = result
|
||||
arguments[task_atom.REVERT_FLOW_FAILURES] = failures
|
||||
with _autobind(task, progress_callback=progress_callback):
|
||||
try:
|
||||
task.pre_revert()
|
||||
result = task.revert(**kwargs)
|
||||
result = task.revert(**arguments)
|
||||
except Exception:
|
||||
# NOTE(imelnikov): wrap current exception with Failure
|
||||
# object and return it.
|
||||
@ -98,15 +114,17 @@ class SerialTaskExecutor(TaskExecutor):
|
||||
self._executor = futures.SynchronousExecutor()
|
||||
|
||||
def execute_task(self, task, task_uuid, arguments, progress_callback=None):
|
||||
fut = self._executor.submit(_execute_task, task, arguments,
|
||||
progress_callback)
|
||||
fut = self._executor.submit(_execute_task,
|
||||
task, arguments,
|
||||
progress_callback=progress_callback)
|
||||
fut.atom = task
|
||||
return fut
|
||||
|
||||
def revert_task(self, task, task_uuid, arguments, result, failures,
|
||||
progress_callback=None):
|
||||
fut = self._executor.submit(_revert_task, task, arguments, result,
|
||||
failures, progress_callback)
|
||||
fut = self._executor.submit(_revert_task,
|
||||
task, arguments, result, failures,
|
||||
progress_callback=progress_callback)
|
||||
fut.atom = task
|
||||
return fut
|
||||
|
||||
@ -127,15 +145,17 @@ class ParallelTaskExecutor(TaskExecutor):
|
||||
self._create_executor = executor is None
|
||||
|
||||
def execute_task(self, task, task_uuid, arguments, progress_callback=None):
|
||||
fut = self._executor.submit(_execute_task, task,
|
||||
arguments, progress_callback)
|
||||
fut = self._executor.submit(_execute_task,
|
||||
task, arguments,
|
||||
progress_callback=progress_callback)
|
||||
fut.atom = task
|
||||
return fut
|
||||
|
||||
def revert_task(self, task, task_uuid, arguments, result, failures,
|
||||
progress_callback=None):
|
||||
fut = self._executor.submit(_revert_task, task, arguments,
|
||||
result, failures, progress_callback)
|
||||
fut = self._executor.submit(_revert_task,
|
||||
task, arguments, result, failures,
|
||||
progress_callback=progress_callback)
|
||||
fut.atom = task
|
||||
return fut
|
||||
|
||||
|
@ -33,18 +33,16 @@ class Endpoint(object):
|
||||
def name(self):
|
||||
return self._task_cls_name
|
||||
|
||||
def _get_task(self, name=None):
|
||||
def generate(self, name=None):
|
||||
# NOTE(skudriashev): Note that task is created here with the `name`
|
||||
# argument passed to its constructor. This will be a problem when
|
||||
# task's constructor requires any other arguments.
|
||||
return self._task_cls(name=name)
|
||||
|
||||
def execute(self, task_name, **kwargs):
|
||||
task = self._get_task(task_name)
|
||||
def execute(self, task, **kwargs):
|
||||
event, result = self._executor.execute_task(task, **kwargs).result()
|
||||
return result
|
||||
|
||||
def revert(self, task_name, **kwargs):
|
||||
task = self._get_task(task_name)
|
||||
def revert(self, task, **kwargs):
|
||||
event, result = self._executor.revert_task(task, **kwargs).result()
|
||||
return result
|
||||
|
@ -25,6 +25,7 @@ from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.engines.worker_based import proxy
|
||||
from taskflow import exceptions as exc
|
||||
from taskflow import logging
|
||||
from taskflow import task as task_atom
|
||||
from taskflow.types import timing as tt
|
||||
from taskflow.utils import async_utils
|
||||
from taskflow.utils import misc
|
||||
@ -132,8 +133,11 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
response.state, request)
|
||||
if response.state == pr.RUNNING:
|
||||
request.transition_and_log_error(pr.RUNNING, logger=LOG)
|
||||
elif response.state == pr.PROGRESS:
|
||||
request.on_progress(**response.data)
|
||||
elif response.state == pr.EVENT:
|
||||
# Proxy the event + details to the task/request notifier...
|
||||
event_type = response.data['event_type']
|
||||
details = response.data['details']
|
||||
request.notifier.notify(event_type, details)
|
||||
elif response.state in (pr.FAILURE, pr.SUCCESS):
|
||||
moved = request.transition_and_log_error(response.state,
|
||||
logger=LOG)
|
||||
@ -181,8 +185,18 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
progress_callback, **kwargs):
|
||||
"""Submit task request to a worker."""
|
||||
request = pr.Request(task, task_uuid, action, arguments,
|
||||
progress_callback, self._transition_timeout,
|
||||
**kwargs)
|
||||
self._transition_timeout, **kwargs)
|
||||
|
||||
# Register the callback, so that we can proxy the progress correctly.
|
||||
if (progress_callback is not None and
|
||||
request.notifier.can_be_registered(
|
||||
task_atom.EVENT_UPDATE_PROGRESS)):
|
||||
request.notifier.register(task_atom.EVENT_UPDATE_PROGRESS,
|
||||
progress_callback)
|
||||
cleaner = functools.partial(request.notifier.deregister,
|
||||
task_atom.EVENT_UPDATE_PROGRESS,
|
||||
progress_callback)
|
||||
request.result.add_done_callback(lambda fut: cleaner())
|
||||
|
||||
# Get task's topic and publish request if topic was found.
|
||||
topic = self._workers_cache.get_topic_by_task(request.task_cls)
|
||||
|
@ -38,14 +38,14 @@ PENDING = 'PENDING'
|
||||
RUNNING = 'RUNNING'
|
||||
SUCCESS = 'SUCCESS'
|
||||
FAILURE = 'FAILURE'
|
||||
PROGRESS = 'PROGRESS'
|
||||
EVENT = 'EVENT'
|
||||
|
||||
# During these states the expiry is active (once out of these states the expiry
|
||||
# no longer matters, since we have no way of knowing how long a task will run
|
||||
# for).
|
||||
WAITING_STATES = (WAITING, PENDING)
|
||||
|
||||
_ALL_STATES = (WAITING, PENDING, RUNNING, SUCCESS, FAILURE, PROGRESS)
|
||||
_ALL_STATES = (WAITING, PENDING, RUNNING, SUCCESS, FAILURE, EVENT)
|
||||
_STOP_TIMER_STATES = (RUNNING, SUCCESS, FAILURE)
|
||||
|
||||
# Transitions that a request state can go through.
|
||||
@ -219,22 +219,29 @@ class Request(Message):
|
||||
'required': ['task_cls', 'task_name', 'task_version', 'action'],
|
||||
}
|
||||
|
||||
def __init__(self, task, uuid, action, arguments, progress_callback,
|
||||
timeout, **kwargs):
|
||||
def __init__(self, task, uuid, action, arguments, timeout, **kwargs):
|
||||
self._task = task
|
||||
self._task_cls = reflection.get_class_name(task)
|
||||
self._uuid = uuid
|
||||
self._action = action
|
||||
self._event = ACTION_TO_EVENT[action]
|
||||
self._arguments = arguments
|
||||
self._progress_callback = progress_callback
|
||||
self._kwargs = kwargs
|
||||
self._watch = tt.StopWatch(duration=timeout).start()
|
||||
self._state = WAITING
|
||||
self._lock = threading.Lock()
|
||||
self._created_on = timeutils.utcnow()
|
||||
self.result = futures.Future()
|
||||
self.result.atom = task
|
||||
self._result = futures.Future()
|
||||
self._result.atom = task
|
||||
self._notifier = task.notifier
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
return self._result
|
||||
|
||||
@property
|
||||
def notifier(self):
|
||||
return self._notifier
|
||||
|
||||
@property
|
||||
def uuid(self):
|
||||
@ -293,9 +300,6 @@ class Request(Message):
|
||||
def set_result(self, result):
|
||||
self.result.set_result((self._event, result))
|
||||
|
||||
def on_progress(self, event_data, progress):
|
||||
self._progress_callback(self._task, event_data, progress)
|
||||
|
||||
def transition_and_log_error(self, new_state, logger=None):
|
||||
"""Transitions *and* logs an error if that transitioning raises.
|
||||
|
||||
@ -362,7 +366,7 @@ class Response(Message):
|
||||
'data': {
|
||||
"anyOf": [
|
||||
{
|
||||
"$ref": "#/definitions/progress",
|
||||
"$ref": "#/definitions/event",
|
||||
},
|
||||
{
|
||||
"$ref": "#/definitions/completion",
|
||||
@ -376,17 +380,17 @@ class Response(Message):
|
||||
"required": ["state", 'data'],
|
||||
"additionalProperties": False,
|
||||
"definitions": {
|
||||
"progress": {
|
||||
"event": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
'progress': {
|
||||
'type': 'number',
|
||||
'event_type': {
|
||||
'type': 'string',
|
||||
},
|
||||
'event_data': {
|
||||
'details': {
|
||||
'type': 'object',
|
||||
},
|
||||
},
|
||||
"required": ["progress", 'event_data'],
|
||||
"required": ["event_type", 'details'],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
# Used when sending *only* request state changes (and no data is
|
||||
|
@ -22,6 +22,7 @@ from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.engines.worker_based import proxy
|
||||
from taskflow import logging
|
||||
from taskflow.types import failure as ft
|
||||
from taskflow.types import notifier as nt
|
||||
from taskflow.utils import misc
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -73,18 +74,23 @@ class Server(object):
|
||||
All `failure.Failure` objects that have been converted to dict on the
|
||||
remote side will now converted back to `failure.Failure` objects.
|
||||
"""
|
||||
action_args = dict(arguments=arguments, task_name=task_name)
|
||||
# These arguments will eventually be given to the task executor
|
||||
# so they need to be in a format it will accept (and using keyword
|
||||
# argument names that it accepts)...
|
||||
arguments = {
|
||||
'arguments': arguments,
|
||||
}
|
||||
if result is not None:
|
||||
data_type, data = result
|
||||
if data_type == 'failure':
|
||||
action_args['result'] = ft.Failure.from_dict(data)
|
||||
arguments['result'] = ft.Failure.from_dict(data)
|
||||
else:
|
||||
action_args['result'] = data
|
||||
arguments['result'] = data
|
||||
if failures is not None:
|
||||
action_args['failures'] = {}
|
||||
arguments['failures'] = {}
|
||||
for key, data in six.iteritems(failures):
|
||||
action_args['failures'][key] = ft.Failure.from_dict(data)
|
||||
return task_cls, action, action_args
|
||||
arguments['failures'][key] = ft.Failure.from_dict(data)
|
||||
return (task_cls, task_name, action, arguments)
|
||||
|
||||
@staticmethod
|
||||
def _parse_message(message):
|
||||
@ -122,14 +128,13 @@ class Server(object):
|
||||
exc_info=True)
|
||||
return published
|
||||
|
||||
def _on_update_progress(self, reply_to, task_uuid, task, event_data,
|
||||
progress):
|
||||
"""Send task update progress notification."""
|
||||
def _on_event(self, reply_to, task_uuid, event_type, details):
|
||||
"""Send out a task event notification."""
|
||||
# NOTE(harlowja): the executor that will trigger this using the
|
||||
# task notification/listener mechanism will handle logging if this
|
||||
# fails, so thats why capture is 'False' is used here.
|
||||
self._reply(False, reply_to, task_uuid, pr.PROGRESS,
|
||||
event_data=event_data, progress=progress)
|
||||
self._reply(False, reply_to, task_uuid, pr.EVENT,
|
||||
event_type=event_type, details=details)
|
||||
|
||||
def _process_notify(self, notify, message):
|
||||
"""Process notify message and reply back."""
|
||||
@ -165,18 +170,15 @@ class Server(object):
|
||||
message.delivery_tag, exc_info=True)
|
||||
return
|
||||
else:
|
||||
# prepare task progress callback
|
||||
progress_callback = functools.partial(self._on_update_progress,
|
||||
reply_to, task_uuid)
|
||||
# prepare reply callback
|
||||
reply_callback = functools.partial(self._reply, True, reply_to,
|
||||
task_uuid)
|
||||
|
||||
# parse request to get task name, action and action arguments
|
||||
try:
|
||||
task_cls, action, action_args = self._parse_request(**request)
|
||||
action_args.update(task_uuid=task_uuid,
|
||||
progress_callback=progress_callback)
|
||||
bundle = self._parse_request(**request)
|
||||
task_cls, task_name, action, arguments = bundle
|
||||
arguments['task_uuid'] = task_uuid
|
||||
except ValueError:
|
||||
with misc.capture_failure() as failure:
|
||||
LOG.warn("Failed to parse request contents from message %r",
|
||||
@ -206,12 +208,36 @@ class Server(object):
|
||||
reply_callback(result=failure.to_dict())
|
||||
return
|
||||
else:
|
||||
if not reply_callback(state=pr.RUNNING):
|
||||
return
|
||||
try:
|
||||
task = endpoint.generate(name=task_name)
|
||||
except Exception:
|
||||
with misc.capture_failure() as failure:
|
||||
LOG.warn("The '%s' task '%s' generation for request"
|
||||
" message %r failed", endpoint, action,
|
||||
message.delivery_tag, exc_info=True)
|
||||
reply_callback(result=failure.to_dict())
|
||||
return
|
||||
else:
|
||||
if not reply_callback(state=pr.RUNNING):
|
||||
return
|
||||
|
||||
# perform task action
|
||||
# associate *any* events this task emits with a proxy that will
|
||||
# emit them back to the engine... for handling at the engine side
|
||||
# of things...
|
||||
if task.notifier.can_be_registered(nt.Notifier.ANY):
|
||||
task.notifier.register(nt.Notifier.ANY,
|
||||
functools.partial(self._on_event,
|
||||
reply_to, task_uuid))
|
||||
elif isinstance(task.notifier, nt.RestrictedNotifier):
|
||||
# only proxy the allowable events then...
|
||||
for event_type in task.notifier.events_iter():
|
||||
task.notifier.register(event_type,
|
||||
functools.partial(self._on_event,
|
||||
reply_to, task_uuid))
|
||||
|
||||
# perform the task action
|
||||
try:
|
||||
result = handler(**action_args)
|
||||
result = handler(task, **arguments)
|
||||
except Exception:
|
||||
with misc.capture_failure() as failure:
|
||||
LOG.warn("The '%s' endpoint '%s' execution for request"
|
||||
|
166
taskflow/task.py
166
taskflow/task.py
@ -16,14 +16,13 @@
|
||||
# under the License.
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import contextlib
|
||||
import copy
|
||||
|
||||
import six
|
||||
|
||||
from taskflow import atom
|
||||
from taskflow import logging
|
||||
from taskflow.types import notifier
|
||||
from taskflow.utils import misc
|
||||
from taskflow.utils import reflection
|
||||
|
||||
@ -51,18 +50,28 @@ class BaseTask(atom.Atom):
|
||||
same piece of work.
|
||||
"""
|
||||
|
||||
# Known events this task can have callbacks bound to (others that are not
|
||||
# in this set/tuple will not be able to be bound); this should be updated
|
||||
# and/or extended in subclasses as needed to enable or disable new or
|
||||
# existing events...
|
||||
# Known internal events this task can have callbacks bound to (others that
|
||||
# are not in this set/tuple will not be able to be bound); this should be
|
||||
# updated and/or extended in subclasses as needed to enable or disable new
|
||||
# or existing internal events...
|
||||
TASK_EVENTS = (EVENT_UPDATE_PROGRESS,)
|
||||
|
||||
def __init__(self, name, provides=None, inject=None):
|
||||
if name is None:
|
||||
name = reflection.get_class_name(self)
|
||||
super(BaseTask, self).__init__(name, provides, inject=inject)
|
||||
# Map of events => lists of callbacks to invoke on task events.
|
||||
self._events_listeners = collections.defaultdict(list)
|
||||
self._notifier = notifier.RestrictedNotifier(self.TASK_EVENTS)
|
||||
|
||||
@property
|
||||
def notifier(self):
|
||||
"""Internal notification dispatcher/registry.
|
||||
|
||||
A notification object that will dispatch events that occur related
|
||||
to *internal* notifications that the task internally emits to
|
||||
listeners (for example for progress status updates, telling others
|
||||
that a task has reached 50% completion...).
|
||||
"""
|
||||
return self._notifier
|
||||
|
||||
def pre_execute(self):
|
||||
"""Code to be run prior to executing the task.
|
||||
@ -138,152 +147,31 @@ class BaseTask(atom.Atom):
|
||||
def copy(self, retain_listeners=True):
|
||||
"""Clone/copy this task.
|
||||
|
||||
:param retain_listeners: retain the attached listeners when cloning,
|
||||
when false the listeners will be emptied, when
|
||||
true the listeners will be copied and retained
|
||||
:param retain_listeners: retain the attached notification listeners
|
||||
when cloning, when false the listeners will
|
||||
be emptied, when true the listeners will be
|
||||
copied and retained
|
||||
|
||||
:return: the copied task
|
||||
"""
|
||||
c = copy.copy(self)
|
||||
c._events_listeners = collections.defaultdict(list)
|
||||
if retain_listeners:
|
||||
for event_name, listeners in six.iteritems(self._events_listeners):
|
||||
c._events_listeners[event_name] = listeners[:]
|
||||
c._notifier = self._notifier.copy()
|
||||
if not retain_listeners:
|
||||
c._notifier.reset()
|
||||
return c
|
||||
|
||||
def update_progress(self, progress, **kwargs):
|
||||
def update_progress(self, progress):
|
||||
"""Update task progress and notify all registered listeners.
|
||||
|
||||
:param progress: task progress float value between 0.0 and 1.0
|
||||
:param kwargs: any keyword arguments that are tied to the specific
|
||||
progress value.
|
||||
"""
|
||||
def on_clamped():
|
||||
LOG.warn("Progress value must be greater or equal to 0.0 or less"
|
||||
" than or equal to 1.0 instead of being '%s'", progress)
|
||||
cleaned_progress = misc.clamp(progress, 0.0, 1.0,
|
||||
on_clamped=on_clamped)
|
||||
self.trigger(EVENT_UPDATE_PROGRESS, cleaned_progress, **kwargs)
|
||||
|
||||
def trigger(self, event_name, *args, **kwargs):
|
||||
"""Execute all callbacks registered for the given event type.
|
||||
|
||||
NOTE(harlowja): if a bound callback raises an exception it will be
|
||||
logged (at a ``WARNING`` level) and the exception
|
||||
will be dropped.
|
||||
|
||||
:param event_name: event name to trigger
|
||||
:param args: arbitrary positional arguments passed to the triggered
|
||||
callbacks (if any are matched), these will be in addition
|
||||
to any ``kwargs`` provided on binding (these are passed
|
||||
as positional arguments to the callback).
|
||||
:param kwargs: arbitrary keyword arguments passed to the triggered
|
||||
callbacks (if any are matched), these will be in addition
|
||||
to any ``kwargs`` provided on binding (these are passed
|
||||
as keyword arguments to the callback).
|
||||
"""
|
||||
for (cb, event_data) in self._events_listeners.get(event_name, []):
|
||||
try:
|
||||
cb(self, event_data, *args, **kwargs)
|
||||
except Exception:
|
||||
LOG.warn("Failed calling callback `%s` on event '%s'",
|
||||
reflection.get_callable_name(cb), event_name,
|
||||
exc_info=True)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def autobind(self, event_name, callback, **kwargs):
|
||||
"""Binds & unbinds a given callback to the task.
|
||||
|
||||
This function binds and unbinds using the context manager protocol.
|
||||
When events are triggered on the task of the given event name this
|
||||
callback will automatically be called with the provided
|
||||
keyword arguments as the first argument (further arguments may be
|
||||
provided by the entity triggering the event).
|
||||
|
||||
The arguments are interpreted as for :func:`bind() <bind>`.
|
||||
"""
|
||||
bound = False
|
||||
if callback is not None:
|
||||
try:
|
||||
self.bind(event_name, callback, **kwargs)
|
||||
bound = True
|
||||
except ValueError:
|
||||
LOG.warn("Failed binding callback `%s` as a receiver of"
|
||||
" event '%s' notifications emitted from task '%s'",
|
||||
reflection.get_callable_name(callback), event_name,
|
||||
self, exc_info=True)
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
if bound:
|
||||
self.unbind(event_name, callback)
|
||||
|
||||
def bind(self, event_name, callback, **kwargs):
|
||||
"""Attach a callback to be triggered on a task event.
|
||||
|
||||
Callbacks should *not* be bound, modified, or removed after execution
|
||||
has commenced (they may be adjusted after execution has finished). This
|
||||
is primarily due to the need to preserve the callbacks that exist at
|
||||
execution time for engines which run tasks remotely or out of
|
||||
process (so that those engines can correctly proxy back transmitted
|
||||
events).
|
||||
|
||||
Callbacks should also be *quick* to execute so that the engine calling
|
||||
them can continue execution in a timely manner (if long running
|
||||
callbacks need to exist, consider creating a separate pool + queue
|
||||
for those that the attached callbacks put long running operations into
|
||||
for execution by other entities).
|
||||
|
||||
:param event_name: event type name
|
||||
:param callback: callable to execute each time event is triggered
|
||||
:param kwargs: optional named parameters that will be passed to the
|
||||
callable object as a dictionary to the callbacks
|
||||
*second* positional parameter.
|
||||
:raises ValueError: if invalid event type, or callback is passed
|
||||
"""
|
||||
if event_name not in self.TASK_EVENTS:
|
||||
raise ValueError("Unknown task event '%s', can only bind"
|
||||
" to events %s" % (event_name, self.TASK_EVENTS))
|
||||
if callback is not None:
|
||||
if not six.callable(callback):
|
||||
raise ValueError("Event handler callback must be callable")
|
||||
self._events_listeners[event_name].append((callback, kwargs))
|
||||
|
||||
def unbind(self, event_name, callback=None):
|
||||
"""Remove a previously-attached event callback from the task.
|
||||
|
||||
If a callback is not passed, then this will unbind *all* event
|
||||
callbacks for the provided event. If multiple of the same callbacks
|
||||
are bound, then the first match is removed (and only the first match).
|
||||
|
||||
:param event_name: event type
|
||||
:param callback: callback previously bound
|
||||
|
||||
:rtype: boolean
|
||||
:return: whether anything was removed
|
||||
"""
|
||||
removed_any = False
|
||||
if not callback:
|
||||
removed_any = self._events_listeners.pop(event_name, removed_any)
|
||||
else:
|
||||
event_listeners = self._events_listeners.get(event_name, [])
|
||||
for i, (cb, _event_data) in enumerate(event_listeners):
|
||||
if reflection.is_same_callback(cb, callback):
|
||||
# NOTE(harlowja): its safe to do this as long as we stop
|
||||
# iterating after we do the removal, otherwise its not
|
||||
# safe (since this could have resized the list).
|
||||
event_listeners.pop(i)
|
||||
removed_any = True
|
||||
break
|
||||
return bool(removed_any)
|
||||
|
||||
def listeners_iter(self):
|
||||
"""Return an iterator over the mapping of event => callbacks bound."""
|
||||
for event_name in list(six.iterkeys(self._events_listeners)):
|
||||
# Use get() just incase it was removed while iterating...
|
||||
event_listeners = self._events_listeners.get(event_name, [])
|
||||
if event_listeners:
|
||||
yield (event_name, event_listeners[:])
|
||||
self._notifier.notify(EVENT_UPDATE_PROGRESS,
|
||||
{'progress': cleaned_progress})
|
||||
|
||||
|
||||
class Task(BaseTask):
|
||||
|
@ -84,6 +84,31 @@ class NotifierTest(test.TestCase):
|
||||
self.assertRaises(ValueError, notifier.register,
|
||||
nt.Notifier.ANY, 2)
|
||||
|
||||
def test_restricted_notifier(self):
|
||||
notifier = nt.RestrictedNotifier(['a', 'b'])
|
||||
self.assertRaises(ValueError, notifier.register,
|
||||
'c', lambda *args, **kargs: None)
|
||||
notifier.register('b', lambda *args, **kargs: None)
|
||||
self.assertEqual(1, len(notifier))
|
||||
|
||||
def test_restricted_notifier_any(self):
|
||||
notifier = nt.RestrictedNotifier(['a', 'b'])
|
||||
self.assertRaises(ValueError, notifier.register,
|
||||
'c', lambda *args, **kargs: None)
|
||||
notifier.register('b', lambda *args, **kargs: None)
|
||||
self.assertEqual(1, len(notifier))
|
||||
notifier.register(nt.RestrictedNotifier.ANY,
|
||||
lambda *args, **kargs: None)
|
||||
self.assertEqual(2, len(notifier))
|
||||
|
||||
def test_restricted_notifier_no_any(self):
|
||||
notifier = nt.RestrictedNotifier(['a', 'b'], allow_any=False)
|
||||
self.assertRaises(ValueError, notifier.register,
|
||||
nt.RestrictedNotifier.ANY,
|
||||
lambda *args, **kargs: None)
|
||||
notifier.register('b', lambda *args, **kargs: None)
|
||||
self.assertEqual(1, len(notifier))
|
||||
|
||||
def test_selective_notify(self):
|
||||
call_counts = collections.defaultdict(list)
|
||||
|
||||
|
@ -39,7 +39,12 @@ class ProgressTask(task.Task):
|
||||
|
||||
class ProgressTaskWithDetails(task.Task):
|
||||
def execute(self):
|
||||
self.update_progress(0.5, test='test data', foo='bar')
|
||||
details = {
|
||||
'progress': 0.5,
|
||||
'test': 'test data',
|
||||
'foo': 'bar',
|
||||
}
|
||||
self.notifier.notify(task.EVENT_UPDATE_PROGRESS, details)
|
||||
|
||||
|
||||
class TestProgress(test.TestCase):
|
||||
@ -60,12 +65,12 @@ class TestProgress(test.TestCase):
|
||||
def test_sanity_progress(self):
|
||||
fired_events = []
|
||||
|
||||
def notify_me(task, event_data, progress):
|
||||
fired_events.append(progress)
|
||||
def notify_me(event_type, details):
|
||||
fired_events.append(details.pop('progress'))
|
||||
|
||||
ev_count = 5
|
||||
t = ProgressTask("test", ev_count)
|
||||
t.bind('update_progress', notify_me)
|
||||
t.notifier.register(task.EVENT_UPDATE_PROGRESS, notify_me)
|
||||
flo = lf.Flow("test")
|
||||
flo.add(t)
|
||||
e = self._make_engine(flo)
|
||||
@ -77,11 +82,11 @@ class TestProgress(test.TestCase):
|
||||
def test_no_segments_progress(self):
|
||||
fired_events = []
|
||||
|
||||
def notify_me(task, event_data, progress):
|
||||
fired_events.append(progress)
|
||||
def notify_me(event_type, details):
|
||||
fired_events.append(details.pop('progress'))
|
||||
|
||||
t = ProgressTask("test", 0)
|
||||
t.bind('update_progress', notify_me)
|
||||
t.notifier.register(task.EVENT_UPDATE_PROGRESS, notify_me)
|
||||
flo = lf.Flow("test")
|
||||
flo.add(t)
|
||||
e = self._make_engine(flo)
|
||||
@ -121,12 +126,12 @@ class TestProgress(test.TestCase):
|
||||
def test_dual_storage_progress(self):
|
||||
fired_events = []
|
||||
|
||||
def notify_me(task, event_data, progress):
|
||||
fired_events.append(progress)
|
||||
def notify_me(event_type, details):
|
||||
fired_events.append(details.pop('progress'))
|
||||
|
||||
with contextlib.closing(impl_memory.MemoryBackend({})) as be:
|
||||
t = ProgressTask("test", 5)
|
||||
t.bind('update_progress', notify_me)
|
||||
t.notifier.register(task.EVENT_UPDATE_PROGRESS, notify_me)
|
||||
flo = lf.Flow("test")
|
||||
flo.add(t)
|
||||
b, fd = p_utils.temporary_flow_detail(be)
|
||||
|
@ -17,7 +17,7 @@
|
||||
from taskflow import task
|
||||
from taskflow import test
|
||||
from taskflow.test import mock
|
||||
from taskflow.utils import reflection
|
||||
from taskflow.types import notifier
|
||||
|
||||
|
||||
class MyTask(task.Task):
|
||||
@ -198,24 +198,24 @@ class TaskTest(test.TestCase):
|
||||
values = [0.0, 0.5, 1.0]
|
||||
result = []
|
||||
|
||||
def progress_callback(task, event_data, progress):
|
||||
result.append(progress)
|
||||
def progress_callback(event_type, details):
|
||||
result.append(details.pop('progress'))
|
||||
|
||||
a_task = ProgressTask()
|
||||
with a_task.autobind(task.EVENT_UPDATE_PROGRESS, progress_callback):
|
||||
a_task.execute(values)
|
||||
a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, progress_callback)
|
||||
a_task.execute(values)
|
||||
self.assertEqual(result, values)
|
||||
|
||||
@mock.patch.object(task.LOG, 'warn')
|
||||
def test_update_progress_lower_bound(self, mocked_warn):
|
||||
result = []
|
||||
|
||||
def progress_callback(task, event_data, progress):
|
||||
result.append(progress)
|
||||
def progress_callback(event_type, details):
|
||||
result.append(details.pop('progress'))
|
||||
|
||||
a_task = ProgressTask()
|
||||
with a_task.autobind(task.EVENT_UPDATE_PROGRESS, progress_callback):
|
||||
a_task.execute([-1.0, -0.5, 0.0])
|
||||
a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, progress_callback)
|
||||
a_task.execute([-1.0, -0.5, 0.0])
|
||||
self.assertEqual(result, [0.0, 0.0, 0.0])
|
||||
self.assertEqual(mocked_warn.call_count, 2)
|
||||
|
||||
@ -223,81 +223,87 @@ class TaskTest(test.TestCase):
|
||||
def test_update_progress_upper_bound(self, mocked_warn):
|
||||
result = []
|
||||
|
||||
def progress_callback(task, event_data, progress):
|
||||
result.append(progress)
|
||||
def progress_callback(event_type, details):
|
||||
result.append(details.pop('progress'))
|
||||
|
||||
a_task = ProgressTask()
|
||||
with a_task.autobind(task.EVENT_UPDATE_PROGRESS, progress_callback):
|
||||
a_task.execute([1.0, 1.5, 2.0])
|
||||
a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, progress_callback)
|
||||
a_task.execute([1.0, 1.5, 2.0])
|
||||
self.assertEqual(result, [1.0, 1.0, 1.0])
|
||||
self.assertEqual(mocked_warn.call_count, 2)
|
||||
|
||||
@mock.patch.object(task.LOG, 'warn')
|
||||
@mock.patch.object(notifier.LOG, 'warn')
|
||||
def test_update_progress_handler_failure(self, mocked_warn):
|
||||
|
||||
def progress_callback(*args, **kwargs):
|
||||
raise Exception('Woot!')
|
||||
|
||||
a_task = ProgressTask()
|
||||
with a_task.autobind(task.EVENT_UPDATE_PROGRESS, progress_callback):
|
||||
a_task.execute([0.5])
|
||||
mocked_warn.assert_called_once_with(
|
||||
mock.ANY, reflection.get_callable_name(progress_callback),
|
||||
task.EVENT_UPDATE_PROGRESS, exc_info=mock.ANY)
|
||||
a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, progress_callback)
|
||||
a_task.execute([0.5])
|
||||
mocked_warn.assert_called_once()
|
||||
|
||||
def test_autobind_handler_is_none(self):
|
||||
def test_register_handler_is_none(self):
|
||||
a_task = MyTask()
|
||||
with a_task.autobind(task.EVENT_UPDATE_PROGRESS, None):
|
||||
self.assertEqual(len(list(a_task.listeners_iter())), 0)
|
||||
self.assertRaises(ValueError, a_task.notifier.register,
|
||||
task.EVENT_UPDATE_PROGRESS, None)
|
||||
self.assertEqual(len(a_task.notifier), 0)
|
||||
|
||||
def test_unbind_any_handler(self):
|
||||
def test_deregister_any_handler(self):
|
||||
a_task = MyTask()
|
||||
self.assertEqual(len(list(a_task.listeners_iter())), 0)
|
||||
a_task.bind(task.EVENT_UPDATE_PROGRESS, lambda: None)
|
||||
self.assertEqual(len(list(a_task.listeners_iter())), 1)
|
||||
self.assertTrue(a_task.unbind(task.EVENT_UPDATE_PROGRESS))
|
||||
self.assertEqual(len(list(a_task.listeners_iter())), 0)
|
||||
self.assertEqual(len(a_task.notifier), 0)
|
||||
a_task.notifier.register(task.EVENT_UPDATE_PROGRESS,
|
||||
lambda event_type, details: None)
|
||||
self.assertEqual(len(a_task.notifier), 1)
|
||||
a_task.notifier.deregister_event(task.EVENT_UPDATE_PROGRESS)
|
||||
self.assertEqual(len(a_task.notifier), 0)
|
||||
|
||||
def test_unbind_any_handler_empty_listeners(self):
|
||||
def test_deregister_any_handler_empty_listeners(self):
|
||||
a_task = MyTask()
|
||||
self.assertEqual(len(list(a_task.listeners_iter())), 0)
|
||||
self.assertFalse(a_task.unbind(task.EVENT_UPDATE_PROGRESS))
|
||||
self.assertEqual(len(list(a_task.listeners_iter())), 0)
|
||||
self.assertEqual(len(a_task.notifier), 0)
|
||||
self.assertFalse(a_task.notifier.deregister_event(
|
||||
task.EVENT_UPDATE_PROGRESS))
|
||||
self.assertEqual(len(a_task.notifier), 0)
|
||||
|
||||
def test_unbind_non_existent_listener(self):
|
||||
handler1 = lambda: None
|
||||
handler2 = lambda: None
|
||||
def test_deregister_non_existent_listener(self):
|
||||
handler1 = lambda event_type, details: None
|
||||
handler2 = lambda event_type, details: None
|
||||
a_task = MyTask()
|
||||
a_task.bind(task.EVENT_UPDATE_PROGRESS, handler1)
|
||||
self.assertEqual(len(list(a_task.listeners_iter())), 1)
|
||||
self.assertFalse(a_task.unbind(task.EVENT_UPDATE_PROGRESS, handler2))
|
||||
self.assertEqual(len(list(a_task.listeners_iter())), 1)
|
||||
a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, handler1)
|
||||
self.assertEqual(len(list(a_task.notifier.listeners_iter())), 1)
|
||||
a_task.notifier.deregister(task.EVENT_UPDATE_PROGRESS, handler2)
|
||||
self.assertEqual(len(list(a_task.notifier.listeners_iter())), 1)
|
||||
a_task.notifier.deregister(task.EVENT_UPDATE_PROGRESS, handler1)
|
||||
self.assertEqual(len(list(a_task.notifier.listeners_iter())), 0)
|
||||
|
||||
def test_bind_not_callable(self):
|
||||
task = MyTask()
|
||||
self.assertRaises(ValueError, task.bind, 'update_progress', 2)
|
||||
a_task = MyTask()
|
||||
self.assertRaises(ValueError, a_task.notifier.register,
|
||||
task.EVENT_UPDATE_PROGRESS, 2)
|
||||
|
||||
def test_copy_no_listeners(self):
|
||||
handler1 = lambda: None
|
||||
handler1 = lambda event_type, details: None
|
||||
a_task = MyTask()
|
||||
a_task.bind(task.EVENT_UPDATE_PROGRESS, handler1)
|
||||
a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, handler1)
|
||||
b_task = a_task.copy(retain_listeners=False)
|
||||
self.assertEqual(len(list(a_task.listeners_iter())), 1)
|
||||
self.assertEqual(len(list(b_task.listeners_iter())), 0)
|
||||
self.assertEqual(len(a_task.notifier), 1)
|
||||
self.assertEqual(len(b_task.notifier), 0)
|
||||
|
||||
def test_copy_listeners(self):
|
||||
handler1 = lambda: None
|
||||
handler2 = lambda: None
|
||||
handler1 = lambda event_type, details: None
|
||||
handler2 = lambda event_type, details: None
|
||||
a_task = MyTask()
|
||||
a_task.bind(task.EVENT_UPDATE_PROGRESS, handler1)
|
||||
a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, handler1)
|
||||
b_task = a_task.copy()
|
||||
self.assertEqual(len(list(b_task.listeners_iter())), 1)
|
||||
self.assertTrue(a_task.unbind(task.EVENT_UPDATE_PROGRESS))
|
||||
self.assertEqual(len(list(a_task.listeners_iter())), 0)
|
||||
self.assertEqual(len(list(b_task.listeners_iter())), 1)
|
||||
b_task.bind(task.EVENT_UPDATE_PROGRESS, handler2)
|
||||
listeners = dict(list(b_task.listeners_iter()))
|
||||
self.assertEqual(len(b_task.notifier), 1)
|
||||
self.assertTrue(a_task.notifier.deregister_event(
|
||||
task.EVENT_UPDATE_PROGRESS))
|
||||
self.assertEqual(len(a_task.notifier), 0)
|
||||
self.assertEqual(len(b_task.notifier), 1)
|
||||
b_task.notifier.register(task.EVENT_UPDATE_PROGRESS, handler2)
|
||||
listeners = dict(list(b_task.notifier.listeners_iter()))
|
||||
self.assertEqual(len(listeners[task.EVENT_UPDATE_PROGRESS]), 2)
|
||||
self.assertEqual(len(list(a_task.listeners_iter())), 0)
|
||||
self.assertEqual(len(a_task.notifier), 0)
|
||||
|
||||
|
||||
class FunctorTaskTest(test.TestCase):
|
||||
|
@ -42,14 +42,14 @@ class TestEndpoint(test.TestCase):
|
||||
self.task_result = 1
|
||||
|
||||
def test_creation(self):
|
||||
task = self.task_ep._get_task()
|
||||
task = self.task_ep.generate()
|
||||
self.assertEqual(self.task_ep.name, self.task_cls_name)
|
||||
self.assertIsInstance(task, self.task_cls)
|
||||
self.assertEqual(task.name, self.task_cls_name)
|
||||
|
||||
def test_creation_with_task_name(self):
|
||||
task_name = 'test'
|
||||
task = self.task_ep._get_task(name=task_name)
|
||||
task = self.task_ep.generate(name=task_name)
|
||||
self.assertEqual(self.task_ep.name, self.task_cls_name)
|
||||
self.assertIsInstance(task, self.task_cls)
|
||||
self.assertEqual(task.name, task_name)
|
||||
@ -58,20 +58,22 @@ class TestEndpoint(test.TestCase):
|
||||
# NOTE(skudriashev): Exception is expected here since task
|
||||
# is created without any arguments passing to its constructor.
|
||||
endpoint = ep.Endpoint(Task)
|
||||
self.assertRaises(TypeError, endpoint._get_task)
|
||||
self.assertRaises(TypeError, endpoint.generate)
|
||||
|
||||
def test_to_str(self):
|
||||
self.assertEqual(str(self.task_ep), self.task_cls_name)
|
||||
|
||||
def test_execute(self):
|
||||
result = self.task_ep.execute(task_name=self.task_cls_name,
|
||||
task = self.task_ep.generate(self.task_cls_name)
|
||||
result = self.task_ep.execute(task,
|
||||
task_uuid=self.task_uuid,
|
||||
arguments=self.task_args,
|
||||
progress_callback=None)
|
||||
self.assertEqual(result, self.task_result)
|
||||
|
||||
def test_revert(self):
|
||||
result = self.task_ep.revert(task_name=self.task_cls_name,
|
||||
task = self.task_ep.generate(self.task_cls_name)
|
||||
result = self.task_ep.revert(task,
|
||||
task_uuid=self.task_uuid,
|
||||
arguments=self.task_args,
|
||||
progress_callback=None,
|
||||
|
@ -21,6 +21,7 @@ from oslo.utils import timeutils
|
||||
|
||||
from taskflow.engines.worker_based import executor
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow import task as task_atom
|
||||
from taskflow import test
|
||||
from taskflow.test import mock
|
||||
from taskflow.tests import utils as test_utils
|
||||
@ -102,13 +103,18 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
self.assertEqual(expected_calls, self.request_inst_mock.mock_calls)
|
||||
|
||||
def test_on_message_response_state_progress(self):
|
||||
response = pr.Response(pr.PROGRESS, progress=1.0)
|
||||
response = pr.Response(pr.EVENT,
|
||||
event_type=task_atom.EVENT_UPDATE_PROGRESS,
|
||||
details={'progress': 1.0})
|
||||
ex = self.executor()
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._process_response(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_inst_mock.mock_calls,
|
||||
[mock.call.on_progress(progress=1.0)])
|
||||
expected_calls = [
|
||||
mock.call.notifier.notify(task_atom.EVENT_UPDATE_PROGRESS,
|
||||
{'progress': 1.0}),
|
||||
]
|
||||
self.assertEqual(expected_calls, self.request_inst_mock.mock_calls)
|
||||
|
||||
def test_on_message_response_state_failure(self):
|
||||
a_failure = failure.Failure.from_exception(Exception('test'))
|
||||
@ -211,7 +217,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
|
||||
expected_calls = [
|
||||
mock.call.Request(self.task, self.task_uuid, 'execute',
|
||||
self.task_args, None, self.timeout),
|
||||
self.task_args, self.timeout),
|
||||
mock.call.request.transition_and_log_error(pr.PENDING,
|
||||
logger=mock.ANY),
|
||||
mock.call.proxy.publish(self.request_inst_mock,
|
||||
@ -231,7 +237,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
|
||||
expected_calls = [
|
||||
mock.call.Request(self.task, self.task_uuid, 'revert',
|
||||
self.task_args, None, self.timeout,
|
||||
self.task_args, self.timeout,
|
||||
failures=self.task_failures,
|
||||
result=self.task_result),
|
||||
mock.call.request.transition_and_log_error(pr.PENDING,
|
||||
@ -250,7 +256,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
|
||||
expected_calls = [
|
||||
mock.call.Request(self.task, self.task_uuid, 'execute',
|
||||
self.task_args, None, self.timeout)
|
||||
self.task_args, self.timeout),
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, expected_calls)
|
||||
|
||||
@ -264,7 +270,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
|
||||
expected_calls = [
|
||||
mock.call.Request(self.task, self.task_uuid, 'execute',
|
||||
self.task_args, None, self.timeout),
|
||||
self.task_args, self.timeout),
|
||||
mock.call.request.transition_and_log_error(pr.PENDING,
|
||||
logger=mock.ANY),
|
||||
mock.call.proxy.publish(self.request_inst_mock,
|
||||
|
@ -118,7 +118,7 @@ class TestMessagePump(test.TestCase):
|
||||
else:
|
||||
p.publish(pr.Request(test_utils.DummyTask("dummy_%s" % i),
|
||||
uuidutils.generate_uuid(),
|
||||
pr.EXECUTE, [], None, None), TEST_TOPIC)
|
||||
pr.EXECUTE, [], None), TEST_TOPIC)
|
||||
|
||||
self.assertTrue(barrier.wait(test_utils.WAIT_TIMEOUT))
|
||||
self.assertEqual(0, barrier.needed)
|
||||
|
@ -22,7 +22,6 @@ from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow import exceptions as excp
|
||||
from taskflow.openstack.common import uuidutils
|
||||
from taskflow import test
|
||||
from taskflow.test import mock
|
||||
from taskflow.tests import utils
|
||||
from taskflow.types import failure
|
||||
|
||||
@ -53,7 +52,7 @@ class TestProtocolValidation(test.TestCase):
|
||||
|
||||
def test_request(self):
|
||||
msg = pr.Request(utils.DummyTask("hi"), uuidutils.generate_uuid(),
|
||||
pr.EXECUTE, {}, None, 1.0)
|
||||
pr.EXECUTE, {}, 1.0)
|
||||
pr.Request.validate(msg.to_dict())
|
||||
|
||||
def test_request_invalid(self):
|
||||
@ -66,13 +65,14 @@ class TestProtocolValidation(test.TestCase):
|
||||
|
||||
def test_request_invalid_action(self):
|
||||
msg = pr.Request(utils.DummyTask("hi"), uuidutils.generate_uuid(),
|
||||
pr.EXECUTE, {}, None, 1.0)
|
||||
pr.EXECUTE, {}, 1.0)
|
||||
msg = msg.to_dict()
|
||||
msg['action'] = 'NOTHING'
|
||||
self.assertRaises(excp.InvalidFormat, pr.Request.validate, msg)
|
||||
|
||||
def test_response_progress(self):
|
||||
msg = pr.Response(pr.PROGRESS, progress=0.5, event_data={})
|
||||
msg = pr.Response(pr.EVENT, details={'progress': 0.5},
|
||||
event_type='blah')
|
||||
pr.Response.validate(msg.to_dict())
|
||||
|
||||
def test_response_completion(self):
|
||||
@ -80,7 +80,9 @@ class TestProtocolValidation(test.TestCase):
|
||||
pr.Response.validate(msg.to_dict())
|
||||
|
||||
def test_response_mixed_invalid(self):
|
||||
msg = pr.Response(pr.PROGRESS, progress=0.5, event_data={}, result=1)
|
||||
msg = pr.Response(pr.EVENT,
|
||||
details={'progress': 0.5},
|
||||
event_type='blah', result=1)
|
||||
self.assertRaises(excp.InvalidFormat, pr.Response.validate, msg)
|
||||
|
||||
def test_response_bad_state(self):
|
||||
@ -184,16 +186,3 @@ class TestProtocol(test.TestCase):
|
||||
request.set_result(111)
|
||||
result = request.result.result()
|
||||
self.assertEqual(result, (executor.EXECUTED, 111))
|
||||
|
||||
def test_on_progress(self):
|
||||
progress_callback = mock.MagicMock(name='progress_callback')
|
||||
request = self.request(task=self.task,
|
||||
progress_callback=progress_callback)
|
||||
request.on_progress('event_data', 0.0)
|
||||
request.on_progress('event_data', 1.0)
|
||||
|
||||
expected_calls = [
|
||||
mock.call(self.task, 'event_data', 0.0),
|
||||
mock.call(self.task, 'event_data', 1.0)
|
||||
]
|
||||
self.assertEqual(progress_callback.mock_calls, expected_calls)
|
||||
|
@ -19,6 +19,7 @@ import six
|
||||
from taskflow.engines.worker_based import endpoint as ep
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.engines.worker_based import server
|
||||
from taskflow import task as task_atom
|
||||
from taskflow import test
|
||||
from taskflow.test import mock
|
||||
from taskflow.tests import utils
|
||||
@ -103,45 +104,41 @@ class TestServer(test.MockTestCase):
|
||||
|
||||
def test_parse_request(self):
|
||||
request = self.make_request()
|
||||
task_cls, action, task_args = server.Server._parse_request(**request)
|
||||
|
||||
self.assertEqual((task_cls, action, task_args),
|
||||
(self.task.name, self.task_action,
|
||||
dict(task_name=self.task.name,
|
||||
arguments=self.task_args)))
|
||||
bundle = server.Server._parse_request(**request)
|
||||
task_cls, task_name, action, task_args = bundle
|
||||
self.assertEqual((task_cls, task_name, action, task_args),
|
||||
(self.task.name, self.task.name, self.task_action,
|
||||
dict(arguments=self.task_args)))
|
||||
|
||||
def test_parse_request_with_success_result(self):
|
||||
request = self.make_request(action='revert', result=1)
|
||||
task_cls, action, task_args = server.Server._parse_request(**request)
|
||||
|
||||
self.assertEqual((task_cls, action, task_args),
|
||||
(self.task.name, 'revert',
|
||||
dict(task_name=self.task.name,
|
||||
arguments=self.task_args,
|
||||
bundle = server.Server._parse_request(**request)
|
||||
task_cls, task_name, action, task_args = bundle
|
||||
self.assertEqual((task_cls, task_name, action, task_args),
|
||||
(self.task.name, self.task.name, 'revert',
|
||||
dict(arguments=self.task_args,
|
||||
result=1)))
|
||||
|
||||
def test_parse_request_with_failure_result(self):
|
||||
a_failure = failure.Failure.from_exception(Exception('test'))
|
||||
request = self.make_request(action='revert', result=a_failure)
|
||||
task_cls, action, task_args = server.Server._parse_request(**request)
|
||||
|
||||
self.assertEqual((task_cls, action, task_args),
|
||||
(self.task.name, 'revert',
|
||||
dict(task_name=self.task.name,
|
||||
arguments=self.task_args,
|
||||
bundle = server.Server._parse_request(**request)
|
||||
task_cls, task_name, action, task_args = bundle
|
||||
self.assertEqual((task_cls, task_name, action, task_args),
|
||||
(self.task.name, self.task.name, 'revert',
|
||||
dict(arguments=self.task_args,
|
||||
result=utils.FailureMatcher(a_failure))))
|
||||
|
||||
def test_parse_request_with_failures(self):
|
||||
failures = {'0': failure.Failure.from_exception(Exception('test1')),
|
||||
'1': failure.Failure.from_exception(Exception('test2'))}
|
||||
request = self.make_request(action='revert', failures=failures)
|
||||
task_cls, action, task_args = server.Server._parse_request(**request)
|
||||
|
||||
bundle = server.Server._parse_request(**request)
|
||||
task_cls, task_name, action, task_args = bundle
|
||||
self.assertEqual(
|
||||
(task_cls, action, task_args),
|
||||
(self.task.name, 'revert',
|
||||
dict(task_name=self.task.name,
|
||||
arguments=self.task_args,
|
||||
(task_cls, task_name, action, task_args),
|
||||
(self.task.name, self.task.name, 'revert',
|
||||
dict(arguments=self.task_args,
|
||||
failures=dict((i, utils.FailureMatcher(f))
|
||||
for i, f in six.iteritems(failures)))))
|
||||
|
||||
@ -182,17 +179,19 @@ class TestServer(test.MockTestCase):
|
||||
mock.call.Response(pr.RUNNING),
|
||||
mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
|
||||
correlation_id=self.task_uuid),
|
||||
mock.call.Response(pr.PROGRESS, progress=0.0, event_data={}),
|
||||
mock.call.Response(pr.EVENT, details={'progress': 0.0},
|
||||
event_type=task_atom.EVENT_UPDATE_PROGRESS),
|
||||
mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
|
||||
correlation_id=self.task_uuid),
|
||||
mock.call.Response(pr.PROGRESS, progress=1.0, event_data={}),
|
||||
mock.call.Response(pr.EVENT, details={'progress': 1.0},
|
||||
event_type=task_atom.EVENT_UPDATE_PROGRESS),
|
||||
mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
|
||||
correlation_id=self.task_uuid),
|
||||
mock.call.Response(pr.SUCCESS, result=5),
|
||||
mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
|
||||
correlation_id=self.task_uuid)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
self.assertEqual(master_mock_calls, self.master_mock.mock_calls)
|
||||
|
||||
def test_process_request(self):
|
||||
# create server and process request
|
||||
|
@ -15,6 +15,7 @@
|
||||
# under the License.
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import logging
|
||||
|
||||
import six
|
||||
@ -39,6 +40,14 @@ class _Listener(object):
|
||||
else:
|
||||
self._kwargs = kwargs.copy()
|
||||
|
||||
@property
|
||||
def kwargs(self):
|
||||
return self._kwargs
|
||||
|
||||
@property
|
||||
def args(self):
|
||||
return self._args
|
||||
|
||||
def __call__(self, event_type, details):
|
||||
if self._details_filter is not None:
|
||||
if not self._details_filter(details):
|
||||
@ -117,17 +126,18 @@ class Notifier(object):
|
||||
event type will be called.
|
||||
|
||||
:param event_type: event type that occurred
|
||||
:param details: addition event details
|
||||
:param details: additional event details *dictionary* passed to
|
||||
callback keyword argument with the same name.
|
||||
"""
|
||||
listeners = list(self._listeners.get(self.ANY, []))
|
||||
for listener in self._listeners[event_type]:
|
||||
if listener not in listeners:
|
||||
listeners.append(listener)
|
||||
listeners.extend(self._listeners.get(event_type, []))
|
||||
if not listeners:
|
||||
return
|
||||
if not details:
|
||||
details = {}
|
||||
for listener in listeners:
|
||||
try:
|
||||
listener(event_type, details)
|
||||
listener(event_type, details.copy())
|
||||
except Exception:
|
||||
LOG.warn("Failure calling listener %s to notify about event"
|
||||
" %s, details: %s", listener, event_type,
|
||||
@ -150,6 +160,9 @@ class Notifier(object):
|
||||
if details_filter is not None:
|
||||
if not six.callable(details_filter):
|
||||
raise ValueError("Details filter must be callable")
|
||||
if not self.can_be_registered(event_type):
|
||||
raise ValueError("Disallowed event type '%s' can not have a"
|
||||
" callback registered" % event_type)
|
||||
if self.is_registered(event_type, callback,
|
||||
details_filter=details_filter):
|
||||
raise ValueError("Event callback already registered with"
|
||||
@ -165,10 +178,61 @@ class Notifier(object):
|
||||
details_filter=details_filter))
|
||||
|
||||
def deregister(self, event_type, callback, details_filter=None):
|
||||
"""Remove a single callback from listening to event ``event_type``."""
|
||||
"""Remove a single listener bound to event ``event_type``."""
|
||||
if event_type not in self._listeners:
|
||||
return
|
||||
for i, listener in enumerate(self._listeners[event_type]):
|
||||
return False
|
||||
for i, listener in enumerate(self._listeners.get(event_type, [])):
|
||||
if listener.is_equivalent(callback, details_filter=details_filter):
|
||||
self._listeners[event_type].pop(i)
|
||||
break
|
||||
return True
|
||||
return False
|
||||
|
||||
def deregister_event(self, event_type):
|
||||
"""Remove a group of listeners bound to event ``event_type``."""
|
||||
return len(self._listeners.pop(event_type, []))
|
||||
|
||||
def copy(self):
|
||||
c = copy.copy(self)
|
||||
c._listeners = collections.defaultdict(list)
|
||||
for event_type, listeners in six.iteritems(self._listeners):
|
||||
c._listeners[event_type] = listeners[:]
|
||||
return c
|
||||
|
||||
def listeners_iter(self):
|
||||
"""Return an iterator over the mapping of event => listeners bound."""
|
||||
for event_type, listeners in six.iteritems(self._listeners):
|
||||
if listeners:
|
||||
yield (event_type, listeners)
|
||||
|
||||
def can_be_registered(self, event_type):
|
||||
"""Checks if the event can be registered/subscribed to."""
|
||||
return True
|
||||
|
||||
|
||||
class RestrictedNotifier(Notifier):
|
||||
"""A notification class that restricts events registered/triggered.
|
||||
|
||||
NOTE(harlowja): This class unlike :class:`.Notifier` restricts and
|
||||
disallows registering callbacks for event types that are not declared
|
||||
when constructing the notifier.
|
||||
"""
|
||||
|
||||
def __init__(self, watchable_events, allow_any=True):
|
||||
super(RestrictedNotifier, self).__init__()
|
||||
self._watchable_events = frozenset(watchable_events)
|
||||
self._allow_any = allow_any
|
||||
|
||||
def events_iter(self):
|
||||
"""Returns iterator of events that can be registered/subscribed to.
|
||||
|
||||
NOTE(harlowja): does not include back the ``ANY`` event type as that
|
||||
meta-type is not a specific event but is a capture-all that does not
|
||||
imply the same meaning as specific event types.
|
||||
"""
|
||||
for event_type in self._watchable_events:
|
||||
yield event_type
|
||||
|
||||
def can_be_registered(self, event_type):
|
||||
"""Checks if the event can be registered/subscribed to."""
|
||||
return (event_type in self._watchable_events or
|
||||
(event_type == self.ANY and self._allow_any))
|
||||
|
Loading…
Reference in New Issue
Block a user