Use the notifier type in the task class/module directly
Instead of having code that is some what like the notifier
code we already have, but is duplicated and is slightly different
in the task class just move the code that was in the task class (and
doing similar actions) to instead now use a notifier that is directly
contained in the task base class for internal task triggering of
internal task events.
Breaking change: alters the capabilities of the task to process
notifications itself, most actions now must go through the task
notifier property and instead use that (update_progress still exists
as a common utility method, since it's likely the most common type
of notification that will be used).
Removes the following methods from task base class (as they are
no longer needed with a notifier attribute):
- trigger (replaced with notifier.notify)
- autobind (removed, not replaced, can be created by the user
            of taskflow in a simple manner, without requiring
            functionality in taskflow)
- bind (replaced with notifier.register)
- unbind (replaced with notifier.unregister)
- listeners_iter (replaced with notifier.listeners_iter)
Due to this change we can now also correctly proxy back events from
remote tasks to the engine for correct proxying back to the original
task.
Fixes bug 1370766
Change-Id: Ic9dfef516d72e6e32e71dda30a1cb3522c9e0be6
			
			
This commit is contained in:
		 Joshua Harlow
					Joshua Harlow
				
			
				
					committed by
					
						 Joshua Harlow
						Joshua Harlow
					
				
			
			
				
	
			
			
			 Joshua Harlow
						Joshua Harlow
					
				
			
						parent
						
							cdfd8ece61
						
					
				
				
					commit
					1f4dd72e6e
				
			| @@ -158,10 +158,11 @@ engine executor in the following manner: | |||||||
|     from dicts after receiving on both executor & worker sides (this |     from dicts after receiving on both executor & worker sides (this | ||||||
|     translation is lossy since the traceback won't be fully retained). |     translation is lossy since the traceback won't be fully retained). | ||||||
|  |  | ||||||
| Executor request format | Executor execute format | ||||||
| ~~~~~~~~~~~~~~~~~~~~~~~ | ~~~~~~~~~~~~~~~~~~~~~~~ | ||||||
|  |  | ||||||
| * **task** - full task name to be performed | * **task_name** - full task name to be performed | ||||||
|  | * **task_cls** - full task class name to be performed | ||||||
| * **action** - task action to be performed (e.g. execute, revert) | * **action** - task action to be performed (e.g. execute, revert) | ||||||
| * **arguments** - arguments the task action to be called with | * **arguments** - arguments the task action to be called with | ||||||
| * **result** - task execution result (result or | * **result** - task execution result (result or | ||||||
| @@ -180,9 +181,14 @@ Additionally, the following parameters are added to the request message: | |||||||
|     { |     { | ||||||
|         "action": "execute", |         "action": "execute", | ||||||
|         "arguments": { |         "arguments": { | ||||||
|             "joe_number": 444 |             "x": 111 | ||||||
|         }, |         }, | ||||||
|         "task": "tasks.CallJoe" |         "task_cls": "taskflow.tests.utils.TaskOneArgOneReturn", | ||||||
|  |         "task_name": "taskflow.tests.utils.TaskOneArgOneReturn", | ||||||
|  |         "task_version": [ | ||||||
|  |             1, | ||||||
|  |             0 | ||||||
|  |         ] | ||||||
|     } |     } | ||||||
|  |  | ||||||
| Worker response format | Worker response format | ||||||
| @@ -193,7 +199,8 @@ When **running:** | |||||||
| .. code:: json | .. code:: json | ||||||
|  |  | ||||||
|     { |     { | ||||||
|         "status": "RUNNING" |         "data": {}, | ||||||
|  |         "state": "RUNNING" | ||||||
|     } |     } | ||||||
|  |  | ||||||
| When **progressing:** | When **progressing:** | ||||||
| @@ -201,9 +208,11 @@ When **progressing:** | |||||||
| .. code:: json | .. code:: json | ||||||
|  |  | ||||||
|     { |     { | ||||||
|         "event_data": <event_data>, |         "details": { | ||||||
|         "progress": <progress>, |             "progress": 0.5 | ||||||
|         "state": "PROGRESS" |         }, | ||||||
|  |         "event_type": "update_progress", | ||||||
|  |         "state": "EVENT" | ||||||
|     } |     } | ||||||
|  |  | ||||||
| When **succeeded:** | When **succeeded:** | ||||||
| @@ -211,8 +220,9 @@ When **succeeded:** | |||||||
| .. code:: json | .. code:: json | ||||||
|  |  | ||||||
|     { |     { | ||||||
|         "event": <event>, |         "data": { | ||||||
|         "result": <result>, |             "result": 666 | ||||||
|  |         }, | ||||||
|         "state": "SUCCESS" |         "state": "SUCCESS" | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -221,11 +231,64 @@ When **failed:** | |||||||
| .. code:: json | .. code:: json | ||||||
|  |  | ||||||
|     { |     { | ||||||
|         "event": <event>, |         "data": { | ||||||
|         "result": <types.failure.Failure>, |             "result": { | ||||||
|  |                 "exc_type_names": [ | ||||||
|  |                     "RuntimeError", | ||||||
|  |                     "StandardError", | ||||||
|  |                     "Exception" | ||||||
|  |                 ], | ||||||
|  |                 "exception_str": "Woot!", | ||||||
|  |                 "traceback_str": "  File \"/homes/harlowja/dev/os/taskflow/taskflow/engines/action_engine/executor.py\", line 56, in _execute_task\n    result = task.execute(**arguments)\n  File \"/homes/harlowja/dev/os/taskflow/taskflow/tests/utils.py\", line 165, in execute\n    raise RuntimeError('Woot!')\n", | ||||||
|  |                 "version": 1 | ||||||
|  |             } | ||||||
|  |         }, | ||||||
|         "state": "FAILURE" |         "state": "FAILURE" | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  | Executor revert format | ||||||
|  | ~~~~~~~~~~~~~~~~~~~~~~ | ||||||
|  |  | ||||||
|  | When **reverting:** | ||||||
|  |  | ||||||
|  | .. code:: json | ||||||
|  |  | ||||||
|  |     { | ||||||
|  |         "action": "revert", | ||||||
|  |         "arguments": {}, | ||||||
|  |         "failures": { | ||||||
|  |             "taskflow.tests.utils.TaskWithFailure": { | ||||||
|  |                 "exc_type_names": [ | ||||||
|  |                     "RuntimeError", | ||||||
|  |                     "StandardError", | ||||||
|  |                     "Exception" | ||||||
|  |                 ], | ||||||
|  |                 "exception_str": "Woot!", | ||||||
|  |                 "traceback_str": "  File \"/homes/harlowja/dev/os/taskflow/taskflow/engines/action_engine/executor.py\", line 56, in _execute_task\n    result = task.execute(**arguments)\n  File \"/homes/harlowja/dev/os/taskflow/taskflow/tests/utils.py\", line 165, in execute\n    raise RuntimeError('Woot!')\n", | ||||||
|  |                 "version": 1 | ||||||
|  |             } | ||||||
|  |         }, | ||||||
|  |         "result": [ | ||||||
|  |             "failure", | ||||||
|  |             { | ||||||
|  |                 "exc_type_names": [ | ||||||
|  |                     "RuntimeError", | ||||||
|  |                     "StandardError", | ||||||
|  |                     "Exception" | ||||||
|  |                 ], | ||||||
|  |                 "exception_str": "Woot!", | ||||||
|  |                 "traceback_str": "  File \"/homes/harlowja/dev/os/taskflow/taskflow/engines/action_engine/executor.py\", line 56, in _execute_task\n    result = task.execute(**arguments)\n  File \"/homes/harlowja/dev/os/taskflow/taskflow/tests/utils.py\", line 165, in execute\n    raise RuntimeError('Woot!')\n", | ||||||
|  |                 "version": 1 | ||||||
|  |             } | ||||||
|  |         ], | ||||||
|  |         "task_cls": "taskflow.tests.utils.TaskWithFailure", | ||||||
|  |         "task_name": "taskflow.tests.utils.TaskWithFailure", | ||||||
|  |         "task_version": [ | ||||||
|  |             1, | ||||||
|  |             0 | ||||||
|  |         ] | ||||||
|  |     } | ||||||
|  |  | ||||||
| Usage | Usage | ||||||
| ===== | ===== | ||||||
|  |  | ||||||
|   | |||||||
| @@ -14,6 +14,8 @@ | |||||||
| #    License for the specific language governing permissions and limitations | #    License for the specific language governing permissions and limitations | ||||||
| #    under the License. | #    under the License. | ||||||
|  |  | ||||||
|  | import functools | ||||||
|  |  | ||||||
| from taskflow import logging | from taskflow import logging | ||||||
| from taskflow import states | from taskflow import states | ||||||
| from taskflow import task as task_atom | from taskflow import task as task_atom | ||||||
| @@ -75,25 +77,37 @@ class TaskAction(object): | |||||||
|         if progress is not None: |         if progress is not None: | ||||||
|             task.update_progress(progress) |             task.update_progress(progress) | ||||||
|  |  | ||||||
|     def _on_update_progress(self, task, event_data, progress, **kwargs): |     def _on_update_progress(self, task, event_type, details): | ||||||
|         """Should be called when task updates its progress.""" |         """Should be called when task updates its progress.""" | ||||||
|         try: |         try: | ||||||
|             self._storage.set_task_progress(task.name, progress, kwargs) |             progress = details.pop('progress') | ||||||
|  |         except KeyError: | ||||||
|  |             pass | ||||||
|  |         else: | ||||||
|  |             try: | ||||||
|  |                 self._storage.set_task_progress(task.name, progress, | ||||||
|  |                                                 details=details) | ||||||
|             except Exception: |             except Exception: | ||||||
|             # Update progress callbacks should never fail, so capture and log |                 # Update progress callbacks should never fail, so capture and | ||||||
|             # the emitted exception instead of raising it. |                 # log the emitted exception instead of raising it. | ||||||
|                 LOG.exception("Failed setting task progress for %s to %0.3f", |                 LOG.exception("Failed setting task progress for %s to %0.3f", | ||||||
|                               task, progress) |                               task, progress) | ||||||
|  |  | ||||||
|     def schedule_execution(self, task): |     def schedule_execution(self, task): | ||||||
|         self.change_state(task, states.RUNNING, progress=0.0) |         self.change_state(task, states.RUNNING, progress=0.0) | ||||||
|         scope_walker = self._walker_factory(task) |         scope_walker = self._walker_factory(task) | ||||||
|         kwargs = self._storage.fetch_mapped_args(task.rebind, |         arguments = self._storage.fetch_mapped_args(task.rebind, | ||||||
|                                                     atom_name=task.name, |                                                     atom_name=task.name, | ||||||
|                                                     scope_walker=scope_walker) |                                                     scope_walker=scope_walker) | ||||||
|  |         if task.notifier.can_be_registered(task_atom.EVENT_UPDATE_PROGRESS): | ||||||
|  |             progress_callback = functools.partial(self._on_update_progress, | ||||||
|  |                                                   task) | ||||||
|  |         else: | ||||||
|  |             progress_callback = None | ||||||
|         task_uuid = self._storage.get_atom_uuid(task.name) |         task_uuid = self._storage.get_atom_uuid(task.name) | ||||||
|         return self._task_executor.execute_task(task, task_uuid, kwargs, |         return self._task_executor.execute_task( | ||||||
|                                                 self._on_update_progress) |             task, task_uuid, arguments, | ||||||
|  |             progress_callback=progress_callback) | ||||||
|  |  | ||||||
|     def complete_execution(self, task, result): |     def complete_execution(self, task, result): | ||||||
|         if isinstance(result, failure.Failure): |         if isinstance(result, failure.Failure): | ||||||
| @@ -105,15 +119,20 @@ class TaskAction(object): | |||||||
|     def schedule_reversion(self, task): |     def schedule_reversion(self, task): | ||||||
|         self.change_state(task, states.REVERTING, progress=0.0) |         self.change_state(task, states.REVERTING, progress=0.0) | ||||||
|         scope_walker = self._walker_factory(task) |         scope_walker = self._walker_factory(task) | ||||||
|         kwargs = self._storage.fetch_mapped_args(task.rebind, |         arguments = self._storage.fetch_mapped_args(task.rebind, | ||||||
|                                                     atom_name=task.name, |                                                     atom_name=task.name, | ||||||
|                                                     scope_walker=scope_walker) |                                                     scope_walker=scope_walker) | ||||||
|         task_uuid = self._storage.get_atom_uuid(task.name) |         task_uuid = self._storage.get_atom_uuid(task.name) | ||||||
|         task_result = self._storage.get(task.name) |         task_result = self._storage.get(task.name) | ||||||
|         failures = self._storage.get_failures() |         failures = self._storage.get_failures() | ||||||
|         future = self._task_executor.revert_task(task, task_uuid, kwargs, |         if task.notifier.can_be_registered(task_atom.EVENT_UPDATE_PROGRESS): | ||||||
|                                                  task_result, failures, |             progress_callback = functools.partial(self._on_update_progress, | ||||||
|                                                  self._on_update_progress) |                                                   task) | ||||||
|  |         else: | ||||||
|  |             progress_callback = None | ||||||
|  |         future = self._task_executor.revert_task( | ||||||
|  |             task, task_uuid, arguments, task_result, failures, | ||||||
|  |             progress_callback=progress_callback) | ||||||
|         return future |         return future | ||||||
|  |  | ||||||
|     def complete_reversion(self, task, rev_result): |     def complete_reversion(self, task, rev_result): | ||||||
|   | |||||||
| @@ -15,10 +15,11 @@ | |||||||
| #    under the License. | #    under the License. | ||||||
|  |  | ||||||
| import abc | import abc | ||||||
|  | import contextlib | ||||||
|  |  | ||||||
| import six | import six | ||||||
|  |  | ||||||
| from taskflow import task as _task | from taskflow import task as task_atom | ||||||
| from taskflow.types import failure | from taskflow.types import failure | ||||||
| from taskflow.types import futures | from taskflow.types import futures | ||||||
| from taskflow.utils import async_utils | from taskflow.utils import async_utils | ||||||
| @@ -29,8 +30,23 @@ EXECUTED = 'executed' | |||||||
| REVERTED = 'reverted' | REVERTED = 'reverted' | ||||||
|  |  | ||||||
|  |  | ||||||
| def _execute_task(task, arguments, progress_callback): | @contextlib.contextmanager | ||||||
|     with task.autobind(_task.EVENT_UPDATE_PROGRESS, progress_callback): | def _autobind(task, progress_callback=None): | ||||||
|  |     bound = False | ||||||
|  |     if progress_callback is not None: | ||||||
|  |         task.notifier.register(task_atom.EVENT_UPDATE_PROGRESS, | ||||||
|  |                                progress_callback) | ||||||
|  |         bound = True | ||||||
|  |     try: | ||||||
|  |         yield | ||||||
|  |     finally: | ||||||
|  |         if bound: | ||||||
|  |             task.notifier.deregister(task_atom.EVENT_UPDATE_PROGRESS, | ||||||
|  |                                      progress_callback) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _execute_task(task, arguments, progress_callback=None): | ||||||
|  |     with _autobind(task, progress_callback=progress_callback): | ||||||
|         try: |         try: | ||||||
|             task.pre_execute() |             task.pre_execute() | ||||||
|             result = task.execute(**arguments) |             result = task.execute(**arguments) | ||||||
| @@ -43,14 +59,14 @@ def _execute_task(task, arguments, progress_callback): | |||||||
|     return (EXECUTED, result) |     return (EXECUTED, result) | ||||||
|  |  | ||||||
|  |  | ||||||
| def _revert_task(task, arguments, result, failures, progress_callback): | def _revert_task(task, arguments, result, failures, progress_callback=None): | ||||||
|     kwargs = arguments.copy() |     arguments = arguments.copy() | ||||||
|     kwargs[_task.REVERT_RESULT] = result |     arguments[task_atom.REVERT_RESULT] = result | ||||||
|     kwargs[_task.REVERT_FLOW_FAILURES] = failures |     arguments[task_atom.REVERT_FLOW_FAILURES] = failures | ||||||
|     with task.autobind(_task.EVENT_UPDATE_PROGRESS, progress_callback): |     with _autobind(task, progress_callback=progress_callback): | ||||||
|         try: |         try: | ||||||
|             task.pre_revert() |             task.pre_revert() | ||||||
|             result = task.revert(**kwargs) |             result = task.revert(**arguments) | ||||||
|         except Exception: |         except Exception: | ||||||
|             # NOTE(imelnikov): wrap current exception with Failure |             # NOTE(imelnikov): wrap current exception with Failure | ||||||
|             # object and return it. |             # object and return it. | ||||||
| @@ -98,15 +114,17 @@ class SerialTaskExecutor(TaskExecutor): | |||||||
|         self._executor = futures.SynchronousExecutor() |         self._executor = futures.SynchronousExecutor() | ||||||
|  |  | ||||||
|     def execute_task(self, task, task_uuid, arguments, progress_callback=None): |     def execute_task(self, task, task_uuid, arguments, progress_callback=None): | ||||||
|         fut = self._executor.submit(_execute_task, task, arguments, |         fut = self._executor.submit(_execute_task, | ||||||
|                                     progress_callback) |                                     task, arguments, | ||||||
|  |                                     progress_callback=progress_callback) | ||||||
|         fut.atom = task |         fut.atom = task | ||||||
|         return fut |         return fut | ||||||
|  |  | ||||||
|     def revert_task(self, task, task_uuid, arguments, result, failures, |     def revert_task(self, task, task_uuid, arguments, result, failures, | ||||||
|                     progress_callback=None): |                     progress_callback=None): | ||||||
|         fut = self._executor.submit(_revert_task, task, arguments, result, |         fut = self._executor.submit(_revert_task, | ||||||
|                                     failures, progress_callback) |                                     task, arguments, result, failures, | ||||||
|  |                                     progress_callback=progress_callback) | ||||||
|         fut.atom = task |         fut.atom = task | ||||||
|         return fut |         return fut | ||||||
|  |  | ||||||
| @@ -127,15 +145,17 @@ class ParallelTaskExecutor(TaskExecutor): | |||||||
|         self._create_executor = executor is None |         self._create_executor = executor is None | ||||||
|  |  | ||||||
|     def execute_task(self, task, task_uuid, arguments, progress_callback=None): |     def execute_task(self, task, task_uuid, arguments, progress_callback=None): | ||||||
|         fut = self._executor.submit(_execute_task, task, |         fut = self._executor.submit(_execute_task, | ||||||
|                                     arguments, progress_callback) |                                     task, arguments, | ||||||
|  |                                     progress_callback=progress_callback) | ||||||
|         fut.atom = task |         fut.atom = task | ||||||
|         return fut |         return fut | ||||||
|  |  | ||||||
|     def revert_task(self, task, task_uuid, arguments, result, failures, |     def revert_task(self, task, task_uuid, arguments, result, failures, | ||||||
|                     progress_callback=None): |                     progress_callback=None): | ||||||
|         fut = self._executor.submit(_revert_task, task, arguments, |         fut = self._executor.submit(_revert_task, | ||||||
|                                     result, failures, progress_callback) |                                     task, arguments, result, failures, | ||||||
|  |                                     progress_callback=progress_callback) | ||||||
|         fut.atom = task |         fut.atom = task | ||||||
|         return fut |         return fut | ||||||
|  |  | ||||||
|   | |||||||
| @@ -33,18 +33,16 @@ class Endpoint(object): | |||||||
|     def name(self): |     def name(self): | ||||||
|         return self._task_cls_name |         return self._task_cls_name | ||||||
|  |  | ||||||
|     def _get_task(self, name=None): |     def generate(self, name=None): | ||||||
|         # NOTE(skudriashev): Note that task is created here with the `name` |         # NOTE(skudriashev): Note that task is created here with the `name` | ||||||
|         # argument passed to its constructor. This will be a problem when |         # argument passed to its constructor. This will be a problem when | ||||||
|         # task's constructor requires any other arguments. |         # task's constructor requires any other arguments. | ||||||
|         return self._task_cls(name=name) |         return self._task_cls(name=name) | ||||||
|  |  | ||||||
|     def execute(self, task_name, **kwargs): |     def execute(self, task, **kwargs): | ||||||
|         task = self._get_task(task_name) |  | ||||||
|         event, result = self._executor.execute_task(task, **kwargs).result() |         event, result = self._executor.execute_task(task, **kwargs).result() | ||||||
|         return result |         return result | ||||||
|  |  | ||||||
|     def revert(self, task_name, **kwargs): |     def revert(self, task, **kwargs): | ||||||
|         task = self._get_task(task_name) |  | ||||||
|         event, result = self._executor.revert_task(task, **kwargs).result() |         event, result = self._executor.revert_task(task, **kwargs).result() | ||||||
|         return result |         return result | ||||||
|   | |||||||
| @@ -25,6 +25,7 @@ from taskflow.engines.worker_based import protocol as pr | |||||||
| from taskflow.engines.worker_based import proxy | from taskflow.engines.worker_based import proxy | ||||||
| from taskflow import exceptions as exc | from taskflow import exceptions as exc | ||||||
| from taskflow import logging | from taskflow import logging | ||||||
|  | from taskflow import task as task_atom | ||||||
| from taskflow.types import timing as tt | from taskflow.types import timing as tt | ||||||
| from taskflow.utils import async_utils | from taskflow.utils import async_utils | ||||||
| from taskflow.utils import misc | from taskflow.utils import misc | ||||||
| @@ -132,8 +133,11 @@ class WorkerTaskExecutor(executor.TaskExecutor): | |||||||
|                           response.state, request) |                           response.state, request) | ||||||
|                 if response.state == pr.RUNNING: |                 if response.state == pr.RUNNING: | ||||||
|                     request.transition_and_log_error(pr.RUNNING, logger=LOG) |                     request.transition_and_log_error(pr.RUNNING, logger=LOG) | ||||||
|                 elif response.state == pr.PROGRESS: |                 elif response.state == pr.EVENT: | ||||||
|                     request.on_progress(**response.data) |                     # Proxy the event + details to the task/request notifier... | ||||||
|  |                     event_type = response.data['event_type'] | ||||||
|  |                     details = response.data['details'] | ||||||
|  |                     request.notifier.notify(event_type, details) | ||||||
|                 elif response.state in (pr.FAILURE, pr.SUCCESS): |                 elif response.state in (pr.FAILURE, pr.SUCCESS): | ||||||
|                     moved = request.transition_and_log_error(response.state, |                     moved = request.transition_and_log_error(response.state, | ||||||
|                                                              logger=LOG) |                                                              logger=LOG) | ||||||
| @@ -181,8 +185,18 @@ class WorkerTaskExecutor(executor.TaskExecutor): | |||||||
|                      progress_callback, **kwargs): |                      progress_callback, **kwargs): | ||||||
|         """Submit task request to a worker.""" |         """Submit task request to a worker.""" | ||||||
|         request = pr.Request(task, task_uuid, action, arguments, |         request = pr.Request(task, task_uuid, action, arguments, | ||||||
|                              progress_callback, self._transition_timeout, |                              self._transition_timeout, **kwargs) | ||||||
|                              **kwargs) |  | ||||||
|  |         # Register the callback, so that we can proxy the progress correctly. | ||||||
|  |         if (progress_callback is not None and | ||||||
|  |                 request.notifier.can_be_registered( | ||||||
|  |                     task_atom.EVENT_UPDATE_PROGRESS)): | ||||||
|  |             request.notifier.register(task_atom.EVENT_UPDATE_PROGRESS, | ||||||
|  |                                       progress_callback) | ||||||
|  |             cleaner = functools.partial(request.notifier.deregister, | ||||||
|  |                                         task_atom.EVENT_UPDATE_PROGRESS, | ||||||
|  |                                         progress_callback) | ||||||
|  |             request.result.add_done_callback(lambda fut: cleaner()) | ||||||
|  |  | ||||||
|         # Get task's topic and publish request if topic was found. |         # Get task's topic and publish request if topic was found. | ||||||
|         topic = self._workers_cache.get_topic_by_task(request.task_cls) |         topic = self._workers_cache.get_topic_by_task(request.task_cls) | ||||||
|   | |||||||
| @@ -38,14 +38,14 @@ PENDING = 'PENDING' | |||||||
| RUNNING = 'RUNNING' | RUNNING = 'RUNNING' | ||||||
| SUCCESS = 'SUCCESS' | SUCCESS = 'SUCCESS' | ||||||
| FAILURE = 'FAILURE' | FAILURE = 'FAILURE' | ||||||
| PROGRESS = 'PROGRESS' | EVENT = 'EVENT' | ||||||
|  |  | ||||||
| # During these states the expiry is active (once out of these states the expiry | # During these states the expiry is active (once out of these states the expiry | ||||||
| # no longer matters, since we have no way of knowing how long a task will run | # no longer matters, since we have no way of knowing how long a task will run | ||||||
| # for). | # for). | ||||||
| WAITING_STATES = (WAITING, PENDING) | WAITING_STATES = (WAITING, PENDING) | ||||||
|  |  | ||||||
| _ALL_STATES = (WAITING, PENDING, RUNNING, SUCCESS, FAILURE, PROGRESS) | _ALL_STATES = (WAITING, PENDING, RUNNING, SUCCESS, FAILURE, EVENT) | ||||||
| _STOP_TIMER_STATES = (RUNNING, SUCCESS, FAILURE) | _STOP_TIMER_STATES = (RUNNING, SUCCESS, FAILURE) | ||||||
|  |  | ||||||
| # Transitions that a request state can go through. | # Transitions that a request state can go through. | ||||||
| @@ -219,22 +219,29 @@ class Request(Message): | |||||||
|         'required': ['task_cls', 'task_name', 'task_version', 'action'], |         'required': ['task_cls', 'task_name', 'task_version', 'action'], | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     def __init__(self, task, uuid, action, arguments, progress_callback, |     def __init__(self, task, uuid, action, arguments, timeout, **kwargs): | ||||||
|                  timeout, **kwargs): |  | ||||||
|         self._task = task |         self._task = task | ||||||
|         self._task_cls = reflection.get_class_name(task) |         self._task_cls = reflection.get_class_name(task) | ||||||
|         self._uuid = uuid |         self._uuid = uuid | ||||||
|         self._action = action |         self._action = action | ||||||
|         self._event = ACTION_TO_EVENT[action] |         self._event = ACTION_TO_EVENT[action] | ||||||
|         self._arguments = arguments |         self._arguments = arguments | ||||||
|         self._progress_callback = progress_callback |  | ||||||
|         self._kwargs = kwargs |         self._kwargs = kwargs | ||||||
|         self._watch = tt.StopWatch(duration=timeout).start() |         self._watch = tt.StopWatch(duration=timeout).start() | ||||||
|         self._state = WAITING |         self._state = WAITING | ||||||
|         self._lock = threading.Lock() |         self._lock = threading.Lock() | ||||||
|         self._created_on = timeutils.utcnow() |         self._created_on = timeutils.utcnow() | ||||||
|         self.result = futures.Future() |         self._result = futures.Future() | ||||||
|         self.result.atom = task |         self._result.atom = task | ||||||
|  |         self._notifier = task.notifier | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def result(self): | ||||||
|  |         return self._result | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def notifier(self): | ||||||
|  |         return self._notifier | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def uuid(self): |     def uuid(self): | ||||||
| @@ -293,9 +300,6 @@ class Request(Message): | |||||||
|     def set_result(self, result): |     def set_result(self, result): | ||||||
|         self.result.set_result((self._event, result)) |         self.result.set_result((self._event, result)) | ||||||
|  |  | ||||||
|     def on_progress(self, event_data, progress): |  | ||||||
|         self._progress_callback(self._task, event_data, progress) |  | ||||||
|  |  | ||||||
|     def transition_and_log_error(self, new_state, logger=None): |     def transition_and_log_error(self, new_state, logger=None): | ||||||
|         """Transitions *and* logs an error if that transitioning raises. |         """Transitions *and* logs an error if that transitioning raises. | ||||||
|  |  | ||||||
| @@ -362,7 +366,7 @@ class Response(Message): | |||||||
|             'data': { |             'data': { | ||||||
|                 "anyOf": [ |                 "anyOf": [ | ||||||
|                     { |                     { | ||||||
|                         "$ref": "#/definitions/progress", |                         "$ref": "#/definitions/event", | ||||||
|                     }, |                     }, | ||||||
|                     { |                     { | ||||||
|                         "$ref": "#/definitions/completion", |                         "$ref": "#/definitions/completion", | ||||||
| @@ -376,17 +380,17 @@ class Response(Message): | |||||||
|         "required": ["state", 'data'], |         "required": ["state", 'data'], | ||||||
|         "additionalProperties": False, |         "additionalProperties": False, | ||||||
|         "definitions": { |         "definitions": { | ||||||
|             "progress": { |             "event": { | ||||||
|                 "type": "object", |                 "type": "object", | ||||||
|                 "properties": { |                 "properties": { | ||||||
|                     'progress': { |                     'event_type': { | ||||||
|                         'type': 'number', |                         'type': 'string', | ||||||
|                     }, |                     }, | ||||||
|                     'event_data': { |                     'details': { | ||||||
|                         'type': 'object', |                         'type': 'object', | ||||||
|                     }, |                     }, | ||||||
|                 }, |                 }, | ||||||
|                 "required": ["progress", 'event_data'], |                 "required": ["event_type", 'details'], | ||||||
|                 "additionalProperties": False, |                 "additionalProperties": False, | ||||||
|             }, |             }, | ||||||
|             # Used when sending *only* request state changes (and no data is |             # Used when sending *only* request state changes (and no data is | ||||||
|   | |||||||
| @@ -22,6 +22,7 @@ from taskflow.engines.worker_based import protocol as pr | |||||||
| from taskflow.engines.worker_based import proxy | from taskflow.engines.worker_based import proxy | ||||||
| from taskflow import logging | from taskflow import logging | ||||||
| from taskflow.types import failure as ft | from taskflow.types import failure as ft | ||||||
|  | from taskflow.types import notifier as nt | ||||||
| from taskflow.utils import misc | from taskflow.utils import misc | ||||||
|  |  | ||||||
| LOG = logging.getLogger(__name__) | LOG = logging.getLogger(__name__) | ||||||
| @@ -73,18 +74,23 @@ class Server(object): | |||||||
|         All `failure.Failure` objects that have been converted to dict on the |         All `failure.Failure` objects that have been converted to dict on the | ||||||
|         remote side will now converted back to `failure.Failure` objects. |         remote side will now converted back to `failure.Failure` objects. | ||||||
|         """ |         """ | ||||||
|         action_args = dict(arguments=arguments, task_name=task_name) |         # These arguments will eventually be given to the task executor | ||||||
|  |         # so they need to be in a format it will accept (and using keyword | ||||||
|  |         # argument names that it accepts)... | ||||||
|  |         arguments = { | ||||||
|  |             'arguments': arguments, | ||||||
|  |         } | ||||||
|         if result is not None: |         if result is not None: | ||||||
|             data_type, data = result |             data_type, data = result | ||||||
|             if data_type == 'failure': |             if data_type == 'failure': | ||||||
|                 action_args['result'] = ft.Failure.from_dict(data) |                 arguments['result'] = ft.Failure.from_dict(data) | ||||||
|             else: |             else: | ||||||
|                 action_args['result'] = data |                 arguments['result'] = data | ||||||
|         if failures is not None: |         if failures is not None: | ||||||
|             action_args['failures'] = {} |             arguments['failures'] = {} | ||||||
|             for key, data in six.iteritems(failures): |             for key, data in six.iteritems(failures): | ||||||
|                 action_args['failures'][key] = ft.Failure.from_dict(data) |                 arguments['failures'][key] = ft.Failure.from_dict(data) | ||||||
|         return task_cls, action, action_args |         return (task_cls, task_name, action, arguments) | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def _parse_message(message): |     def _parse_message(message): | ||||||
| @@ -122,14 +128,13 @@ class Server(object): | |||||||
|                          exc_info=True) |                          exc_info=True) | ||||||
|         return published |         return published | ||||||
|  |  | ||||||
|     def _on_update_progress(self, reply_to, task_uuid, task, event_data, |     def _on_event(self, reply_to, task_uuid, event_type, details): | ||||||
|                             progress): |         """Send out a task event notification.""" | ||||||
|         """Send task update progress notification.""" |  | ||||||
|         # NOTE(harlowja): the executor that will trigger this using the |         # NOTE(harlowja): the executor that will trigger this using the | ||||||
|         # task notification/listener mechanism will handle logging if this |         # task notification/listener mechanism will handle logging if this | ||||||
|         # fails, so thats why capture is 'False' is used here. |         # fails, so thats why capture is 'False' is used here. | ||||||
|         self._reply(False, reply_to, task_uuid, pr.PROGRESS, |         self._reply(False, reply_to, task_uuid, pr.EVENT, | ||||||
|                     event_data=event_data, progress=progress) |                     event_type=event_type, details=details) | ||||||
|  |  | ||||||
|     def _process_notify(self, notify, message): |     def _process_notify(self, notify, message): | ||||||
|         """Process notify message and reply back.""" |         """Process notify message and reply back.""" | ||||||
| @@ -165,18 +170,15 @@ class Server(object): | |||||||
|                      message.delivery_tag, exc_info=True) |                      message.delivery_tag, exc_info=True) | ||||||
|             return |             return | ||||||
|         else: |         else: | ||||||
|             # prepare task progress callback |  | ||||||
|             progress_callback = functools.partial(self._on_update_progress, |  | ||||||
|                                                   reply_to, task_uuid) |  | ||||||
|             # prepare reply callback |             # prepare reply callback | ||||||
|             reply_callback = functools.partial(self._reply, True, reply_to, |             reply_callback = functools.partial(self._reply, True, reply_to, | ||||||
|                                                task_uuid) |                                                task_uuid) | ||||||
|  |  | ||||||
|         # parse request to get task name, action and action arguments |         # parse request to get task name, action and action arguments | ||||||
|         try: |         try: | ||||||
|             task_cls, action, action_args = self._parse_request(**request) |             bundle = self._parse_request(**request) | ||||||
|             action_args.update(task_uuid=task_uuid, |             task_cls, task_name, action, arguments = bundle | ||||||
|                                progress_callback=progress_callback) |             arguments['task_uuid'] = task_uuid | ||||||
|         except ValueError: |         except ValueError: | ||||||
|             with misc.capture_failure() as failure: |             with misc.capture_failure() as failure: | ||||||
|                 LOG.warn("Failed to parse request contents from message %r", |                 LOG.warn("Failed to parse request contents from message %r", | ||||||
| @@ -205,13 +207,37 @@ class Server(object): | |||||||
|                              message.delivery_tag, exc_info=True) |                              message.delivery_tag, exc_info=True) | ||||||
|                     reply_callback(result=failure.to_dict()) |                     reply_callback(result=failure.to_dict()) | ||||||
|                     return |                     return | ||||||
|  |             else: | ||||||
|  |                 try: | ||||||
|  |                     task = endpoint.generate(name=task_name) | ||||||
|  |                 except Exception: | ||||||
|  |                     with misc.capture_failure() as failure: | ||||||
|  |                         LOG.warn("The '%s' task '%s' generation for request" | ||||||
|  |                                  " message %r failed", endpoint, action, | ||||||
|  |                                  message.delivery_tag, exc_info=True) | ||||||
|  |                         reply_callback(result=failure.to_dict()) | ||||||
|  |                         return | ||||||
|                 else: |                 else: | ||||||
|                     if not reply_callback(state=pr.RUNNING): |                     if not reply_callback(state=pr.RUNNING): | ||||||
|                         return |                         return | ||||||
|  |  | ||||||
|         # perform task action |         # associate *any* events this task emits with a proxy that will | ||||||
|  |         # emit them back to the engine... for handling at the engine side | ||||||
|  |         # of things... | ||||||
|  |         if task.notifier.can_be_registered(nt.Notifier.ANY): | ||||||
|  |             task.notifier.register(nt.Notifier.ANY, | ||||||
|  |                                    functools.partial(self._on_event, | ||||||
|  |                                                      reply_to, task_uuid)) | ||||||
|  |         elif isinstance(task.notifier, nt.RestrictedNotifier): | ||||||
|  |             # only proxy the allowable events then... | ||||||
|  |             for event_type in task.notifier.events_iter(): | ||||||
|  |                 task.notifier.register(event_type, | ||||||
|  |                                        functools.partial(self._on_event, | ||||||
|  |                                                          reply_to, task_uuid)) | ||||||
|  |  | ||||||
|  |         # perform the task action | ||||||
|         try: |         try: | ||||||
|             result = handler(**action_args) |             result = handler(task, **arguments) | ||||||
|         except Exception: |         except Exception: | ||||||
|             with misc.capture_failure() as failure: |             with misc.capture_failure() as failure: | ||||||
|                 LOG.warn("The '%s' endpoint '%s' execution for request" |                 LOG.warn("The '%s' endpoint '%s' execution for request" | ||||||
|   | |||||||
							
								
								
									
										166
									
								
								taskflow/task.py
									
									
									
									
									
								
							
							
						
						
									
										166
									
								
								taskflow/task.py
									
									
									
									
									
								
							| @@ -16,14 +16,13 @@ | |||||||
| #    under the License. | #    under the License. | ||||||
|  |  | ||||||
| import abc | import abc | ||||||
| import collections |  | ||||||
| import contextlib |  | ||||||
| import copy | import copy | ||||||
|  |  | ||||||
| import six | import six | ||||||
|  |  | ||||||
| from taskflow import atom | from taskflow import atom | ||||||
| from taskflow import logging | from taskflow import logging | ||||||
|  | from taskflow.types import notifier | ||||||
| from taskflow.utils import misc | from taskflow.utils import misc | ||||||
| from taskflow.utils import reflection | from taskflow.utils import reflection | ||||||
|  |  | ||||||
| @@ -51,18 +50,28 @@ class BaseTask(atom.Atom): | |||||||
|     same piece of work. |     same piece of work. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     # Known events this task can have callbacks bound to (others that are not |     # Known internal events this task can have callbacks bound to (others that | ||||||
|     # in this set/tuple will not be able to be bound); this should be updated |     # are not in this set/tuple will not be able to be bound); this should be | ||||||
|     # and/or extended in subclasses as needed to enable or disable new or |     # updated and/or extended in subclasses as needed to enable or disable new | ||||||
|     # existing events... |     # or existing internal events... | ||||||
|     TASK_EVENTS = (EVENT_UPDATE_PROGRESS,) |     TASK_EVENTS = (EVENT_UPDATE_PROGRESS,) | ||||||
|  |  | ||||||
|     def __init__(self, name, provides=None, inject=None): |     def __init__(self, name, provides=None, inject=None): | ||||||
|         if name is None: |         if name is None: | ||||||
|             name = reflection.get_class_name(self) |             name = reflection.get_class_name(self) | ||||||
|         super(BaseTask, self).__init__(name, provides, inject=inject) |         super(BaseTask, self).__init__(name, provides, inject=inject) | ||||||
|         # Map of events => lists of callbacks to invoke on task events. |         self._notifier = notifier.RestrictedNotifier(self.TASK_EVENTS) | ||||||
|         self._events_listeners = collections.defaultdict(list) |  | ||||||
|  |     @property | ||||||
|  |     def notifier(self): | ||||||
|  |         """Internal notification dispatcher/registry. | ||||||
|  |  | ||||||
|  |         A notification object that will dispatch events that occur related | ||||||
|  |         to *internal* notifications that the task internally emits to | ||||||
|  |         listeners (for example for progress status updates, telling others | ||||||
|  |         that a task has reached 50% completion...). | ||||||
|  |         """ | ||||||
|  |         return self._notifier | ||||||
|  |  | ||||||
|     def pre_execute(self): |     def pre_execute(self): | ||||||
|         """Code to be run prior to executing the task. |         """Code to be run prior to executing the task. | ||||||
| @@ -138,152 +147,31 @@ class BaseTask(atom.Atom): | |||||||
|     def copy(self, retain_listeners=True): |     def copy(self, retain_listeners=True): | ||||||
|         """Clone/copy this task. |         """Clone/copy this task. | ||||||
|  |  | ||||||
|         :param retain_listeners: retain the attached listeners when cloning, |         :param retain_listeners: retain the attached notification listeners | ||||||
|                                  when false the listeners will be emptied, when |                                  when cloning, when false the listeners will | ||||||
|                                  true the listeners will be copied and retained |                                  be emptied, when true the listeners will be | ||||||
|  |                                  copied and retained | ||||||
|  |  | ||||||
|         :return: the copied task |         :return: the copied task | ||||||
|         """ |         """ | ||||||
|         c = copy.copy(self) |         c = copy.copy(self) | ||||||
|         c._events_listeners = collections.defaultdict(list) |         c._notifier = self._notifier.copy() | ||||||
|         if retain_listeners: |         if not retain_listeners: | ||||||
|             for event_name, listeners in six.iteritems(self._events_listeners): |             c._notifier.reset() | ||||||
|                 c._events_listeners[event_name] = listeners[:] |  | ||||||
|         return c |         return c | ||||||
|  |  | ||||||
|     def update_progress(self, progress, **kwargs): |     def update_progress(self, progress): | ||||||
|         """Update task progress and notify all registered listeners. |         """Update task progress and notify all registered listeners. | ||||||
|  |  | ||||||
|         :param progress: task progress float value between 0.0 and 1.0 |         :param progress: task progress float value between 0.0 and 1.0 | ||||||
|         :param kwargs: any keyword arguments that are tied to the specific |  | ||||||
|                        progress value. |  | ||||||
|         """ |         """ | ||||||
|         def on_clamped(): |         def on_clamped(): | ||||||
|             LOG.warn("Progress value must be greater or equal to 0.0 or less" |             LOG.warn("Progress value must be greater or equal to 0.0 or less" | ||||||
|                      " than or equal to 1.0 instead of being '%s'", progress) |                      " than or equal to 1.0 instead of being '%s'", progress) | ||||||
|         cleaned_progress = misc.clamp(progress, 0.0, 1.0, |         cleaned_progress = misc.clamp(progress, 0.0, 1.0, | ||||||
|                                       on_clamped=on_clamped) |                                       on_clamped=on_clamped) | ||||||
|         self.trigger(EVENT_UPDATE_PROGRESS, cleaned_progress, **kwargs) |         self._notifier.notify(EVENT_UPDATE_PROGRESS, | ||||||
|  |                               {'progress': cleaned_progress}) | ||||||
|     def trigger(self, event_name, *args, **kwargs): |  | ||||||
|         """Execute all callbacks registered for the given event type. |  | ||||||
|  |  | ||||||
|         NOTE(harlowja): if a bound callback raises an exception it will be |  | ||||||
|                         logged (at a ``WARNING`` level) and the exception |  | ||||||
|                         will be dropped. |  | ||||||
|  |  | ||||||
|         :param event_name: event name to trigger |  | ||||||
|         :param args: arbitrary positional arguments passed to the triggered |  | ||||||
|                      callbacks (if any are matched), these will be in addition |  | ||||||
|                      to any ``kwargs`` provided on binding (these are passed |  | ||||||
|                      as positional arguments to the callback). |  | ||||||
|         :param kwargs: arbitrary keyword arguments passed to the triggered |  | ||||||
|                      callbacks (if any are matched), these will be in addition |  | ||||||
|                      to any ``kwargs`` provided on binding (these are passed |  | ||||||
|                      as keyword arguments to the callback). |  | ||||||
|         """ |  | ||||||
|         for (cb, event_data) in self._events_listeners.get(event_name, []): |  | ||||||
|             try: |  | ||||||
|                 cb(self, event_data, *args, **kwargs) |  | ||||||
|             except Exception: |  | ||||||
|                 LOG.warn("Failed calling callback `%s` on event '%s'", |  | ||||||
|                          reflection.get_callable_name(cb), event_name, |  | ||||||
|                          exc_info=True) |  | ||||||
|  |  | ||||||
|     @contextlib.contextmanager |  | ||||||
|     def autobind(self, event_name, callback, **kwargs): |  | ||||||
|         """Binds & unbinds a given callback to the task. |  | ||||||
|  |  | ||||||
|         This function binds and unbinds using the context manager protocol. |  | ||||||
|         When events are triggered on the task of the given event name this |  | ||||||
|         callback will automatically be called with the provided |  | ||||||
|         keyword arguments as the first argument (further arguments may be |  | ||||||
|         provided by the entity triggering the event). |  | ||||||
|  |  | ||||||
|         The arguments are interpreted as for :func:`bind() <bind>`. |  | ||||||
|         """ |  | ||||||
|         bound = False |  | ||||||
|         if callback is not None: |  | ||||||
|             try: |  | ||||||
|                 self.bind(event_name, callback, **kwargs) |  | ||||||
|                 bound = True |  | ||||||
|             except ValueError: |  | ||||||
|                 LOG.warn("Failed binding callback `%s` as a receiver of" |  | ||||||
|                          " event '%s' notifications emitted from task '%s'", |  | ||||||
|                          reflection.get_callable_name(callback), event_name, |  | ||||||
|                          self, exc_info=True) |  | ||||||
|         try: |  | ||||||
|             yield self |  | ||||||
|         finally: |  | ||||||
|             if bound: |  | ||||||
|                 self.unbind(event_name, callback) |  | ||||||
|  |  | ||||||
|     def bind(self, event_name, callback, **kwargs): |  | ||||||
|         """Attach a callback to be triggered on a task event. |  | ||||||
|  |  | ||||||
|         Callbacks should *not* be bound, modified, or removed after execution |  | ||||||
|         has commenced (they may be adjusted after execution has finished). This |  | ||||||
|         is primarily due to the need to preserve the callbacks that exist at |  | ||||||
|         execution time for engines which run tasks remotely or out of |  | ||||||
|         process (so that those engines can correctly proxy back transmitted |  | ||||||
|         events). |  | ||||||
|  |  | ||||||
|         Callbacks should also be *quick* to execute so that the engine calling |  | ||||||
|         them can continue execution in a timely manner (if long running |  | ||||||
|         callbacks need to exist, consider creating a separate pool + queue |  | ||||||
|         for those that the attached callbacks put long running operations into |  | ||||||
|         for execution by other entities). |  | ||||||
|  |  | ||||||
|         :param event_name: event type name |  | ||||||
|         :param callback: callable to execute each time event is triggered |  | ||||||
|         :param kwargs: optional named parameters that will be passed to the |  | ||||||
|                        callable object as a dictionary to the callbacks |  | ||||||
|                        *second* positional parameter. |  | ||||||
|         :raises ValueError: if invalid event type, or callback is passed |  | ||||||
|         """ |  | ||||||
|         if event_name not in self.TASK_EVENTS: |  | ||||||
|             raise ValueError("Unknown task event '%s', can only bind" |  | ||||||
|                              " to events %s" % (event_name, self.TASK_EVENTS)) |  | ||||||
|         if callback is not None: |  | ||||||
|             if not six.callable(callback): |  | ||||||
|                 raise ValueError("Event handler callback must be callable") |  | ||||||
|             self._events_listeners[event_name].append((callback, kwargs)) |  | ||||||
|  |  | ||||||
|     def unbind(self, event_name, callback=None): |  | ||||||
|         """Remove a previously-attached event callback from the task. |  | ||||||
|  |  | ||||||
|         If a callback is not passed, then this will unbind *all* event |  | ||||||
|         callbacks for the provided event. If multiple of the same callbacks |  | ||||||
|         are bound, then the first match is removed (and only the first match). |  | ||||||
|  |  | ||||||
|         :param event_name: event type |  | ||||||
|         :param callback: callback previously bound |  | ||||||
|  |  | ||||||
|         :rtype: boolean |  | ||||||
|         :return: whether anything was removed |  | ||||||
|         """ |  | ||||||
|         removed_any = False |  | ||||||
|         if not callback: |  | ||||||
|             removed_any = self._events_listeners.pop(event_name, removed_any) |  | ||||||
|         else: |  | ||||||
|             event_listeners = self._events_listeners.get(event_name, []) |  | ||||||
|             for i, (cb, _event_data) in enumerate(event_listeners): |  | ||||||
|                 if reflection.is_same_callback(cb, callback): |  | ||||||
|                     # NOTE(harlowja): its safe to do this as long as we stop |  | ||||||
|                     # iterating after we do the removal, otherwise its not |  | ||||||
|                     # safe (since this could have resized the list). |  | ||||||
|                     event_listeners.pop(i) |  | ||||||
|                     removed_any = True |  | ||||||
|                     break |  | ||||||
|         return bool(removed_any) |  | ||||||
|  |  | ||||||
|     def listeners_iter(self): |  | ||||||
|         """Return an iterator over the mapping of event => callbacks bound.""" |  | ||||||
|         for event_name in list(six.iterkeys(self._events_listeners)): |  | ||||||
|             # Use get() just incase it was removed while iterating... |  | ||||||
|             event_listeners = self._events_listeners.get(event_name, []) |  | ||||||
|             if event_listeners: |  | ||||||
|                 yield (event_name, event_listeners[:]) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Task(BaseTask): | class Task(BaseTask): | ||||||
|   | |||||||
| @@ -84,6 +84,31 @@ class NotifierTest(test.TestCase): | |||||||
|         self.assertRaises(ValueError, notifier.register, |         self.assertRaises(ValueError, notifier.register, | ||||||
|                           nt.Notifier.ANY, 2) |                           nt.Notifier.ANY, 2) | ||||||
|  |  | ||||||
|  |     def test_restricted_notifier(self): | ||||||
|  |         notifier = nt.RestrictedNotifier(['a', 'b']) | ||||||
|  |         self.assertRaises(ValueError, notifier.register, | ||||||
|  |                           'c', lambda *args, **kargs: None) | ||||||
|  |         notifier.register('b', lambda *args, **kargs: None) | ||||||
|  |         self.assertEqual(1, len(notifier)) | ||||||
|  |  | ||||||
|  |     def test_restricted_notifier_any(self): | ||||||
|  |         notifier = nt.RestrictedNotifier(['a', 'b']) | ||||||
|  |         self.assertRaises(ValueError, notifier.register, | ||||||
|  |                           'c', lambda *args, **kargs: None) | ||||||
|  |         notifier.register('b', lambda *args, **kargs: None) | ||||||
|  |         self.assertEqual(1, len(notifier)) | ||||||
|  |         notifier.register(nt.RestrictedNotifier.ANY, | ||||||
|  |                           lambda *args, **kargs: None) | ||||||
|  |         self.assertEqual(2, len(notifier)) | ||||||
|  |  | ||||||
|  |     def test_restricted_notifier_no_any(self): | ||||||
|  |         notifier = nt.RestrictedNotifier(['a', 'b'], allow_any=False) | ||||||
|  |         self.assertRaises(ValueError, notifier.register, | ||||||
|  |                           nt.RestrictedNotifier.ANY, | ||||||
|  |                           lambda *args, **kargs: None) | ||||||
|  |         notifier.register('b', lambda *args, **kargs: None) | ||||||
|  |         self.assertEqual(1, len(notifier)) | ||||||
|  |  | ||||||
|     def test_selective_notify(self): |     def test_selective_notify(self): | ||||||
|         call_counts = collections.defaultdict(list) |         call_counts = collections.defaultdict(list) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -39,7 +39,12 @@ class ProgressTask(task.Task): | |||||||
|  |  | ||||||
| class ProgressTaskWithDetails(task.Task): | class ProgressTaskWithDetails(task.Task): | ||||||
|     def execute(self): |     def execute(self): | ||||||
|         self.update_progress(0.5, test='test data', foo='bar') |         details = { | ||||||
|  |             'progress': 0.5, | ||||||
|  |             'test': 'test data', | ||||||
|  |             'foo': 'bar', | ||||||
|  |         } | ||||||
|  |         self.notifier.notify(task.EVENT_UPDATE_PROGRESS, details) | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestProgress(test.TestCase): | class TestProgress(test.TestCase): | ||||||
| @@ -60,12 +65,12 @@ class TestProgress(test.TestCase): | |||||||
|     def test_sanity_progress(self): |     def test_sanity_progress(self): | ||||||
|         fired_events = [] |         fired_events = [] | ||||||
|  |  | ||||||
|         def notify_me(task, event_data, progress): |         def notify_me(event_type, details): | ||||||
|             fired_events.append(progress) |             fired_events.append(details.pop('progress')) | ||||||
|  |  | ||||||
|         ev_count = 5 |         ev_count = 5 | ||||||
|         t = ProgressTask("test", ev_count) |         t = ProgressTask("test", ev_count) | ||||||
|         t.bind('update_progress', notify_me) |         t.notifier.register(task.EVENT_UPDATE_PROGRESS, notify_me) | ||||||
|         flo = lf.Flow("test") |         flo = lf.Flow("test") | ||||||
|         flo.add(t) |         flo.add(t) | ||||||
|         e = self._make_engine(flo) |         e = self._make_engine(flo) | ||||||
| @@ -77,11 +82,11 @@ class TestProgress(test.TestCase): | |||||||
|     def test_no_segments_progress(self): |     def test_no_segments_progress(self): | ||||||
|         fired_events = [] |         fired_events = [] | ||||||
|  |  | ||||||
|         def notify_me(task, event_data, progress): |         def notify_me(event_type, details): | ||||||
|             fired_events.append(progress) |             fired_events.append(details.pop('progress')) | ||||||
|  |  | ||||||
|         t = ProgressTask("test", 0) |         t = ProgressTask("test", 0) | ||||||
|         t.bind('update_progress', notify_me) |         t.notifier.register(task.EVENT_UPDATE_PROGRESS, notify_me) | ||||||
|         flo = lf.Flow("test") |         flo = lf.Flow("test") | ||||||
|         flo.add(t) |         flo.add(t) | ||||||
|         e = self._make_engine(flo) |         e = self._make_engine(flo) | ||||||
| @@ -121,12 +126,12 @@ class TestProgress(test.TestCase): | |||||||
|     def test_dual_storage_progress(self): |     def test_dual_storage_progress(self): | ||||||
|         fired_events = [] |         fired_events = [] | ||||||
|  |  | ||||||
|         def notify_me(task, event_data, progress): |         def notify_me(event_type, details): | ||||||
|             fired_events.append(progress) |             fired_events.append(details.pop('progress')) | ||||||
|  |  | ||||||
|         with contextlib.closing(impl_memory.MemoryBackend({})) as be: |         with contextlib.closing(impl_memory.MemoryBackend({})) as be: | ||||||
|             t = ProgressTask("test", 5) |             t = ProgressTask("test", 5) | ||||||
|             t.bind('update_progress', notify_me) |             t.notifier.register(task.EVENT_UPDATE_PROGRESS, notify_me) | ||||||
|             flo = lf.Flow("test") |             flo = lf.Flow("test") | ||||||
|             flo.add(t) |             flo.add(t) | ||||||
|             b, fd = p_utils.temporary_flow_detail(be) |             b, fd = p_utils.temporary_flow_detail(be) | ||||||
|   | |||||||
| @@ -17,7 +17,7 @@ | |||||||
| from taskflow import task | from taskflow import task | ||||||
| from taskflow import test | from taskflow import test | ||||||
| from taskflow.test import mock | from taskflow.test import mock | ||||||
| from taskflow.utils import reflection | from taskflow.types import notifier | ||||||
|  |  | ||||||
|  |  | ||||||
| class MyTask(task.Task): | class MyTask(task.Task): | ||||||
| @@ -198,11 +198,11 @@ class TaskTest(test.TestCase): | |||||||
|         values = [0.0, 0.5, 1.0] |         values = [0.0, 0.5, 1.0] | ||||||
|         result = [] |         result = [] | ||||||
|  |  | ||||||
|         def progress_callback(task, event_data, progress): |         def progress_callback(event_type, details): | ||||||
|             result.append(progress) |             result.append(details.pop('progress')) | ||||||
|  |  | ||||||
|         a_task = ProgressTask() |         a_task = ProgressTask() | ||||||
|         with a_task.autobind(task.EVENT_UPDATE_PROGRESS, progress_callback): |         a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, progress_callback) | ||||||
|         a_task.execute(values) |         a_task.execute(values) | ||||||
|         self.assertEqual(result, values) |         self.assertEqual(result, values) | ||||||
|  |  | ||||||
| @@ -210,11 +210,11 @@ class TaskTest(test.TestCase): | |||||||
|     def test_update_progress_lower_bound(self, mocked_warn): |     def test_update_progress_lower_bound(self, mocked_warn): | ||||||
|         result = [] |         result = [] | ||||||
|  |  | ||||||
|         def progress_callback(task, event_data, progress): |         def progress_callback(event_type, details): | ||||||
|             result.append(progress) |             result.append(details.pop('progress')) | ||||||
|  |  | ||||||
|         a_task = ProgressTask() |         a_task = ProgressTask() | ||||||
|         with a_task.autobind(task.EVENT_UPDATE_PROGRESS, progress_callback): |         a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, progress_callback) | ||||||
|         a_task.execute([-1.0, -0.5, 0.0]) |         a_task.execute([-1.0, -0.5, 0.0]) | ||||||
|         self.assertEqual(result, [0.0, 0.0, 0.0]) |         self.assertEqual(result, [0.0, 0.0, 0.0]) | ||||||
|         self.assertEqual(mocked_warn.call_count, 2) |         self.assertEqual(mocked_warn.call_count, 2) | ||||||
| @@ -223,81 +223,87 @@ class TaskTest(test.TestCase): | |||||||
|     def test_update_progress_upper_bound(self, mocked_warn): |     def test_update_progress_upper_bound(self, mocked_warn): | ||||||
|         result = [] |         result = [] | ||||||
|  |  | ||||||
|         def progress_callback(task, event_data, progress): |         def progress_callback(event_type, details): | ||||||
|             result.append(progress) |             result.append(details.pop('progress')) | ||||||
|  |  | ||||||
|         a_task = ProgressTask() |         a_task = ProgressTask() | ||||||
|         with a_task.autobind(task.EVENT_UPDATE_PROGRESS, progress_callback): |         a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, progress_callback) | ||||||
|         a_task.execute([1.0, 1.5, 2.0]) |         a_task.execute([1.0, 1.5, 2.0]) | ||||||
|         self.assertEqual(result, [1.0, 1.0, 1.0]) |         self.assertEqual(result, [1.0, 1.0, 1.0]) | ||||||
|         self.assertEqual(mocked_warn.call_count, 2) |         self.assertEqual(mocked_warn.call_count, 2) | ||||||
|  |  | ||||||
|     @mock.patch.object(task.LOG, 'warn') |     @mock.patch.object(notifier.LOG, 'warn') | ||||||
|     def test_update_progress_handler_failure(self, mocked_warn): |     def test_update_progress_handler_failure(self, mocked_warn): | ||||||
|  |  | ||||||
|         def progress_callback(*args, **kwargs): |         def progress_callback(*args, **kwargs): | ||||||
|             raise Exception('Woot!') |             raise Exception('Woot!') | ||||||
|  |  | ||||||
|         a_task = ProgressTask() |         a_task = ProgressTask() | ||||||
|         with a_task.autobind(task.EVENT_UPDATE_PROGRESS, progress_callback): |         a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, progress_callback) | ||||||
|         a_task.execute([0.5]) |         a_task.execute([0.5]) | ||||||
|         mocked_warn.assert_called_once_with( |         mocked_warn.assert_called_once() | ||||||
|             mock.ANY, reflection.get_callable_name(progress_callback), |  | ||||||
|             task.EVENT_UPDATE_PROGRESS, exc_info=mock.ANY) |  | ||||||
|  |  | ||||||
|     def test_autobind_handler_is_none(self): |     def test_register_handler_is_none(self): | ||||||
|         a_task = MyTask() |         a_task = MyTask() | ||||||
|         with a_task.autobind(task.EVENT_UPDATE_PROGRESS, None): |         self.assertRaises(ValueError, a_task.notifier.register, | ||||||
|             self.assertEqual(len(list(a_task.listeners_iter())), 0) |                           task.EVENT_UPDATE_PROGRESS, None) | ||||||
|  |         self.assertEqual(len(a_task.notifier), 0) | ||||||
|  |  | ||||||
|     def test_unbind_any_handler(self): |     def test_deregister_any_handler(self): | ||||||
|         a_task = MyTask() |         a_task = MyTask() | ||||||
|         self.assertEqual(len(list(a_task.listeners_iter())), 0) |         self.assertEqual(len(a_task.notifier), 0) | ||||||
|         a_task.bind(task.EVENT_UPDATE_PROGRESS, lambda: None) |         a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, | ||||||
|         self.assertEqual(len(list(a_task.listeners_iter())), 1) |                                  lambda event_type, details: None) | ||||||
|         self.assertTrue(a_task.unbind(task.EVENT_UPDATE_PROGRESS)) |         self.assertEqual(len(a_task.notifier), 1) | ||||||
|         self.assertEqual(len(list(a_task.listeners_iter())), 0) |         a_task.notifier.deregister_event(task.EVENT_UPDATE_PROGRESS) | ||||||
|  |         self.assertEqual(len(a_task.notifier), 0) | ||||||
|  |  | ||||||
|     def test_unbind_any_handler_empty_listeners(self): |     def test_deregister_any_handler_empty_listeners(self): | ||||||
|         a_task = MyTask() |         a_task = MyTask() | ||||||
|         self.assertEqual(len(list(a_task.listeners_iter())), 0) |         self.assertEqual(len(a_task.notifier), 0) | ||||||
|         self.assertFalse(a_task.unbind(task.EVENT_UPDATE_PROGRESS)) |         self.assertFalse(a_task.notifier.deregister_event( | ||||||
|         self.assertEqual(len(list(a_task.listeners_iter())), 0) |             task.EVENT_UPDATE_PROGRESS)) | ||||||
|  |         self.assertEqual(len(a_task.notifier), 0) | ||||||
|  |  | ||||||
|     def test_unbind_non_existent_listener(self): |     def test_deregister_non_existent_listener(self): | ||||||
|         handler1 = lambda: None |         handler1 = lambda event_type, details: None | ||||||
|         handler2 = lambda: None |         handler2 = lambda event_type, details: None | ||||||
|         a_task = MyTask() |         a_task = MyTask() | ||||||
|         a_task.bind(task.EVENT_UPDATE_PROGRESS, handler1) |         a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, handler1) | ||||||
|         self.assertEqual(len(list(a_task.listeners_iter())), 1) |         self.assertEqual(len(list(a_task.notifier.listeners_iter())), 1) | ||||||
|         self.assertFalse(a_task.unbind(task.EVENT_UPDATE_PROGRESS, handler2)) |         a_task.notifier.deregister(task.EVENT_UPDATE_PROGRESS, handler2) | ||||||
|         self.assertEqual(len(list(a_task.listeners_iter())), 1) |         self.assertEqual(len(list(a_task.notifier.listeners_iter())), 1) | ||||||
|  |         a_task.notifier.deregister(task.EVENT_UPDATE_PROGRESS, handler1) | ||||||
|  |         self.assertEqual(len(list(a_task.notifier.listeners_iter())), 0) | ||||||
|  |  | ||||||
|     def test_bind_not_callable(self): |     def test_bind_not_callable(self): | ||||||
|         task = MyTask() |         a_task = MyTask() | ||||||
|         self.assertRaises(ValueError, task.bind, 'update_progress', 2) |         self.assertRaises(ValueError, a_task.notifier.register, | ||||||
|  |                           task.EVENT_UPDATE_PROGRESS, 2) | ||||||
|  |  | ||||||
|     def test_copy_no_listeners(self): |     def test_copy_no_listeners(self): | ||||||
|         handler1 = lambda: None |         handler1 = lambda event_type, details: None | ||||||
|         a_task = MyTask() |         a_task = MyTask() | ||||||
|         a_task.bind(task.EVENT_UPDATE_PROGRESS, handler1) |         a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, handler1) | ||||||
|         b_task = a_task.copy(retain_listeners=False) |         b_task = a_task.copy(retain_listeners=False) | ||||||
|         self.assertEqual(len(list(a_task.listeners_iter())), 1) |         self.assertEqual(len(a_task.notifier), 1) | ||||||
|         self.assertEqual(len(list(b_task.listeners_iter())), 0) |         self.assertEqual(len(b_task.notifier), 0) | ||||||
|  |  | ||||||
|     def test_copy_listeners(self): |     def test_copy_listeners(self): | ||||||
|         handler1 = lambda: None |         handler1 = lambda event_type, details: None | ||||||
|         handler2 = lambda: None |         handler2 = lambda event_type, details: None | ||||||
|         a_task = MyTask() |         a_task = MyTask() | ||||||
|         a_task.bind(task.EVENT_UPDATE_PROGRESS, handler1) |         a_task.notifier.register(task.EVENT_UPDATE_PROGRESS, handler1) | ||||||
|         b_task = a_task.copy() |         b_task = a_task.copy() | ||||||
|         self.assertEqual(len(list(b_task.listeners_iter())), 1) |         self.assertEqual(len(b_task.notifier), 1) | ||||||
|         self.assertTrue(a_task.unbind(task.EVENT_UPDATE_PROGRESS)) |         self.assertTrue(a_task.notifier.deregister_event( | ||||||
|         self.assertEqual(len(list(a_task.listeners_iter())), 0) |             task.EVENT_UPDATE_PROGRESS)) | ||||||
|         self.assertEqual(len(list(b_task.listeners_iter())), 1) |         self.assertEqual(len(a_task.notifier), 0) | ||||||
|         b_task.bind(task.EVENT_UPDATE_PROGRESS, handler2) |         self.assertEqual(len(b_task.notifier), 1) | ||||||
|         listeners = dict(list(b_task.listeners_iter())) |         b_task.notifier.register(task.EVENT_UPDATE_PROGRESS, handler2) | ||||||
|  |         listeners = dict(list(b_task.notifier.listeners_iter())) | ||||||
|         self.assertEqual(len(listeners[task.EVENT_UPDATE_PROGRESS]), 2) |         self.assertEqual(len(listeners[task.EVENT_UPDATE_PROGRESS]), 2) | ||||||
|         self.assertEqual(len(list(a_task.listeners_iter())), 0) |         self.assertEqual(len(a_task.notifier), 0) | ||||||
|  |  | ||||||
|  |  | ||||||
| class FunctorTaskTest(test.TestCase): | class FunctorTaskTest(test.TestCase): | ||||||
|   | |||||||
| @@ -42,14 +42,14 @@ class TestEndpoint(test.TestCase): | |||||||
|         self.task_result = 1 |         self.task_result = 1 | ||||||
|  |  | ||||||
|     def test_creation(self): |     def test_creation(self): | ||||||
|         task = self.task_ep._get_task() |         task = self.task_ep.generate() | ||||||
|         self.assertEqual(self.task_ep.name, self.task_cls_name) |         self.assertEqual(self.task_ep.name, self.task_cls_name) | ||||||
|         self.assertIsInstance(task, self.task_cls) |         self.assertIsInstance(task, self.task_cls) | ||||||
|         self.assertEqual(task.name, self.task_cls_name) |         self.assertEqual(task.name, self.task_cls_name) | ||||||
|  |  | ||||||
|     def test_creation_with_task_name(self): |     def test_creation_with_task_name(self): | ||||||
|         task_name = 'test' |         task_name = 'test' | ||||||
|         task = self.task_ep._get_task(name=task_name) |         task = self.task_ep.generate(name=task_name) | ||||||
|         self.assertEqual(self.task_ep.name, self.task_cls_name) |         self.assertEqual(self.task_ep.name, self.task_cls_name) | ||||||
|         self.assertIsInstance(task, self.task_cls) |         self.assertIsInstance(task, self.task_cls) | ||||||
|         self.assertEqual(task.name, task_name) |         self.assertEqual(task.name, task_name) | ||||||
| @@ -58,20 +58,22 @@ class TestEndpoint(test.TestCase): | |||||||
|         # NOTE(skudriashev): Exception is expected here since task |         # NOTE(skudriashev): Exception is expected here since task | ||||||
|         # is created without any arguments passing to its constructor. |         # is created without any arguments passing to its constructor. | ||||||
|         endpoint = ep.Endpoint(Task) |         endpoint = ep.Endpoint(Task) | ||||||
|         self.assertRaises(TypeError, endpoint._get_task) |         self.assertRaises(TypeError, endpoint.generate) | ||||||
|  |  | ||||||
|     def test_to_str(self): |     def test_to_str(self): | ||||||
|         self.assertEqual(str(self.task_ep), self.task_cls_name) |         self.assertEqual(str(self.task_ep), self.task_cls_name) | ||||||
|  |  | ||||||
|     def test_execute(self): |     def test_execute(self): | ||||||
|         result = self.task_ep.execute(task_name=self.task_cls_name, |         task = self.task_ep.generate(self.task_cls_name) | ||||||
|  |         result = self.task_ep.execute(task, | ||||||
|                                       task_uuid=self.task_uuid, |                                       task_uuid=self.task_uuid, | ||||||
|                                       arguments=self.task_args, |                                       arguments=self.task_args, | ||||||
|                                       progress_callback=None) |                                       progress_callback=None) | ||||||
|         self.assertEqual(result, self.task_result) |         self.assertEqual(result, self.task_result) | ||||||
|  |  | ||||||
|     def test_revert(self): |     def test_revert(self): | ||||||
|         result = self.task_ep.revert(task_name=self.task_cls_name, |         task = self.task_ep.generate(self.task_cls_name) | ||||||
|  |         result = self.task_ep.revert(task, | ||||||
|                                      task_uuid=self.task_uuid, |                                      task_uuid=self.task_uuid, | ||||||
|                                      arguments=self.task_args, |                                      arguments=self.task_args, | ||||||
|                                      progress_callback=None, |                                      progress_callback=None, | ||||||
|   | |||||||
| @@ -21,6 +21,7 @@ from oslo.utils import timeutils | |||||||
|  |  | ||||||
| from taskflow.engines.worker_based import executor | from taskflow.engines.worker_based import executor | ||||||
| from taskflow.engines.worker_based import protocol as pr | from taskflow.engines.worker_based import protocol as pr | ||||||
|  | from taskflow import task as task_atom | ||||||
| from taskflow import test | from taskflow import test | ||||||
| from taskflow.test import mock | from taskflow.test import mock | ||||||
| from taskflow.tests import utils as test_utils | from taskflow.tests import utils as test_utils | ||||||
| @@ -102,13 +103,18 @@ class TestWorkerTaskExecutor(test.MockTestCase): | |||||||
|         self.assertEqual(expected_calls, self.request_inst_mock.mock_calls) |         self.assertEqual(expected_calls, self.request_inst_mock.mock_calls) | ||||||
|  |  | ||||||
|     def test_on_message_response_state_progress(self): |     def test_on_message_response_state_progress(self): | ||||||
|         response = pr.Response(pr.PROGRESS, progress=1.0) |         response = pr.Response(pr.EVENT, | ||||||
|  |                                event_type=task_atom.EVENT_UPDATE_PROGRESS, | ||||||
|  |                                details={'progress': 1.0}) | ||||||
|         ex = self.executor() |         ex = self.executor() | ||||||
|         ex._requests_cache[self.task_uuid] = self.request_inst_mock |         ex._requests_cache[self.task_uuid] = self.request_inst_mock | ||||||
|         ex._process_response(response.to_dict(), self.message_mock) |         ex._process_response(response.to_dict(), self.message_mock) | ||||||
|  |  | ||||||
|         self.assertEqual(self.request_inst_mock.mock_calls, |         expected_calls = [ | ||||||
|                          [mock.call.on_progress(progress=1.0)]) |             mock.call.notifier.notify(task_atom.EVENT_UPDATE_PROGRESS, | ||||||
|  |                                       {'progress': 1.0}), | ||||||
|  |         ] | ||||||
|  |         self.assertEqual(expected_calls, self.request_inst_mock.mock_calls) | ||||||
|  |  | ||||||
|     def test_on_message_response_state_failure(self): |     def test_on_message_response_state_failure(self): | ||||||
|         a_failure = failure.Failure.from_exception(Exception('test')) |         a_failure = failure.Failure.from_exception(Exception('test')) | ||||||
| @@ -211,7 +217,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): | |||||||
|  |  | ||||||
|         expected_calls = [ |         expected_calls = [ | ||||||
|             mock.call.Request(self.task, self.task_uuid, 'execute', |             mock.call.Request(self.task, self.task_uuid, 'execute', | ||||||
|                               self.task_args, None, self.timeout), |                               self.task_args, self.timeout), | ||||||
|             mock.call.request.transition_and_log_error(pr.PENDING, |             mock.call.request.transition_and_log_error(pr.PENDING, | ||||||
|                                                        logger=mock.ANY), |                                                        logger=mock.ANY), | ||||||
|             mock.call.proxy.publish(self.request_inst_mock, |             mock.call.proxy.publish(self.request_inst_mock, | ||||||
| @@ -231,7 +237,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): | |||||||
|  |  | ||||||
|         expected_calls = [ |         expected_calls = [ | ||||||
|             mock.call.Request(self.task, self.task_uuid, 'revert', |             mock.call.Request(self.task, self.task_uuid, 'revert', | ||||||
|                               self.task_args, None, self.timeout, |                               self.task_args, self.timeout, | ||||||
|                               failures=self.task_failures, |                               failures=self.task_failures, | ||||||
|                               result=self.task_result), |                               result=self.task_result), | ||||||
|             mock.call.request.transition_and_log_error(pr.PENDING, |             mock.call.request.transition_and_log_error(pr.PENDING, | ||||||
| @@ -250,7 +256,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): | |||||||
|  |  | ||||||
|         expected_calls = [ |         expected_calls = [ | ||||||
|             mock.call.Request(self.task, self.task_uuid, 'execute', |             mock.call.Request(self.task, self.task_uuid, 'execute', | ||||||
|                               self.task_args, None, self.timeout) |                               self.task_args, self.timeout), | ||||||
|         ] |         ] | ||||||
|         self.assertEqual(self.master_mock.mock_calls, expected_calls) |         self.assertEqual(self.master_mock.mock_calls, expected_calls) | ||||||
|  |  | ||||||
| @@ -264,7 +270,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): | |||||||
|  |  | ||||||
|         expected_calls = [ |         expected_calls = [ | ||||||
|             mock.call.Request(self.task, self.task_uuid, 'execute', |             mock.call.Request(self.task, self.task_uuid, 'execute', | ||||||
|                               self.task_args, None, self.timeout), |                               self.task_args, self.timeout), | ||||||
|             mock.call.request.transition_and_log_error(pr.PENDING, |             mock.call.request.transition_and_log_error(pr.PENDING, | ||||||
|                                                        logger=mock.ANY), |                                                        logger=mock.ANY), | ||||||
|             mock.call.proxy.publish(self.request_inst_mock, |             mock.call.proxy.publish(self.request_inst_mock, | ||||||
|   | |||||||
| @@ -118,7 +118,7 @@ class TestMessagePump(test.TestCase): | |||||||
|             else: |             else: | ||||||
|                 p.publish(pr.Request(test_utils.DummyTask("dummy_%s" % i), |                 p.publish(pr.Request(test_utils.DummyTask("dummy_%s" % i), | ||||||
|                                      uuidutils.generate_uuid(), |                                      uuidutils.generate_uuid(), | ||||||
|                                      pr.EXECUTE, [], None, None), TEST_TOPIC) |                                      pr.EXECUTE, [], None), TEST_TOPIC) | ||||||
|  |  | ||||||
|         self.assertTrue(barrier.wait(test_utils.WAIT_TIMEOUT)) |         self.assertTrue(barrier.wait(test_utils.WAIT_TIMEOUT)) | ||||||
|         self.assertEqual(0, barrier.needed) |         self.assertEqual(0, barrier.needed) | ||||||
|   | |||||||
| @@ -22,7 +22,6 @@ from taskflow.engines.worker_based import protocol as pr | |||||||
| from taskflow import exceptions as excp | from taskflow import exceptions as excp | ||||||
| from taskflow.openstack.common import uuidutils | from taskflow.openstack.common import uuidutils | ||||||
| from taskflow import test | from taskflow import test | ||||||
| from taskflow.test import mock |  | ||||||
| from taskflow.tests import utils | from taskflow.tests import utils | ||||||
| from taskflow.types import failure | from taskflow.types import failure | ||||||
|  |  | ||||||
| @@ -53,7 +52,7 @@ class TestProtocolValidation(test.TestCase): | |||||||
|  |  | ||||||
|     def test_request(self): |     def test_request(self): | ||||||
|         msg = pr.Request(utils.DummyTask("hi"), uuidutils.generate_uuid(), |         msg = pr.Request(utils.DummyTask("hi"), uuidutils.generate_uuid(), | ||||||
|                          pr.EXECUTE, {}, None, 1.0) |                          pr.EXECUTE, {}, 1.0) | ||||||
|         pr.Request.validate(msg.to_dict()) |         pr.Request.validate(msg.to_dict()) | ||||||
|  |  | ||||||
|     def test_request_invalid(self): |     def test_request_invalid(self): | ||||||
| @@ -66,13 +65,14 @@ class TestProtocolValidation(test.TestCase): | |||||||
|  |  | ||||||
|     def test_request_invalid_action(self): |     def test_request_invalid_action(self): | ||||||
|         msg = pr.Request(utils.DummyTask("hi"), uuidutils.generate_uuid(), |         msg = pr.Request(utils.DummyTask("hi"), uuidutils.generate_uuid(), | ||||||
|                          pr.EXECUTE, {}, None, 1.0) |                          pr.EXECUTE, {}, 1.0) | ||||||
|         msg = msg.to_dict() |         msg = msg.to_dict() | ||||||
|         msg['action'] = 'NOTHING' |         msg['action'] = 'NOTHING' | ||||||
|         self.assertRaises(excp.InvalidFormat, pr.Request.validate, msg) |         self.assertRaises(excp.InvalidFormat, pr.Request.validate, msg) | ||||||
|  |  | ||||||
|     def test_response_progress(self): |     def test_response_progress(self): | ||||||
|         msg = pr.Response(pr.PROGRESS, progress=0.5, event_data={}) |         msg = pr.Response(pr.EVENT, details={'progress': 0.5}, | ||||||
|  |                           event_type='blah') | ||||||
|         pr.Response.validate(msg.to_dict()) |         pr.Response.validate(msg.to_dict()) | ||||||
|  |  | ||||||
|     def test_response_completion(self): |     def test_response_completion(self): | ||||||
| @@ -80,7 +80,9 @@ class TestProtocolValidation(test.TestCase): | |||||||
|         pr.Response.validate(msg.to_dict()) |         pr.Response.validate(msg.to_dict()) | ||||||
|  |  | ||||||
|     def test_response_mixed_invalid(self): |     def test_response_mixed_invalid(self): | ||||||
|         msg = pr.Response(pr.PROGRESS, progress=0.5, event_data={}, result=1) |         msg = pr.Response(pr.EVENT, | ||||||
|  |                           details={'progress': 0.5}, | ||||||
|  |                           event_type='blah', result=1) | ||||||
|         self.assertRaises(excp.InvalidFormat, pr.Response.validate, msg) |         self.assertRaises(excp.InvalidFormat, pr.Response.validate, msg) | ||||||
|  |  | ||||||
|     def test_response_bad_state(self): |     def test_response_bad_state(self): | ||||||
| @@ -184,16 +186,3 @@ class TestProtocol(test.TestCase): | |||||||
|         request.set_result(111) |         request.set_result(111) | ||||||
|         result = request.result.result() |         result = request.result.result() | ||||||
|         self.assertEqual(result, (executor.EXECUTED, 111)) |         self.assertEqual(result, (executor.EXECUTED, 111)) | ||||||
|  |  | ||||||
|     def test_on_progress(self): |  | ||||||
|         progress_callback = mock.MagicMock(name='progress_callback') |  | ||||||
|         request = self.request(task=self.task, |  | ||||||
|                                progress_callback=progress_callback) |  | ||||||
|         request.on_progress('event_data', 0.0) |  | ||||||
|         request.on_progress('event_data', 1.0) |  | ||||||
|  |  | ||||||
|         expected_calls = [ |  | ||||||
|             mock.call(self.task, 'event_data', 0.0), |  | ||||||
|             mock.call(self.task, 'event_data', 1.0) |  | ||||||
|         ] |  | ||||||
|         self.assertEqual(progress_callback.mock_calls, expected_calls) |  | ||||||
|   | |||||||
| @@ -19,6 +19,7 @@ import six | |||||||
| from taskflow.engines.worker_based import endpoint as ep | from taskflow.engines.worker_based import endpoint as ep | ||||||
| from taskflow.engines.worker_based import protocol as pr | from taskflow.engines.worker_based import protocol as pr | ||||||
| from taskflow.engines.worker_based import server | from taskflow.engines.worker_based import server | ||||||
|  | from taskflow import task as task_atom | ||||||
| from taskflow import test | from taskflow import test | ||||||
| from taskflow.test import mock | from taskflow.test import mock | ||||||
| from taskflow.tests import utils | from taskflow.tests import utils | ||||||
| @@ -103,45 +104,41 @@ class TestServer(test.MockTestCase): | |||||||
|  |  | ||||||
|     def test_parse_request(self): |     def test_parse_request(self): | ||||||
|         request = self.make_request() |         request = self.make_request() | ||||||
|         task_cls, action, task_args = server.Server._parse_request(**request) |         bundle = server.Server._parse_request(**request) | ||||||
|  |         task_cls, task_name, action, task_args = bundle | ||||||
|         self.assertEqual((task_cls, action, task_args), |         self.assertEqual((task_cls, task_name, action, task_args), | ||||||
|                          (self.task.name, self.task_action, |                          (self.task.name, self.task.name, self.task_action, | ||||||
|                           dict(task_name=self.task.name, |                           dict(arguments=self.task_args))) | ||||||
|                                arguments=self.task_args))) |  | ||||||
|  |  | ||||||
|     def test_parse_request_with_success_result(self): |     def test_parse_request_with_success_result(self): | ||||||
|         request = self.make_request(action='revert', result=1) |         request = self.make_request(action='revert', result=1) | ||||||
|         task_cls, action, task_args = server.Server._parse_request(**request) |         bundle = server.Server._parse_request(**request) | ||||||
|  |         task_cls, task_name, action, task_args = bundle | ||||||
|         self.assertEqual((task_cls, action, task_args), |         self.assertEqual((task_cls, task_name, action, task_args), | ||||||
|                          (self.task.name, 'revert', |                          (self.task.name, self.task.name, 'revert', | ||||||
|                           dict(task_name=self.task.name, |                           dict(arguments=self.task_args, | ||||||
|                                arguments=self.task_args, |  | ||||||
|                                result=1))) |                                result=1))) | ||||||
|  |  | ||||||
|     def test_parse_request_with_failure_result(self): |     def test_parse_request_with_failure_result(self): | ||||||
|         a_failure = failure.Failure.from_exception(Exception('test')) |         a_failure = failure.Failure.from_exception(Exception('test')) | ||||||
|         request = self.make_request(action='revert', result=a_failure) |         request = self.make_request(action='revert', result=a_failure) | ||||||
|         task_cls, action, task_args = server.Server._parse_request(**request) |         bundle = server.Server._parse_request(**request) | ||||||
|  |         task_cls, task_name, action, task_args = bundle | ||||||
|         self.assertEqual((task_cls, action, task_args), |         self.assertEqual((task_cls, task_name, action, task_args), | ||||||
|                          (self.task.name, 'revert', |                          (self.task.name, self.task.name, 'revert', | ||||||
|                           dict(task_name=self.task.name, |                           dict(arguments=self.task_args, | ||||||
|                                arguments=self.task_args, |  | ||||||
|                                result=utils.FailureMatcher(a_failure)))) |                                result=utils.FailureMatcher(a_failure)))) | ||||||
|  |  | ||||||
|     def test_parse_request_with_failures(self): |     def test_parse_request_with_failures(self): | ||||||
|         failures = {'0': failure.Failure.from_exception(Exception('test1')), |         failures = {'0': failure.Failure.from_exception(Exception('test1')), | ||||||
|                     '1': failure.Failure.from_exception(Exception('test2'))} |                     '1': failure.Failure.from_exception(Exception('test2'))} | ||||||
|         request = self.make_request(action='revert', failures=failures) |         request = self.make_request(action='revert', failures=failures) | ||||||
|         task_cls, action, task_args = server.Server._parse_request(**request) |         bundle = server.Server._parse_request(**request) | ||||||
|  |         task_cls, task_name, action, task_args = bundle | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
|             (task_cls, action, task_args), |             (task_cls, task_name, action, task_args), | ||||||
|             (self.task.name, 'revert', |             (self.task.name, self.task.name, 'revert', | ||||||
|              dict(task_name=self.task.name, |              dict(arguments=self.task_args, | ||||||
|                   arguments=self.task_args, |  | ||||||
|                   failures=dict((i, utils.FailureMatcher(f)) |                   failures=dict((i, utils.FailureMatcher(f)) | ||||||
|                                 for i, f in six.iteritems(failures))))) |                                 for i, f in six.iteritems(failures))))) | ||||||
|  |  | ||||||
| @@ -182,17 +179,19 @@ class TestServer(test.MockTestCase): | |||||||
|             mock.call.Response(pr.RUNNING), |             mock.call.Response(pr.RUNNING), | ||||||
|             mock.call.proxy.publish(self.response_inst_mock, self.reply_to, |             mock.call.proxy.publish(self.response_inst_mock, self.reply_to, | ||||||
|                                     correlation_id=self.task_uuid), |                                     correlation_id=self.task_uuid), | ||||||
|             mock.call.Response(pr.PROGRESS, progress=0.0, event_data={}), |             mock.call.Response(pr.EVENT, details={'progress': 0.0}, | ||||||
|  |                                event_type=task_atom.EVENT_UPDATE_PROGRESS), | ||||||
|             mock.call.proxy.publish(self.response_inst_mock, self.reply_to, |             mock.call.proxy.publish(self.response_inst_mock, self.reply_to, | ||||||
|                                     correlation_id=self.task_uuid), |                                     correlation_id=self.task_uuid), | ||||||
|             mock.call.Response(pr.PROGRESS, progress=1.0, event_data={}), |             mock.call.Response(pr.EVENT, details={'progress': 1.0}, | ||||||
|  |                                event_type=task_atom.EVENT_UPDATE_PROGRESS), | ||||||
|             mock.call.proxy.publish(self.response_inst_mock, self.reply_to, |             mock.call.proxy.publish(self.response_inst_mock, self.reply_to, | ||||||
|                                     correlation_id=self.task_uuid), |                                     correlation_id=self.task_uuid), | ||||||
|             mock.call.Response(pr.SUCCESS, result=5), |             mock.call.Response(pr.SUCCESS, result=5), | ||||||
|             mock.call.proxy.publish(self.response_inst_mock, self.reply_to, |             mock.call.proxy.publish(self.response_inst_mock, self.reply_to, | ||||||
|                                     correlation_id=self.task_uuid) |                                     correlation_id=self.task_uuid) | ||||||
|         ] |         ] | ||||||
|         self.assertEqual(self.master_mock.mock_calls, master_mock_calls) |         self.assertEqual(master_mock_calls, self.master_mock.mock_calls) | ||||||
|  |  | ||||||
|     def test_process_request(self): |     def test_process_request(self): | ||||||
|         # create server and process request |         # create server and process request | ||||||
|   | |||||||
| @@ -15,6 +15,7 @@ | |||||||
| #    under the License. | #    under the License. | ||||||
|  |  | ||||||
| import collections | import collections | ||||||
|  | import copy | ||||||
| import logging | import logging | ||||||
|  |  | ||||||
| import six | import six | ||||||
| @@ -39,6 +40,14 @@ class _Listener(object): | |||||||
|         else: |         else: | ||||||
|             self._kwargs = kwargs.copy() |             self._kwargs = kwargs.copy() | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def kwargs(self): | ||||||
|  |         return self._kwargs | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def args(self): | ||||||
|  |         return self._args | ||||||
|  |  | ||||||
|     def __call__(self, event_type, details): |     def __call__(self, event_type, details): | ||||||
|         if self._details_filter is not None: |         if self._details_filter is not None: | ||||||
|             if not self._details_filter(details): |             if not self._details_filter(details): | ||||||
| @@ -117,17 +126,18 @@ class Notifier(object): | |||||||
|         event type will be called. |         event type will be called. | ||||||
|  |  | ||||||
|         :param event_type: event type that occurred |         :param event_type: event type that occurred | ||||||
|         :param details: addition event details |         :param details: additional event details *dictionary* passed to | ||||||
|  |                         callback keyword argument with the same name. | ||||||
|         """ |         """ | ||||||
|         listeners = list(self._listeners.get(self.ANY, [])) |         listeners = list(self._listeners.get(self.ANY, [])) | ||||||
|         for listener in self._listeners[event_type]: |         listeners.extend(self._listeners.get(event_type, [])) | ||||||
|             if listener not in listeners: |  | ||||||
|                 listeners.append(listener) |  | ||||||
|         if not listeners: |         if not listeners: | ||||||
|             return |             return | ||||||
|  |         if not details: | ||||||
|  |             details = {} | ||||||
|         for listener in listeners: |         for listener in listeners: | ||||||
|             try: |             try: | ||||||
|                 listener(event_type, details) |                 listener(event_type, details.copy()) | ||||||
|             except Exception: |             except Exception: | ||||||
|                 LOG.warn("Failure calling listener %s to notify about event" |                 LOG.warn("Failure calling listener %s to notify about event" | ||||||
|                          " %s, details: %s", listener, event_type, |                          " %s, details: %s", listener, event_type, | ||||||
| @@ -150,6 +160,9 @@ class Notifier(object): | |||||||
|         if details_filter is not None: |         if details_filter is not None: | ||||||
|             if not six.callable(details_filter): |             if not six.callable(details_filter): | ||||||
|                 raise ValueError("Details filter must be callable") |                 raise ValueError("Details filter must be callable") | ||||||
|  |         if not self.can_be_registered(event_type): | ||||||
|  |             raise ValueError("Disallowed event type '%s' can not have a" | ||||||
|  |                              " callback registered" % event_type) | ||||||
|         if self.is_registered(event_type, callback, |         if self.is_registered(event_type, callback, | ||||||
|                               details_filter=details_filter): |                               details_filter=details_filter): | ||||||
|             raise ValueError("Event callback already registered with" |             raise ValueError("Event callback already registered with" | ||||||
| @@ -165,10 +178,61 @@ class Notifier(object): | |||||||
|                       details_filter=details_filter)) |                       details_filter=details_filter)) | ||||||
|  |  | ||||||
|     def deregister(self, event_type, callback, details_filter=None): |     def deregister(self, event_type, callback, details_filter=None): | ||||||
|         """Remove a single callback from listening to event ``event_type``.""" |         """Remove a single listener bound to event ``event_type``.""" | ||||||
|         if event_type not in self._listeners: |         if event_type not in self._listeners: | ||||||
|             return |             return False | ||||||
|         for i, listener in enumerate(self._listeners[event_type]): |         for i, listener in enumerate(self._listeners.get(event_type, [])): | ||||||
|             if listener.is_equivalent(callback, details_filter=details_filter): |             if listener.is_equivalent(callback, details_filter=details_filter): | ||||||
|                 self._listeners[event_type].pop(i) |                 self._listeners[event_type].pop(i) | ||||||
|                 break |                 return True | ||||||
|  |         return False | ||||||
|  |  | ||||||
|  |     def deregister_event(self, event_type): | ||||||
|  |         """Remove a group of listeners bound to event ``event_type``.""" | ||||||
|  |         return len(self._listeners.pop(event_type, [])) | ||||||
|  |  | ||||||
|  |     def copy(self): | ||||||
|  |         c = copy.copy(self) | ||||||
|  |         c._listeners = collections.defaultdict(list) | ||||||
|  |         for event_type, listeners in six.iteritems(self._listeners): | ||||||
|  |             c._listeners[event_type] = listeners[:] | ||||||
|  |         return c | ||||||
|  |  | ||||||
|  |     def listeners_iter(self): | ||||||
|  |         """Return an iterator over the mapping of event => listeners bound.""" | ||||||
|  |         for event_type, listeners in six.iteritems(self._listeners): | ||||||
|  |             if listeners: | ||||||
|  |                 yield (event_type, listeners) | ||||||
|  |  | ||||||
|  |     def can_be_registered(self, event_type): | ||||||
|  |         """Checks if the event can be registered/subscribed to.""" | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class RestrictedNotifier(Notifier): | ||||||
|  |     """A notification class that restricts events registered/triggered. | ||||||
|  |  | ||||||
|  |     NOTE(harlowja): This class unlike :class:`.Notifier` restricts and | ||||||
|  |     disallows registering callbacks for event types that are not declared | ||||||
|  |     when constructing the notifier. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, watchable_events, allow_any=True): | ||||||
|  |         super(RestrictedNotifier, self).__init__() | ||||||
|  |         self._watchable_events = frozenset(watchable_events) | ||||||
|  |         self._allow_any = allow_any | ||||||
|  |  | ||||||
|  |     def events_iter(self): | ||||||
|  |         """Returns iterator of events that can be registered/subscribed to. | ||||||
|  |  | ||||||
|  |         NOTE(harlowja): does not include back the ``ANY`` event type as that | ||||||
|  |         meta-type is not a specific event but is a capture-all that does not | ||||||
|  |         imply the same meaning as specific event types. | ||||||
|  |         """ | ||||||
|  |         for event_type in self._watchable_events: | ||||||
|  |             yield event_type | ||||||
|  |  | ||||||
|  |     def can_be_registered(self, event_type): | ||||||
|  |         """Checks if the event can be registered/subscribed to.""" | ||||||
|  |         return (event_type in self._watchable_events or | ||||||
|  |                 (event_type == self.ANY and self._allow_any)) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user