Message-oriented worker-based flow with kombu
* Implemented Worker to be started on remote host for handling tasks request. * Implemented WorkerTaskExecutor that proxies tasks requests to remote workers. * Implemented Proxy that is used for consuming and publishing messages by Worker and Executor. * Added worker-based engine and worker task executor. * Added kombu dependency to requirements. * Added worker-based flow example. * Added unit-tests for worker-based flow components. Implements: blueprint worker-based-engine Change-Id: I8c6859ba4a1a56c2592e3d67cdfb8968b13ee99c
This commit is contained in:
parent
5acb0832df
commit
32e8c3da61
@ -21,5 +21,8 @@ psycopg2
|
||||
# ZooKeeper backends
|
||||
kazoo>=1.3.1
|
||||
|
||||
# Eventlet may be used with parallel engine
|
||||
# Eventlet may be used with parallel engine:
|
||||
eventlet>=0.13.0
|
||||
|
||||
# Needed for the worker-based engine:
|
||||
kombu>=2.4.8
|
||||
|
@ -46,6 +46,7 @@ taskflow.engines =
|
||||
default = taskflow.engines.action_engine.engine:SingleThreadedActionEngine
|
||||
serial = taskflow.engines.action_engine.engine:SingleThreadedActionEngine
|
||||
parallel = taskflow.engines.action_engine.engine:MultiThreadedActionEngine
|
||||
worker-based = taskflow.engines.worker_based.engine:WorkerBasedActionEngine
|
||||
|
||||
[nosetests]
|
||||
cover-erase = true
|
||||
|
@ -65,11 +65,11 @@ class TaskExecutorBase(object):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def execute_task(self, task, arguments, progress_callback=None):
|
||||
def execute_task(self, task, task_uuid, arguments, progress_callback=None):
|
||||
"""Schedules task execution."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def revert_task(self, task, arguments, result, failures,
|
||||
def revert_task(self, task, task_uuid, arguments, result, failures,
|
||||
progress_callback=None):
|
||||
"""Schedules task reversion."""
|
||||
|
||||
@ -89,11 +89,11 @@ class TaskExecutorBase(object):
|
||||
class SerialTaskExecutor(TaskExecutorBase):
|
||||
"""Execute task one after another."""
|
||||
|
||||
def execute_task(self, task, arguments, progress_callback=None):
|
||||
def execute_task(self, task, task_uuid, arguments, progress_callback=None):
|
||||
return async_utils.make_completed_future(
|
||||
_execute_task(task, arguments, progress_callback))
|
||||
|
||||
def revert_task(self, task, arguments, result, failures,
|
||||
def revert_task(self, task, task_uuid, arguments, result, failures,
|
||||
progress_callback=None):
|
||||
return async_utils.make_completed_future(
|
||||
_revert_task(task, arguments, result,
|
||||
@ -115,11 +115,11 @@ class ParallelTaskExecutor(TaskExecutorBase):
|
||||
self._executor = executor
|
||||
self._own_executor = executor is None
|
||||
|
||||
def execute_task(self, task, arguments, progress_callback=None):
|
||||
def execute_task(self, task, task_uuid, arguments, progress_callback=None):
|
||||
return self._executor.submit(
|
||||
_execute_task, task, arguments, progress_callback)
|
||||
|
||||
def revert_task(self, task, arguments, result, failures,
|
||||
def revert_task(self, task, task_uuid, arguments, result, failures,
|
||||
progress_callback=None):
|
||||
return self._executor.submit(
|
||||
_revert_task, task,
|
||||
|
@ -27,6 +27,7 @@ SAVE_RESULT_STATES = (states.SUCCESS, states.FAILURE)
|
||||
|
||||
|
||||
class TaskAction(object):
|
||||
|
||||
def __init__(self, storage, task_executor, notifier):
|
||||
self._storage = storage
|
||||
self._task_executor = task_executor
|
||||
@ -66,7 +67,8 @@ class TaskAction(object):
|
||||
if not self._change_state(task, states.RUNNING, progress=0.0):
|
||||
return
|
||||
kwargs = self._storage.fetch_mapped_args(task.rebind)
|
||||
return self._task_executor.execute_task(task, kwargs,
|
||||
task_uuid = self._storage.get_task_uuid(task.name)
|
||||
return self._task_executor.execute_task(task, task_uuid, kwargs,
|
||||
self._on_update_progress)
|
||||
|
||||
def complete_execution(self, task, result):
|
||||
@ -80,9 +82,10 @@ class TaskAction(object):
|
||||
if not self._change_state(task, states.REVERTING, progress=0.0):
|
||||
return
|
||||
kwargs = self._storage.fetch_mapped_args(task.rebind)
|
||||
task_uuid = self._storage.get_task_uuid(task.name)
|
||||
task_result = self._storage.get(task.name)
|
||||
failures = self._storage.get_failures()
|
||||
future = self._task_executor.revert_task(task, kwargs,
|
||||
future = self._task_executor.revert_task(task, task_uuid, kwargs,
|
||||
task_result, failures,
|
||||
self._on_update_progress)
|
||||
return future
|
||||
|
17
taskflow/engines/worker_based/__init__.py
Normal file
17
taskflow/engines/worker_based/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
52
taskflow/engines/worker_based/endpoint.py
Normal file
52
taskflow/engines/worker_based/endpoint.py
Normal file
@ -0,0 +1,52 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from taskflow.engines.action_engine import executor
|
||||
from taskflow.utils import reflection
|
||||
|
||||
|
||||
class Endpoint(object):
|
||||
"""Represents a single task with execute/revert methods."""
|
||||
|
||||
def __init__(self, task_cls):
|
||||
self._task_cls = task_cls
|
||||
self._task_cls_name = reflection.get_class_name(task_cls)
|
||||
self._executor = executor.SerialTaskExecutor()
|
||||
|
||||
def __str__(self):
|
||||
return self._task_cls_name
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._task_cls_name
|
||||
|
||||
def _get_task(self, name=None):
|
||||
# NOTE(skudriashev): Note that task is created here with the 'name'
|
||||
# argument passed to its constructor. This will be a problem
|
||||
# when task's constructor requires some other arguments.
|
||||
return self._task_cls(name=name)
|
||||
|
||||
def execute(self, task_name, **kwargs):
|
||||
task, event, result = self._executor.execute_task(
|
||||
self._get_task(task_name), **kwargs).result()
|
||||
return result
|
||||
|
||||
def revert(self, task_name, **kwargs):
|
||||
task, event, result = self._executor.revert_task(
|
||||
self._get_task(task_name), **kwargs).result()
|
||||
return result
|
40
taskflow/engines/worker_based/engine.py
Normal file
40
taskflow/engines/worker_based/engine.py
Normal file
@ -0,0 +1,40 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from taskflow.engines.action_engine import engine
|
||||
from taskflow.engines.worker_based import executor
|
||||
from taskflow import storage as t_storage
|
||||
|
||||
|
||||
class WorkerBasedActionEngine(engine.ActionEngine):
|
||||
_storage_cls = t_storage.SingleThreadedStorage
|
||||
|
||||
def _task_executor_cls(self):
|
||||
return executor.WorkerTaskExecutor(**self._executor_config)
|
||||
|
||||
def __init__(self, flow, flow_detail, backend, conf):
|
||||
self._executor_config = {
|
||||
'uuid': flow_detail.uuid,
|
||||
'url': conf.get('url'),
|
||||
'exchange': conf.get('exchange', 'default'),
|
||||
'workers_info': conf.get('workers_info', {}),
|
||||
'transport': conf.get('transport'),
|
||||
'transport_options': conf.get('transport_options')
|
||||
}
|
||||
super(WorkerBasedActionEngine, self).__init__(
|
||||
flow, flow_detail, backend, conf)
|
180
taskflow/engines/worker_based/executor.py
Normal file
180
taskflow/engines/worker_based/executor.py
Normal file
@ -0,0 +1,180 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
import six
|
||||
import threading
|
||||
|
||||
from kombu import exceptions as kombu_exc
|
||||
|
||||
from taskflow.engines.action_engine import executor
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.engines.worker_based import proxy
|
||||
from taskflow.engines.worker_based import remote_task as rt
|
||||
from taskflow import exceptions as exc
|
||||
from taskflow.utils import async_utils
|
||||
from taskflow.utils import misc
|
||||
from taskflow.utils import persistence_utils as pu
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
"""Executes tasks on remote workers."""
|
||||
|
||||
def __init__(self, uuid, exchange, workers_info, **kwargs):
|
||||
self._uuid = uuid
|
||||
self._proxy = proxy.Proxy(uuid, exchange, self._on_message,
|
||||
self._on_wait, **kwargs)
|
||||
self._proxy_thread = None
|
||||
self._remote_tasks = {}
|
||||
|
||||
# TODO(skudriashev): This data should be collected from workers
|
||||
# using broadcast messages directly.
|
||||
self._workers_info = {}
|
||||
for topic, tasks in workers_info.items():
|
||||
for task in tasks:
|
||||
self._workers_info[task] = topic
|
||||
|
||||
def _get_proxy_thread(self):
|
||||
proxy_thread = threading.Thread(target=self._proxy.start)
|
||||
# NOTE(skudriashev): When the main thread is terminated unexpectedly
|
||||
# and proxy thread is still alive - it will prevent main thread from
|
||||
# exiting unless the daemon property is set to True.
|
||||
proxy_thread.daemon = True
|
||||
return proxy_thread
|
||||
|
||||
def _on_message(self, response, message):
|
||||
"""This method is called on incoming response."""
|
||||
LOG.debug("Got response: %s" % response)
|
||||
try:
|
||||
# acknowledge message
|
||||
message.ack()
|
||||
except kombu_exc.MessageStateError as e:
|
||||
LOG.warning("Failed to acknowledge AMQP message: %s" % e)
|
||||
else:
|
||||
LOG.debug("AMQP message acknowledged.")
|
||||
# get task uuid from message correlation id parameter
|
||||
try:
|
||||
task_uuid = message.properties['correlation_id']
|
||||
except KeyError:
|
||||
LOG.warning("Got message with no 'correlation_id' property.")
|
||||
else:
|
||||
LOG.debug("Task uuid: '%s'" % task_uuid)
|
||||
self._process_response(task_uuid, response)
|
||||
|
||||
def _process_response(self, task_uuid, response):
|
||||
"""Process response from remote side."""
|
||||
try:
|
||||
task = self._remote_tasks[task_uuid]
|
||||
except KeyError:
|
||||
LOG.debug("Task with id='%s' not found." % task_uuid)
|
||||
else:
|
||||
state = response.pop('state')
|
||||
if state == pr.RUNNING:
|
||||
task.set_running()
|
||||
elif state == pr.PROGRESS:
|
||||
task.on_progress(**response)
|
||||
elif state == pr.FAILURE:
|
||||
response['result'] = pu.failure_from_dict(
|
||||
response['result'])
|
||||
task.set_result(**response)
|
||||
self._remove_remote_task(task)
|
||||
elif state == pr.SUCCESS:
|
||||
task.set_result(**response)
|
||||
self._remove_remote_task(task)
|
||||
else:
|
||||
LOG.warning("Unexpected response status: '%s'" % state)
|
||||
|
||||
def _on_wait(self):
|
||||
"""This function is called cyclically between draining events
|
||||
iterations to clean-up expired task requests.
|
||||
"""
|
||||
expired_tasks = [task for task in six.itervalues(self._remote_tasks)
|
||||
if task.expired]
|
||||
for task in expired_tasks:
|
||||
LOG.debug("Task request '%s' is expired." % task)
|
||||
task.set_result(misc.Failure.from_exception(
|
||||
exc.Timeout("Task request '%s' is expired" % task)))
|
||||
del self._remote_tasks[task.uuid]
|
||||
|
||||
def _store_remote_task(self, task):
|
||||
"""Store task in the remote tasks map."""
|
||||
self._remote_tasks[task.uuid] = task
|
||||
return task
|
||||
|
||||
def _remove_remote_task(self, task):
|
||||
"""Remove remote task from the tasks map."""
|
||||
if task.uuid in self._remote_tasks:
|
||||
del self._remote_tasks[task.uuid]
|
||||
|
||||
def _submit_task(self, task, task_uuid, action, arguments,
|
||||
progress_callback, timeout=pr.REQUEST_TIMEOUT, **kwargs):
|
||||
"""Submit task request to workers."""
|
||||
remote_task = self._store_remote_task(
|
||||
rt.RemoteTask(task, task_uuid, action, arguments,
|
||||
progress_callback, timeout, **kwargs)
|
||||
)
|
||||
try:
|
||||
# get task's workers topic to send request to
|
||||
try:
|
||||
topic = self._workers_info[remote_task.name]
|
||||
except KeyError:
|
||||
raise exc.NotFound("Workers topic not found for the '%s'"
|
||||
"task." % remote_task.name)
|
||||
else:
|
||||
# publish request
|
||||
request = remote_task.request
|
||||
LOG.debug("Sending request: %s" % request)
|
||||
self._proxy.publish(request, remote_task.uuid,
|
||||
routing_key=topic, reply_to=self._uuid)
|
||||
except Exception as e:
|
||||
LOG.error("Failed to submit the '%s' task: %s" % (remote_task, e))
|
||||
self._remove_remote_task(remote_task)
|
||||
remote_task.set_result(misc.Failure())
|
||||
return remote_task.result
|
||||
|
||||
def execute_task(self, task, task_uuid, arguments,
|
||||
progress_callback=None):
|
||||
return self._submit_task(task, task_uuid, pr.EXECUTE, arguments,
|
||||
progress_callback)
|
||||
|
||||
def revert_task(self, task, task_uuid, arguments, result, failures,
|
||||
progress_callback=None):
|
||||
return self._submit_task(task, task_uuid, pr.REVERT, arguments,
|
||||
progress_callback, result=result,
|
||||
failures=failures)
|
||||
|
||||
def wait_for_any(self, fs, timeout=None):
|
||||
"""Wait for futures returned by this executor to complete."""
|
||||
return async_utils.wait_for_any(fs, timeout)
|
||||
|
||||
def start(self):
|
||||
"""Start proxy thread."""
|
||||
if self._proxy_thread is None:
|
||||
self._proxy_thread = self._get_proxy_thread()
|
||||
self._proxy_thread.start()
|
||||
self._proxy.wait()
|
||||
|
||||
def stop(self):
|
||||
"""Stop proxy, so its thread would be gracefully terminated."""
|
||||
if self._proxy_thread is not None:
|
||||
if self._proxy_thread.is_alive():
|
||||
self._proxy.stop()
|
||||
self._proxy_thread.join()
|
||||
self._proxy_thread = None
|
47
taskflow/engines/worker_based/protocol.py
Normal file
47
taskflow/engines/worker_based/protocol.py
Normal file
@ -0,0 +1,47 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from taskflow.engines.action_engine import executor
|
||||
|
||||
# NOTE(skudriashev): This is protocol events, not related to the task states.
|
||||
PENDING = 'PENDING'
|
||||
RUNNING = 'RUNNING'
|
||||
SUCCESS = 'SUCCESS'
|
||||
FAILURE = 'FAILURE'
|
||||
PROGRESS = 'PROGRESS'
|
||||
|
||||
# Remote task actions
|
||||
EXECUTE = 'execute'
|
||||
REVERT = 'revert'
|
||||
|
||||
# Remote task action to event map
|
||||
ACTION_TO_EVENT = {
|
||||
EXECUTE: executor.EXECUTED,
|
||||
REVERT: executor.REVERTED
|
||||
}
|
||||
|
||||
# NOTE(skudriashev): A timeout which specifies request expiration period.
|
||||
REQUEST_TIMEOUT = 60
|
||||
|
||||
# NOTE(skudriashev): A timeout which controls for how long a queue can be
|
||||
# unused before it is automatically deleted. Unused means the queue has no
|
||||
# consumers, the queue has not been redeclared, the `queue.get` has not been
|
||||
# invoked for a duration of at least the expiration period. In our case this
|
||||
# period is equal to the request timeout, once request is expired - queue is
|
||||
# no longer needed.
|
||||
QUEUE_EXPIRE_TIMEOUT = REQUEST_TIMEOUT
|
126
taskflow/engines/worker_based/proxy.py
Normal file
126
taskflow/engines/worker_based/proxy.py
Normal file
@ -0,0 +1,126 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import kombu
|
||||
import logging
|
||||
import socket
|
||||
import threading
|
||||
|
||||
from amqp import exceptions as amqp_exc
|
||||
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
# NOTE(skudriashev): A timeout of 1 is often used in environments where
|
||||
# the socket can get "stuck", and is a best practice for Kombu consumers.
|
||||
DRAIN_EVENTS_PERIOD = 1
|
||||
|
||||
|
||||
class Proxy(object):
|
||||
"""Proxy picks up messages from the named exchange, calls on_message
|
||||
callback when new message received and is used to publish messages.
|
||||
"""
|
||||
|
||||
def __init__(self, uuid, exchange_name, on_message, on_wait=None,
|
||||
**kwargs):
|
||||
self._uuid = uuid
|
||||
self._exchange_name = exchange_name
|
||||
self._on_message = on_message
|
||||
self._on_wait = on_wait
|
||||
self._running = threading.Event()
|
||||
self._url = kwargs.get('url')
|
||||
self._transport = kwargs.get('transport')
|
||||
self._transport_opts = kwargs.get('transport_options')
|
||||
|
||||
# create connection
|
||||
self._conn = kombu.Connection(self._url, transport=self._transport,
|
||||
transport_options=self._transport_opts)
|
||||
|
||||
# create exchange
|
||||
self._exchange = kombu.Exchange(name=self._exchange_name,
|
||||
channel=self._conn,
|
||||
durable=False,
|
||||
auto_delete=True)
|
||||
|
||||
@property
|
||||
def is_running(self):
|
||||
"""Return whether proxy is running."""
|
||||
return self._running.is_set()
|
||||
|
||||
def _make_queue(self, name, exchange, **kwargs):
|
||||
"""Make named queue for the given exchange."""
|
||||
queue_arguments = {'x-expires': pr.QUEUE_EXPIRE_TIMEOUT * 1000}
|
||||
return kombu.Queue(name="%s_%s" % (self._exchange_name, name),
|
||||
exchange=exchange,
|
||||
routing_key=name,
|
||||
durable=False,
|
||||
queue_arguments=queue_arguments,
|
||||
**kwargs)
|
||||
|
||||
def publish(self, msg, task_uuid, routing_key, **kwargs):
|
||||
"""Publish message to the named exchange with routing key."""
|
||||
with kombu.producers[self._conn].acquire(block=True) as producer:
|
||||
queue = self._make_queue(routing_key, self._exchange)
|
||||
producer.publish(body=msg,
|
||||
routing_key=routing_key,
|
||||
exchange=self._exchange,
|
||||
correlation_id=task_uuid,
|
||||
declare=[queue],
|
||||
**kwargs)
|
||||
|
||||
def start(self):
|
||||
"""Start proxy."""
|
||||
LOG.info("Starting to consume from the '%s' exchange." %
|
||||
self._exchange_name)
|
||||
with kombu.connections[self._conn].acquire(block=True) as conn:
|
||||
queue = self._make_queue(self._uuid, self._exchange, channel=conn)
|
||||
try:
|
||||
with conn.Consumer(queues=queue,
|
||||
callbacks=[self._on_message]):
|
||||
self._running.set()
|
||||
while self.is_running:
|
||||
try:
|
||||
conn.drain_events(timeout=DRAIN_EVENTS_PERIOD)
|
||||
except socket.timeout:
|
||||
pass
|
||||
if self._on_wait is not None:
|
||||
self._on_wait()
|
||||
finally:
|
||||
try:
|
||||
queue.delete(if_unused=True)
|
||||
except (amqp_exc.PreconditionFailed, amqp_exc.NotFound):
|
||||
pass
|
||||
except Exception as e:
|
||||
LOG.error("Failed to delete the '%s' queue: %s" %
|
||||
(queue.name, e))
|
||||
try:
|
||||
self._exchange.delete(if_unused=True)
|
||||
except (amqp_exc.PreconditionFailed, amqp_exc.NotFound):
|
||||
pass
|
||||
except Exception as e:
|
||||
LOG.error("Failed to delete the '%s' exchange: %s" %
|
||||
(self._exchange.name, e))
|
||||
|
||||
def wait(self):
|
||||
"""Wait until proxy is started."""
|
||||
self._running.wait()
|
||||
|
||||
def stop(self):
|
||||
"""Stop proxy."""
|
||||
self._running.clear()
|
105
taskflow/engines/worker_based/remote_task.py
Normal file
105
taskflow/engines/worker_based/remote_task.py
Normal file
@ -0,0 +1,105 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import time
|
||||
|
||||
from concurrent import futures
|
||||
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.utils import misc
|
||||
from taskflow.utils import persistence_utils as pu
|
||||
from taskflow.utils import reflection
|
||||
|
||||
|
||||
class RemoteTask(object):
|
||||
"""Represents remote task with its request data and execution results.
|
||||
Every remote task is created in the PENDING state and will be expired
|
||||
within the given timeout.
|
||||
"""
|
||||
|
||||
def __init__(self, task, uuid, action, arguments, progress_callback,
|
||||
timeout, **kwargs):
|
||||
self._task = task
|
||||
self._name = reflection.get_class_name(task)
|
||||
self._uuid = uuid
|
||||
self._action = action
|
||||
self._event = pr.ACTION_TO_EVENT[action]
|
||||
self._arguments = arguments
|
||||
self._progress_callback = progress_callback
|
||||
self._timeout = timeout
|
||||
self._kwargs = kwargs
|
||||
self._time = time.time()
|
||||
self._state = pr.PENDING
|
||||
self.result = futures.Future()
|
||||
|
||||
def __repr__(self):
|
||||
return "%s:%s" % (self._name, self._action)
|
||||
|
||||
@property
|
||||
def uuid(self):
|
||||
return self._uuid
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def request(self):
|
||||
"""Return json-serializable task request, converting all `misc.Failure`
|
||||
objects into dictionaries.
|
||||
"""
|
||||
request = dict(task=self._name, task_name=self._task.name,
|
||||
task_version=self._task.version, action=self._action,
|
||||
arguments=self._arguments)
|
||||
if 'result' in self._kwargs:
|
||||
result = self._kwargs['result']
|
||||
if isinstance(result, misc.Failure):
|
||||
request['result'] = ('failure', pu.failure_to_dict(result))
|
||||
else:
|
||||
request['result'] = ('success', result)
|
||||
if 'failures' in self._kwargs:
|
||||
failures = self._kwargs['failures']
|
||||
request['failures'] = {}
|
||||
for task, failure in failures.items():
|
||||
request['failures'][task] = pu.failure_to_dict(failure)
|
||||
return request
|
||||
|
||||
@property
|
||||
def expired(self):
|
||||
"""Check if task is expired.
|
||||
|
||||
When new remote task is created its state is set to the PENDING,
|
||||
creation time is stored and timeout is given via constructor arguments.
|
||||
|
||||
Remote task is considered to be expired when it is in the PENDING state
|
||||
for more then the given timeout (task is not considered to be expired
|
||||
in any other state). After remote task is expired - the `Timeout`
|
||||
exception is raised and task is removed from the remote tasks map.
|
||||
"""
|
||||
if self._state == pr.PENDING:
|
||||
return time.time() - self._time > self._timeout
|
||||
return False
|
||||
|
||||
def set_result(self, result):
|
||||
self.result.set_result((self._task, self._event, result))
|
||||
|
||||
def set_running(self):
|
||||
self._state = pr.RUNNING
|
||||
|
||||
def on_progress(self, event_data, progress):
|
||||
self._progress_callback(self._task, event_data, progress)
|
178
taskflow/engines/worker_based/server.py
Normal file
178
taskflow/engines/worker_based/server.py
Normal file
@ -0,0 +1,178 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import functools
|
||||
import logging
|
||||
|
||||
from kombu import exceptions as kombu_exc
|
||||
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.engines.worker_based import proxy
|
||||
from taskflow.utils import misc
|
||||
from taskflow.utils import persistence_utils as pu
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Server(object):
|
||||
"""Server implementation that waits for incoming tasks requests."""
|
||||
|
||||
def __init__(self, uuid, exchange, executor, endpoints, **kwargs):
|
||||
self._proxy = proxy.Proxy(uuid, exchange, self._on_message, **kwargs)
|
||||
self._executor = executor
|
||||
self._endpoints = dict([(endpoint.name, endpoint)
|
||||
for endpoint in endpoints])
|
||||
|
||||
def _on_message(self, request, message):
|
||||
"""This method is called on incoming request."""
|
||||
LOG.debug("Got request: %s" % request)
|
||||
# NOTE(skudriashev): Process all incoming requests only if proxy is
|
||||
# running, otherwise reject and requeue them.
|
||||
if self._proxy.is_running:
|
||||
# NOTE(skudriashev): Process request only if message has been
|
||||
# acknowledged successfully.
|
||||
try:
|
||||
# acknowledge message
|
||||
message.ack()
|
||||
except kombu_exc.MessageStateError as e:
|
||||
LOG.warning("Failed to acknowledge AMQP message: %s" % e)
|
||||
else:
|
||||
LOG.debug("AMQP message acknowledged.")
|
||||
# spawn new thread to process request
|
||||
self._executor.submit(self._process_request, request, message)
|
||||
else:
|
||||
try:
|
||||
# reject and requeue message
|
||||
message.reject(requeue=True)
|
||||
except kombu_exc.MessageStateError as e:
|
||||
LOG.warning("Failed to reject/requeue AMQP message: %s" % e)
|
||||
else:
|
||||
LOG.debug("AMQP message rejected and requeued.")
|
||||
|
||||
@staticmethod
|
||||
def _parse_request(task, task_name, action, arguments, result=None,
|
||||
failures=None, **kwargs):
|
||||
"""Parse request before it can be processed. All `misc.Failure` objects
|
||||
that have been converted to dict on the remote side to be serializable
|
||||
are now converted back to objects.
|
||||
"""
|
||||
action_args = dict(arguments=arguments, task_name=task_name)
|
||||
if result is not None:
|
||||
data_type, data = result
|
||||
if data_type == 'failure':
|
||||
action_args['result'] = pu.failure_from_dict(data)
|
||||
else:
|
||||
action_args['result'] = data
|
||||
if failures is not None:
|
||||
action_args['failures'] = {}
|
||||
for k, v in failures.items():
|
||||
action_args['failures'][k] = pu.failure_from_dict(v)
|
||||
return task, action, action_args
|
||||
|
||||
@staticmethod
|
||||
def _parse_message(message):
|
||||
"""Parse broker message to get the `reply_to` and the `correlation_id`
|
||||
properties. If required properties are missing - the `ValueError` is
|
||||
raised.
|
||||
"""
|
||||
properties = []
|
||||
for prop in ('reply_to', 'correlation_id'):
|
||||
try:
|
||||
properties.append(message.properties[prop])
|
||||
except KeyError:
|
||||
raise ValueError("The '%s' message property is missing." %
|
||||
prop)
|
||||
|
||||
return properties
|
||||
|
||||
def _reply(self, reply_to, task_uuid, state=pr.FAILURE, **kwargs):
|
||||
"""Send reply to the `reply_to` queue."""
|
||||
response = dict(state=state, **kwargs)
|
||||
LOG.debug("Send reply: %s" % response)
|
||||
try:
|
||||
self._proxy.publish(response, task_uuid, reply_to)
|
||||
except Exception as e:
|
||||
LOG.error("Failed to send reply: %s" % e)
|
||||
|
||||
def _on_update_progress(self, reply_to, task_uuid, task, event_data,
|
||||
progress):
|
||||
"""Send task update progress notification."""
|
||||
self._reply(reply_to, task_uuid, pr.PROGRESS, event_data=event_data,
|
||||
progress=progress)
|
||||
|
||||
def _process_request(self, request, message):
|
||||
"""Process request in separate thread and reply back."""
|
||||
# parse broker message first to get the `reply_to` and the `task_uuid`
|
||||
# parameters to have possibility to reply back
|
||||
try:
|
||||
reply_to, task_uuid = self._parse_message(message)
|
||||
except ValueError as e:
|
||||
LOG.error("Failed to parse broker message: %s" % e)
|
||||
return
|
||||
else:
|
||||
# prepare task progress callback
|
||||
progress_callback = functools.partial(
|
||||
self._on_update_progress, reply_to, task_uuid)
|
||||
# prepare reply callback
|
||||
reply_callback = functools.partial(
|
||||
self._reply, reply_to, task_uuid)
|
||||
|
||||
# parse request to get task name, action and action arguments
|
||||
try:
|
||||
task, action, action_args = self._parse_request(**request)
|
||||
action_args.update(task_uuid=task_uuid,
|
||||
progress_callback=progress_callback)
|
||||
except ValueError as e:
|
||||
LOG.error("Failed to parse request: %s" % e)
|
||||
reply_callback(result=pu.failure_to_dict(misc.Failure()))
|
||||
return
|
||||
|
||||
# get task endpoint
|
||||
try:
|
||||
endpoint = self._endpoints[task]
|
||||
except KeyError:
|
||||
LOG.error("The '%s' task endpoint does not exist." % task)
|
||||
reply_callback(result=pu.failure_to_dict(misc.Failure()))
|
||||
return
|
||||
else:
|
||||
reply_callback(state=pr.RUNNING)
|
||||
|
||||
# perform task action
|
||||
try:
|
||||
result = getattr(endpoint, action)(**action_args)
|
||||
except Exception as e:
|
||||
LOG.error("The %s task execution failed: %s" % (endpoint, e))
|
||||
reply_callback(result=pu.failure_to_dict(misc.Failure()))
|
||||
else:
|
||||
if isinstance(result, misc.Failure):
|
||||
reply_callback(result=pu.failure_to_dict(result))
|
||||
else:
|
||||
reply_callback(state=pr.SUCCESS, result=result)
|
||||
|
||||
def start(self):
|
||||
"""Start processing incoming requests."""
|
||||
self._proxy.start()
|
||||
|
||||
def wait(self):
|
||||
"""Wait until server is started."""
|
||||
self._proxy.wait()
|
||||
|
||||
def stop(self):
|
||||
"""Stop processing incoming requests."""
|
||||
self._proxy.stop()
|
||||
self._executor.shutdown()
|
132
taskflow/engines/worker_based/worker.py
Normal file
132
taskflow/engines/worker_based/worker.py
Normal file
@ -0,0 +1,132 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import types
|
||||
|
||||
from concurrent import futures
|
||||
|
||||
import six
|
||||
|
||||
from taskflow.engines.worker_based import endpoint
|
||||
from taskflow.engines.worker_based import server
|
||||
from taskflow.openstack.common import importutils
|
||||
from taskflow import task as t_task
|
||||
from taskflow.utils import reflection
|
||||
from taskflow.utils import threading_utils as tu
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Worker(object):
|
||||
"""Worker that can be started on a remote host for handling tasks requests.
|
||||
|
||||
:param url: broker url
|
||||
:param exchange: broker exchange name
|
||||
:param topic: topic name under which worker is stated
|
||||
:param tasks: tasks list that worker is capable to perform
|
||||
|
||||
Tasks list item can be one of the following types:
|
||||
1. String:
|
||||
|
||||
1.1 Python module name:
|
||||
|
||||
> tasks=['taskflow.tests.utils']
|
||||
|
||||
1.2. Task class (BaseTask subclass) name:
|
||||
|
||||
> tasks=['taskflow.test.utils.DummyTask']
|
||||
|
||||
3. Python module:
|
||||
|
||||
> from taskflow.tests import utils
|
||||
> tasks=[utils]
|
||||
|
||||
4. Task class (BaseTask subclass):
|
||||
|
||||
> from taskflow.tests import utils
|
||||
> tasks=[utils.DummyTask]
|
||||
|
||||
:param executor: custom executor object that is used for processing
|
||||
requests in separate threads
|
||||
:keyword threads_count: threads count to be passed to the default executor
|
||||
:keyword transport: broker transport to be used (e.g. amqp, memory, etc.)
|
||||
:keyword transport_options: broker transport options
|
||||
"""
|
||||
|
||||
def __init__(self, exchange, topic, tasks, executor=None, **kwargs):
|
||||
self._topic = topic
|
||||
self._executor = executor
|
||||
self._threads_count = kwargs.pop('threads_count',
|
||||
tu.get_optimal_thread_count())
|
||||
if self._executor is None:
|
||||
self._executor = futures.ThreadPoolExecutor(self._threads_count)
|
||||
self._endpoints = self._derive_endpoints(tasks)
|
||||
self._server = server.Server(topic, exchange, self._executor,
|
||||
self._endpoints, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _derive_endpoints(tasks):
|
||||
"""Derive endpoints from list of strings, classes or packages."""
|
||||
derived_tasks = set()
|
||||
for item in tasks:
|
||||
module = None
|
||||
if isinstance(item, six.string_types):
|
||||
try:
|
||||
pkg, cls = item.split(':')
|
||||
except ValueError:
|
||||
module = importutils.import_module(item)
|
||||
else:
|
||||
obj = importutils.import_class('%s.%s' % (pkg, cls))
|
||||
if not reflection.is_subclass(obj, t_task.BaseTask):
|
||||
raise TypeError("Item %s is not a BaseTask subclass" %
|
||||
item)
|
||||
derived_tasks.add(obj)
|
||||
elif isinstance(item, types.ModuleType):
|
||||
module = item
|
||||
elif reflection.is_subclass(item, t_task.BaseTask):
|
||||
derived_tasks.add(item)
|
||||
else:
|
||||
raise TypeError("Item %s unexpected type: %s" %
|
||||
(item, type(item)))
|
||||
|
||||
# derive tasks
|
||||
if module is not None:
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if reflection.is_subclass(obj, t_task.BaseTask):
|
||||
derived_tasks.add(obj)
|
||||
|
||||
return [endpoint.Endpoint(task) for task in derived_tasks]
|
||||
|
||||
def run(self):
|
||||
"""Run worker."""
|
||||
LOG.info("Starting the '%s' topic worker in %s threads." %
|
||||
(self._topic, self._threads_count))
|
||||
LOG.info("Tasks list:")
|
||||
for endpoint in self._endpoints:
|
||||
LOG.info("|-- %s" % endpoint)
|
||||
self._server.start()
|
||||
|
||||
def wait(self):
|
||||
"""Wait until worker is started."""
|
||||
self._server.wait()
|
||||
|
||||
def stop(self):
|
||||
"""Stop worker."""
|
||||
self._server.stop()
|
68
taskflow/examples/worker_based/flow.py
Normal file
68
taskflow/examples/worker_based/flow.py
Normal file
@ -0,0 +1,68 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import taskflow.engines
|
||||
from taskflow.patterns import linear_flow as lf
|
||||
from taskflow.tests import utils
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.ERROR)
|
||||
engine_conf = {
|
||||
'engine': 'worker-based',
|
||||
'exchange': 'taskflow',
|
||||
'workers_info': {
|
||||
'topic': [
|
||||
'taskflow.tests.utils.TaskOneArgOneReturn',
|
||||
'taskflow.tests.utils.TaskMultiArgOneReturn'
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# parse command line
|
||||
try:
|
||||
arg = sys.argv[1]
|
||||
except IndexError:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
cfg = json.loads(arg)
|
||||
except ValueError:
|
||||
engine_conf.update(url=arg)
|
||||
else:
|
||||
engine_conf.update(cfg)
|
||||
finally:
|
||||
LOG.debug("Worker configuration: %s\n" %
|
||||
json.dumps(engine_conf, sort_keys=True, indent=4))
|
||||
|
||||
# create and run flow
|
||||
flow = lf.Flow('simple-linear').add(
|
||||
utils.TaskOneArgOneReturn(provides='result1'),
|
||||
utils.TaskMultiArgOneReturn(provides='result2')
|
||||
)
|
||||
eng = taskflow.engines.load(flow,
|
||||
store=dict(x=111, y=222, z=333),
|
||||
engine_conf=engine_conf)
|
||||
eng.run()
|
||||
print(json.dumps(eng.storage.fetch_all(), sort_keys=True))
|
60
taskflow/examples/worker_based/worker.py
Normal file
60
taskflow/examples/worker_based/worker.py
Normal file
@ -0,0 +1,60 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from taskflow.engines.worker_based import worker as w
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.ERROR)
|
||||
worker_conf = {
|
||||
'exchange': 'taskflow',
|
||||
'topic': 'topic',
|
||||
'tasks': [
|
||||
'taskflow.tests.utils:TaskOneArgOneReturn',
|
||||
'taskflow.tests.utils:TaskMultiArgOneReturn'
|
||||
]
|
||||
}
|
||||
|
||||
# parse command line
|
||||
try:
|
||||
arg = sys.argv[1]
|
||||
except IndexError:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
cfg = json.loads(arg)
|
||||
except ValueError:
|
||||
worker_conf.update(url=arg)
|
||||
else:
|
||||
worker_conf.update(cfg)
|
||||
finally:
|
||||
LOG.debug("Worker configuration: %s\n" %
|
||||
json.dumps(worker_conf, sort_keys=True, indent=4))
|
||||
|
||||
# run worker
|
||||
worker = w.Worker(**worker_conf)
|
||||
try:
|
||||
worker.run()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
6
taskflow/examples/worker_based_flow.out.txt
Normal file
6
taskflow/examples/worker_based_flow.out.txt
Normal file
@ -0,0 +1,6 @@
|
||||
Run worker.
|
||||
Run flow.
|
||||
{"result1": 1, "result2": 666, "x": 111, "y": 222, "z": 333}
|
||||
|
||||
Flow finished.
|
||||
Stop worker.
|
75
taskflow/examples/worker_based_flow.py
Normal file
75
taskflow/examples/worker_based_flow.py
Normal file
@ -0,0 +1,75 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
self_dir = os.path.abspath(os.path.dirname(__file__))
|
||||
sys.path.insert(0, self_dir)
|
||||
|
||||
import example_utils # noqa
|
||||
|
||||
|
||||
def _path_to(name):
|
||||
return os.path.abspath(os.path.join(os.path.dirname(__file__),
|
||||
'worker_based', name))
|
||||
|
||||
|
||||
def run_test(name, config):
|
||||
cmd = [sys.executable, _path_to(name), config]
|
||||
process = subprocess.Popen(cmd, stdin=None, stdout=subprocess.PIPE,
|
||||
stderr=sys.stderr)
|
||||
return process, cmd
|
||||
|
||||
|
||||
def main():
|
||||
tmp_path = None
|
||||
try:
|
||||
tmp_path = tempfile.mkdtemp(prefix='worker-based-example-')
|
||||
config = json.dumps({
|
||||
'transport': 'filesystem',
|
||||
'transport_options': {
|
||||
'data_folder_in': tmp_path,
|
||||
'data_folder_out': tmp_path
|
||||
}
|
||||
})
|
||||
|
||||
print('Run worker.')
|
||||
worker_process, _ = run_test('worker.py', config)
|
||||
|
||||
print('Run flow.')
|
||||
flow_process, flow_cmd = run_test('flow.py', config)
|
||||
stdout, _ = flow_process.communicate()
|
||||
rc = flow_process.returncode
|
||||
if rc != 0:
|
||||
raise RuntimeError("Could not run %s [%s]" % (flow_cmd, rc))
|
||||
print(stdout.decode())
|
||||
print('Flow finished.')
|
||||
|
||||
print('Stop worker.')
|
||||
worker_process.terminate()
|
||||
|
||||
finally:
|
||||
if tmp_path is not None:
|
||||
example_utils.rm_path(tmp_path)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -151,3 +151,7 @@ def exception_message(exc):
|
||||
return six.text_type(exc)
|
||||
except UnicodeError:
|
||||
return str(exc)
|
||||
|
||||
|
||||
class Timeout(TaskFlowException):
|
||||
"""Raised when something was not finished within the given timeout."""
|
||||
|
@ -16,6 +16,8 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import mock
|
||||
|
||||
from testtools import compat
|
||||
from testtools import matchers
|
||||
from testtools import testcase
|
||||
@ -88,3 +90,51 @@ class TestCase(testcase.TestCase):
|
||||
self.fail(msg)
|
||||
else:
|
||||
current_tail = current_tail[super_index + 1:]
|
||||
|
||||
|
||||
class MockTestCase(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(MockTestCase, self).setUp()
|
||||
self.master_mock = mock.Mock(name='master_mock')
|
||||
|
||||
def _patch(self, target, autospec=True, **kwargs):
|
||||
"""Patch target and attach it to the master mock."""
|
||||
patcher = mock.patch(target, autospec=autospec, **kwargs)
|
||||
mocked = patcher.start()
|
||||
self.addCleanup(patcher.stop)
|
||||
|
||||
attach_as = kwargs.pop('attach_as', None)
|
||||
if attach_as is not None:
|
||||
self.master_mock.attach_mock(mocked, attach_as)
|
||||
|
||||
return mocked
|
||||
|
||||
def _patch_class(self, module, name, autospec=True, attach_as=None):
|
||||
"""Patch class, create class instance mock and attach them to
|
||||
the master mock.
|
||||
"""
|
||||
if autospec:
|
||||
instance_mock = mock.Mock(spec_set=getattr(module, name))
|
||||
else:
|
||||
instance_mock = mock.Mock()
|
||||
|
||||
patcher = mock.patch.object(module, name, autospec=autospec)
|
||||
class_mock = patcher.start()
|
||||
self.addCleanup(patcher.stop)
|
||||
class_mock.return_value = instance_mock
|
||||
|
||||
if attach_as is None:
|
||||
attach_class_as = name
|
||||
attach_instance_as = name.lower()
|
||||
else:
|
||||
attach_class_as = attach_as + '_class'
|
||||
attach_instance_as = attach_as
|
||||
|
||||
self.master_mock.attach_mock(class_mock, attach_class_as)
|
||||
self.master_mock.attach_mock(instance_mock, attach_instance_as)
|
||||
|
||||
return class_mock, instance_mock
|
||||
|
||||
def _reset_master_mock(self):
|
||||
self.master_mock.reset_mock()
|
||||
|
17
taskflow/tests/unit/worker_based/__init__.py
Normal file
17
taskflow/tests/unit/worker_based/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
82
taskflow/tests/unit/worker_based/test_endpoint.py
Normal file
82
taskflow/tests/unit/worker_based/test_endpoint.py
Normal file
@ -0,0 +1,82 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from taskflow.engines.worker_based import endpoint as ep
|
||||
from taskflow import task
|
||||
from taskflow import test
|
||||
from taskflow.tests import utils
|
||||
from taskflow.utils import reflection
|
||||
|
||||
|
||||
class Task(task.Task):
|
||||
|
||||
def __init__(self, a, *args, **kwargs):
|
||||
super(Task, self).__init__(*args, **kwargs)
|
||||
|
||||
def execute(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class TestEndpoint(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestEndpoint, self).setUp()
|
||||
self.task_cls = utils.TaskOneReturn
|
||||
self.task_uuid = 'task-uuid'
|
||||
self.task_args = {'context': 'context'}
|
||||
self.task_cls_name = reflection.get_class_name(self.task_cls)
|
||||
self.task_ep = ep.Endpoint(self.task_cls)
|
||||
self.task_result = 1
|
||||
|
||||
def test_creation(self):
|
||||
task = self.task_ep._get_task()
|
||||
self.assertEqual(self.task_ep.name, self.task_cls_name)
|
||||
self.assertIsInstance(task, self.task_cls)
|
||||
self.assertEqual(task.name, self.task_cls_name)
|
||||
|
||||
def test_creation_with_task_name(self):
|
||||
task_name = 'test'
|
||||
task = self.task_ep._get_task(name=task_name)
|
||||
self.assertEqual(self.task_ep.name, self.task_cls_name)
|
||||
self.assertIsInstance(task, self.task_cls)
|
||||
self.assertEqual(task.name, task_name)
|
||||
|
||||
def test_creation_task_with_constructor_args(self):
|
||||
# NOTE(skudriashev): Exception is expected here since task
|
||||
# is created without any arguments passing to its constructor.
|
||||
endpoint = ep.Endpoint(Task)
|
||||
self.assertRaises(TypeError, endpoint._get_task)
|
||||
|
||||
def test_to_str(self):
|
||||
self.assertEqual(str(self.task_ep), self.task_cls_name)
|
||||
|
||||
def test_execute(self):
|
||||
result = self.task_ep.execute(task_name=self.task_cls_name,
|
||||
task_uuid=self.task_uuid,
|
||||
arguments=self.task_args,
|
||||
progress_callback=None)
|
||||
self.assertEqual(result, self.task_result)
|
||||
|
||||
def test_revert(self):
|
||||
result = self.task_ep.revert(task_name=self.task_cls_name,
|
||||
task_uuid=self.task_uuid,
|
||||
arguments=self.task_args,
|
||||
progress_callback=None,
|
||||
result=self.task_result,
|
||||
failures={})
|
||||
self.assertEqual(result, None)
|
72
taskflow/tests/unit/worker_based/test_engine.py
Normal file
72
taskflow/tests/unit/worker_based/test_engine.py
Normal file
@ -0,0 +1,72 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import mock
|
||||
|
||||
from taskflow.engines.worker_based import engine
|
||||
from taskflow.patterns import linear_flow as lf
|
||||
from taskflow import test
|
||||
from taskflow.tests import utils
|
||||
from taskflow.utils import persistence_utils as pu
|
||||
|
||||
|
||||
class TestWorkerBasedActionEngine(test.MockTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestWorkerBasedActionEngine, self).setUp()
|
||||
self.broker_url = 'test-url'
|
||||
self.exchange = 'test-exchange'
|
||||
self.workers_info = {'test-topic': ['task1', 'task2']}
|
||||
|
||||
# patch classes
|
||||
self.executor_mock, self.executor_inst_mock = self._patch_class(
|
||||
engine.executor, 'WorkerTaskExecutor', attach_as='executor')
|
||||
|
||||
def test_creation_default(self):
|
||||
flow = lf.Flow('test-flow').add(utils.DummyTask())
|
||||
_, flow_detail = pu.temporary_flow_detail()
|
||||
engine.WorkerBasedActionEngine(flow, flow_detail, None, {}).compile()
|
||||
|
||||
expected_calls = [
|
||||
mock.call.executor_class(uuid=flow_detail.uuid,
|
||||
url=None,
|
||||
exchange='default',
|
||||
workers_info={},
|
||||
transport=None,
|
||||
transport_options=None)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, expected_calls)
|
||||
|
||||
def test_creation_custom(self):
|
||||
flow = lf.Flow('test-flow').add(utils.DummyTask())
|
||||
_, flow_detail = pu.temporary_flow_detail()
|
||||
config = {'url': self.broker_url, 'exchange': self.exchange,
|
||||
'workers_info': self.workers_info, 'transport': 'memory',
|
||||
'transport_options': {}}
|
||||
engine.WorkerBasedActionEngine(
|
||||
flow, flow_detail, None, config).compile()
|
||||
|
||||
expected_calls = [
|
||||
mock.call.executor_class(uuid=flow_detail.uuid,
|
||||
url=self.broker_url,
|
||||
exchange=self.exchange,
|
||||
workers_info=self.workers_info,
|
||||
transport='memory',
|
||||
transport_options={})
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, expected_calls)
|
386
taskflow/tests/unit/worker_based/test_executor.py
Normal file
386
taskflow/tests/unit/worker_based/test_executor.py
Normal file
@ -0,0 +1,386 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import mock
|
||||
import threading
|
||||
import time
|
||||
|
||||
from concurrent import futures
|
||||
from kombu import exceptions as kombu_exc
|
||||
|
||||
from taskflow.engines.worker_based import executor
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.engines.worker_based import remote_task as rt
|
||||
from taskflow import test
|
||||
from taskflow.tests import utils
|
||||
from taskflow.utils import misc
|
||||
from taskflow.utils import persistence_utils as pu
|
||||
|
||||
|
||||
class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestWorkerTaskExecutor, self).setUp()
|
||||
self.task = utils.DummyTask()
|
||||
self.task_uuid = 'task-uuid'
|
||||
self.task_args = {'context': 'context'}
|
||||
self.task_result = 'task-result'
|
||||
self.task_failures = {}
|
||||
self.timeout = 60
|
||||
self.broker_url = 'test-url'
|
||||
self.executor_uuid = 'executor-uuid'
|
||||
self.executor_exchange = 'executor-exchange'
|
||||
self.executor_topic = 'executor-topic'
|
||||
self.executor_workers_info = {self.executor_topic: [self.task.name]}
|
||||
self.proxy_started_event = threading.Event()
|
||||
|
||||
# patch classes
|
||||
self.proxy_mock, self.proxy_inst_mock = self._patch_class(
|
||||
executor.proxy, 'Proxy')
|
||||
|
||||
# other mocking
|
||||
self.proxy_inst_mock.start.side_effect = self._fake_proxy_start
|
||||
self.proxy_inst_mock.stop.side_effect = self._fake_proxy_stop
|
||||
self.wait_for_any_mock = self._patch(
|
||||
'taskflow.engines.worker_based.executor.async_utils.wait_for_any')
|
||||
self.message_mock = mock.MagicMock(name='message')
|
||||
self.message_mock.properties = {'correlation_id': self.task_uuid}
|
||||
self.remote_task_mock = mock.MagicMock(uuid=self.task_uuid)
|
||||
|
||||
def _fake_proxy_start(self):
|
||||
self.proxy_started_event.set()
|
||||
while self.proxy_started_event.is_set():
|
||||
time.sleep(0.01)
|
||||
|
||||
def _fake_proxy_stop(self):
|
||||
self.proxy_started_event.clear()
|
||||
|
||||
def executor(self, reset_master_mock=True, **kwargs):
|
||||
executor_kwargs = dict(uuid=self.executor_uuid,
|
||||
exchange=self.executor_exchange,
|
||||
workers_info=self.executor_workers_info,
|
||||
url=self.broker_url)
|
||||
executor_kwargs.update(kwargs)
|
||||
ex = executor.WorkerTaskExecutor(**executor_kwargs)
|
||||
if reset_master_mock:
|
||||
self._reset_master_mock()
|
||||
return ex
|
||||
|
||||
def request(self, **kwargs):
|
||||
request = dict(task=self.task.name, task_name=self.task.name,
|
||||
task_version=self.task.version,
|
||||
arguments=self.task_args)
|
||||
request.update(kwargs)
|
||||
return request
|
||||
|
||||
def remote_task(self, **kwargs):
|
||||
remote_task_kwargs = dict(task=self.task, uuid=self.task_uuid,
|
||||
action='execute', arguments=self.task_args,
|
||||
progress_callback=None, timeout=self.timeout)
|
||||
remote_task_kwargs.update(kwargs)
|
||||
return rt.RemoteTask(**remote_task_kwargs)
|
||||
|
||||
def test_creation(self):
|
||||
ex = self.executor(reset_master_mock=False)
|
||||
|
||||
master_mock_calls = [
|
||||
mock.call.Proxy(self.executor_uuid, self.executor_exchange,
|
||||
ex._on_message, ex._on_wait, url=self.broker_url)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_on_message_state_running(self):
|
||||
response = dict(state=pr.RUNNING)
|
||||
ex = self.executor()
|
||||
ex._store_remote_task(self.remote_task_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
|
||||
self.assertEqual(self.remote_task_mock.mock_calls,
|
||||
[mock.call.set_running()])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
def test_on_message_state_progress(self):
|
||||
response = dict(state=pr.PROGRESS, progress=1.0)
|
||||
ex = self.executor()
|
||||
ex._store_remote_task(self.remote_task_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
|
||||
self.assertEqual(self.remote_task_mock.mock_calls,
|
||||
[mock.call.on_progress(progress=1.0)])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
def test_on_message_state_failure(self):
|
||||
failure = misc.Failure.from_exception(Exception('test'))
|
||||
failure_dict = pu.failure_to_dict(failure)
|
||||
response = dict(state=pr.FAILURE, result=failure_dict)
|
||||
ex = self.executor()
|
||||
ex._store_remote_task(self.remote_task_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
|
||||
self.assertEqual(len(ex._remote_tasks), 0)
|
||||
self.assertEqual(self.remote_task_mock.mock_calls, [
|
||||
mock.call.set_result(result=utils.FailureMatcher(failure))
|
||||
])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
def test_on_message_state_success(self):
|
||||
response = dict(state=pr.SUCCESS, result=self.task_result,
|
||||
event='executed')
|
||||
ex = self.executor()
|
||||
ex._store_remote_task(self.remote_task_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
|
||||
self.assertEqual(self.remote_task_mock.mock_calls,
|
||||
[mock.call.set_result(result=self.task_result,
|
||||
event='executed')])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
def test_on_message_unknown_state(self):
|
||||
response = dict(state='unknown')
|
||||
ex = self.executor()
|
||||
ex._store_remote_task(self.remote_task_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
|
||||
self.assertEqual(self.remote_task_mock.mock_calls, [])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
def test_on_message_non_existent_task(self):
|
||||
self.message_mock.properties = {'correlation_id': 'non-existent'}
|
||||
response = dict(state=pr.RUNNING)
|
||||
ex = self.executor()
|
||||
ex._store_remote_task(self.remote_task_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
|
||||
self.assertEqual(self.remote_task_mock.mock_calls, [])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
def test_on_message_no_correlation_id(self):
|
||||
self.message_mock.properties = {}
|
||||
response = dict(state=pr.RUNNING)
|
||||
ex = self.executor()
|
||||
ex._store_remote_task(self.remote_task_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
|
||||
self.assertEqual(self.remote_task_mock.mock_calls, [])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.executor.LOG.warning')
|
||||
def test_on_message_acknowledge_raises(self, mocked_warning):
|
||||
self.message_mock.ack.side_effect = kombu_exc.MessageStateError()
|
||||
self.executor()._on_message({}, self.message_mock)
|
||||
self.assertTrue(mocked_warning.called)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.remote_task.time.time')
|
||||
def test_on_wait_task_not_expired(self, mock_time):
|
||||
mock_time.side_effect = [1, self.timeout]
|
||||
ex = self.executor()
|
||||
ex._store_remote_task(self.remote_task())
|
||||
|
||||
self.assertEqual(len(ex._remote_tasks), 1)
|
||||
ex._on_wait()
|
||||
self.assertEqual(len(ex._remote_tasks), 1)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.remote_task.time.time')
|
||||
def test_on_wait_task_expired(self, mock_time):
|
||||
mock_time.side_effect = [1, self.timeout + 2, self.timeout * 2]
|
||||
ex = self.executor()
|
||||
ex._store_remote_task(self.remote_task())
|
||||
|
||||
self.assertEqual(len(ex._remote_tasks), 1)
|
||||
ex._on_wait()
|
||||
self.assertEqual(len(ex._remote_tasks), 0)
|
||||
|
||||
def test_remove_task_non_existent(self):
|
||||
task = self.remote_task()
|
||||
ex = self.executor()
|
||||
ex._store_remote_task(task)
|
||||
|
||||
self.assertEqual(len(ex._remote_tasks), 1)
|
||||
ex._remove_remote_task(task)
|
||||
self.assertEqual(len(ex._remote_tasks), 0)
|
||||
|
||||
# remove non-existent
|
||||
ex._remove_remote_task(task)
|
||||
self.assertEqual(len(ex._remote_tasks), 0)
|
||||
|
||||
def test_execute_task(self):
|
||||
request = self.request(action='execute')
|
||||
ex = self.executor()
|
||||
result = ex.execute_task(self.task, self.task_uuid, self.task_args)
|
||||
|
||||
expected_calls = [
|
||||
mock.call.proxy.publish(request, self.task_uuid,
|
||||
routing_key=self.executor_topic,
|
||||
reply_to=self.executor_uuid)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, expected_calls)
|
||||
self.assertIsInstance(result, futures.Future)
|
||||
|
||||
def test_revert_task(self):
|
||||
request = self.request(action='revert',
|
||||
result=('success', self.task_result),
|
||||
failures=self.task_failures)
|
||||
ex = self.executor()
|
||||
result = ex.revert_task(self.task, self.task_uuid, self.task_args,
|
||||
self.task_result, self.task_failures)
|
||||
|
||||
expected_calls = [
|
||||
mock.call.proxy.publish(request, self.task_uuid,
|
||||
routing_key=self.executor_topic,
|
||||
reply_to=self.executor_uuid)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, expected_calls)
|
||||
self.assertIsInstance(result, futures.Future)
|
||||
|
||||
def test_execute_task_topic_not_found(self):
|
||||
workers_info = {self.executor_topic: ['non-existent-task']}
|
||||
ex = self.executor(workers_info=workers_info)
|
||||
result = ex.execute_task(self.task, self.task_uuid, self.task_args)
|
||||
|
||||
self.assertFalse(self.proxy_inst_mock.publish.called)
|
||||
|
||||
# check execute result
|
||||
task, event, res = result.result()
|
||||
self.assertEqual(task, self.task)
|
||||
self.assertEqual(event, 'executed')
|
||||
self.assertIsInstance(res, misc.Failure)
|
||||
|
||||
def test_execute_task_publish_error(self):
|
||||
self.proxy_inst_mock.publish.side_effect = Exception('Woot!')
|
||||
request = self.request(action='execute')
|
||||
ex = self.executor()
|
||||
result = ex.execute_task(self.task, self.task_uuid, self.task_args)
|
||||
|
||||
expected_calls = [
|
||||
mock.call.proxy.publish(request, self.task_uuid,
|
||||
routing_key=self.executor_topic,
|
||||
reply_to=self.executor_uuid)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, expected_calls)
|
||||
|
||||
# check execute result
|
||||
task, event, res = result.result()
|
||||
self.assertEqual(task, self.task)
|
||||
self.assertEqual(event, 'executed')
|
||||
self.assertIsInstance(res, misc.Failure)
|
||||
|
||||
def test_wait_for_any(self):
|
||||
fs = [futures.Future(), futures.Future()]
|
||||
ex = self.executor()
|
||||
ex.wait_for_any(fs)
|
||||
|
||||
expected_calls = [
|
||||
mock.call(fs, None)
|
||||
]
|
||||
self.assertEqual(self.wait_for_any_mock.mock_calls, expected_calls)
|
||||
|
||||
def test_wait_for_any_with_timeout(self):
|
||||
timeout = 30
|
||||
fs = [futures.Future(), futures.Future()]
|
||||
ex = self.executor()
|
||||
ex.wait_for_any(fs, timeout)
|
||||
|
||||
master_mock_calls = [
|
||||
mock.call(fs, timeout)
|
||||
]
|
||||
self.assertEqual(self.wait_for_any_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_start_stop(self):
|
||||
ex = self.executor()
|
||||
ex.start()
|
||||
|
||||
# make sure proxy thread started
|
||||
self.proxy_started_event.wait()
|
||||
|
||||
# stop executor
|
||||
ex.stop()
|
||||
|
||||
self.master_mock.assert_has_calls([
|
||||
mock.call.proxy.start(),
|
||||
mock.call.proxy.wait(),
|
||||
mock.call.proxy.stop()
|
||||
], any_order=True)
|
||||
|
||||
def test_start_already_running(self):
|
||||
ex = self.executor()
|
||||
ex.start()
|
||||
|
||||
# make sure proxy thread started
|
||||
self.proxy_started_event.wait()
|
||||
|
||||
# start executor again
|
||||
ex.start()
|
||||
|
||||
# stop executor
|
||||
ex.stop()
|
||||
|
||||
self.master_mock.assert_has_calls([
|
||||
mock.call.proxy.start(),
|
||||
mock.call.proxy.wait(),
|
||||
mock.call.proxy.stop()
|
||||
], any_order=True)
|
||||
|
||||
def test_stop_not_running(self):
|
||||
self.executor().stop()
|
||||
|
||||
self.assertEqual(self.master_mock.mock_calls, [])
|
||||
|
||||
def test_stop_not_alive(self):
|
||||
self.proxy_inst_mock.start.side_effect = None
|
||||
|
||||
# start executor
|
||||
ex = self.executor()
|
||||
ex.start()
|
||||
|
||||
# wait until executor thread is done
|
||||
ex._proxy_thread.join()
|
||||
|
||||
# stop executor
|
||||
ex.stop()
|
||||
|
||||
# since proxy thread is already done - stop is not called
|
||||
self.master_mock.assert_has_calls([
|
||||
mock.call.proxy.start(),
|
||||
mock.call.proxy.wait()
|
||||
], any_order=True)
|
||||
|
||||
def test_restart(self):
|
||||
ex = self.executor()
|
||||
ex.start()
|
||||
|
||||
# make sure thread started
|
||||
self.proxy_started_event.wait()
|
||||
|
||||
# restart executor
|
||||
ex.stop()
|
||||
ex.start()
|
||||
|
||||
# make sure thread started
|
||||
self.proxy_started_event.wait()
|
||||
|
||||
# stop executor
|
||||
ex.stop()
|
||||
|
||||
self.master_mock.assert_has_calls([
|
||||
mock.call.proxy.start(),
|
||||
mock.call.proxy.wait(),
|
||||
mock.call.proxy.stop(),
|
||||
mock.call.proxy.start(),
|
||||
mock.call.proxy.wait(),
|
||||
mock.call.proxy.stop()
|
||||
], any_order=True)
|
301
taskflow/tests/unit/worker_based/test_proxy.py
Normal file
301
taskflow/tests/unit/worker_based/test_proxy.py
Normal file
@ -0,0 +1,301 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import mock
|
||||
import socket
|
||||
import threading
|
||||
|
||||
from amqp import exceptions as amqp_exc
|
||||
|
||||
from taskflow.engines.worker_based import proxy
|
||||
from taskflow import test
|
||||
|
||||
|
||||
class TestProxy(test.MockTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestProxy, self).setUp()
|
||||
self.uuid = 'test-uuid'
|
||||
self.broker_url = 'test-url'
|
||||
self.exchange_name = 'test-exchange'
|
||||
self.timeout = 5
|
||||
self.queue_arguments = {
|
||||
'x-expires': proxy.pr.QUEUE_EXPIRE_TIMEOUT * 1000
|
||||
}
|
||||
self.de_period = proxy.DRAIN_EVENTS_PERIOD
|
||||
|
||||
# patch classes
|
||||
self.conn_mock, self.conn_inst_mock = self._patch_class(
|
||||
proxy.kombu, 'Connection')
|
||||
self.exchange_mock, self.exchange_inst_mock = self._patch_class(
|
||||
proxy.kombu, 'Exchange')
|
||||
self.queue_mock, self.queue_inst_mock = self._patch_class(
|
||||
proxy.kombu, 'Queue')
|
||||
self.producer_mock, self.producer_inst_mock = self._patch_class(
|
||||
proxy.kombu, 'Producer')
|
||||
|
||||
# connection mocking
|
||||
self.conn_inst_mock.drain_events.side_effect = [
|
||||
socket.timeout, socket.timeout, KeyboardInterrupt]
|
||||
|
||||
# connections mocking
|
||||
self.connections_mock = self._patch(
|
||||
"taskflow.engines.worker_based.proxy.kombu.connections",
|
||||
attach_as='connections')
|
||||
self.connections_mock.__getitem__().acquire().__enter__.return_value =\
|
||||
self.conn_inst_mock
|
||||
|
||||
# producers mocking
|
||||
self.producers_mock = self._patch(
|
||||
"taskflow.engines.worker_based.proxy.kombu.producers",
|
||||
attach_as='producers')
|
||||
self.producers_mock.__getitem__().acquire().__enter__.return_value =\
|
||||
self.producer_inst_mock
|
||||
|
||||
# consumer mocking
|
||||
self.conn_inst_mock.Consumer.return_value.__enter__ = mock.MagicMock()
|
||||
self.conn_inst_mock.Consumer.return_value.__exit__ = mock.MagicMock()
|
||||
|
||||
# other mocking
|
||||
self.on_message_mock = mock.MagicMock(name='on_message')
|
||||
self.on_wait_mock = mock.MagicMock(name='on_wait')
|
||||
self.master_mock.attach_mock(self.on_wait_mock, 'on_wait')
|
||||
|
||||
# reset master mock
|
||||
self._reset_master_mock()
|
||||
|
||||
def _queue_name(self, uuid):
|
||||
return "%s_%s" % (self.exchange_name, uuid)
|
||||
|
||||
def proxy_start_calls(self, calls, exc_type=mock.ANY):
|
||||
return [
|
||||
mock.call.Queue(name=self._queue_name(self.uuid),
|
||||
exchange=self.exchange_inst_mock,
|
||||
routing_key=self.uuid,
|
||||
durable=False,
|
||||
queue_arguments=self.queue_arguments,
|
||||
channel=self.conn_inst_mock),
|
||||
mock.call.connection.Consumer(queues=self.queue_inst_mock,
|
||||
callbacks=[self.on_message_mock]),
|
||||
mock.call.connection.Consumer().__enter__(),
|
||||
] + calls + [
|
||||
mock.call.connection.Consumer().__exit__(exc_type, mock.ANY,
|
||||
mock.ANY),
|
||||
mock.ANY,
|
||||
mock.call.queue.delete(if_unused=True)
|
||||
]
|
||||
|
||||
def proxy(self, reset_master_mock=False, **kwargs):
|
||||
proxy_kwargs = dict(uuid=self.uuid,
|
||||
exchange_name=self.exchange_name,
|
||||
on_message=self.on_message_mock,
|
||||
url=self.broker_url)
|
||||
proxy_kwargs.update(kwargs)
|
||||
p = proxy.Proxy(**proxy_kwargs)
|
||||
if reset_master_mock:
|
||||
self._reset_master_mock()
|
||||
return p
|
||||
|
||||
def test_creation(self):
|
||||
self.proxy()
|
||||
|
||||
master_mock_calls = [
|
||||
mock.call.Connection(self.broker_url, transport=None,
|
||||
transport_options=None),
|
||||
mock.call.Exchange(name=self.exchange_name,
|
||||
channel=self.conn_inst_mock,
|
||||
durable=False,
|
||||
auto_delete=True)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_creation_custom(self):
|
||||
transport_opts = {'context': 'context'}
|
||||
self.proxy(transport='memory', transport_options=transport_opts)
|
||||
|
||||
master_mock_calls = [
|
||||
mock.call.Connection(self.broker_url, transport='memory',
|
||||
transport_options=transport_opts),
|
||||
mock.call.Exchange(name=self.exchange_name,
|
||||
channel=self.conn_inst_mock,
|
||||
durable=False,
|
||||
auto_delete=True)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_publish(self):
|
||||
task_data = 'task-data'
|
||||
task_uuid = 'task-uuid'
|
||||
routing_key = 'routing-key'
|
||||
kwargs = dict(a='a', b='b')
|
||||
|
||||
self.proxy(reset_master_mock=True).publish(
|
||||
task_data, task_uuid, routing_key, **kwargs)
|
||||
|
||||
master_mock_calls = [
|
||||
mock.call.Queue(name=self._queue_name(routing_key),
|
||||
exchange=self.exchange_inst_mock,
|
||||
routing_key=routing_key,
|
||||
durable=False,
|
||||
queue_arguments=self.queue_arguments),
|
||||
mock.call.producer.publish(body=task_data,
|
||||
routing_key=routing_key,
|
||||
exchange=self.exchange_inst_mock,
|
||||
correlation_id=task_uuid,
|
||||
declare=[self.queue_inst_mock],
|
||||
**kwargs)
|
||||
]
|
||||
self.master_mock.assert_has_calls(master_mock_calls)
|
||||
|
||||
def test_start(self):
|
||||
try:
|
||||
# KeyboardInterrupt will be raised after two iterations
|
||||
self.proxy(reset_master_mock=True).start()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
master_calls = self.proxy_start_calls([
|
||||
mock.call.connection.drain_events(timeout=self.de_period),
|
||||
mock.call.connection.drain_events(timeout=self.de_period),
|
||||
mock.call.connection.drain_events(timeout=self.de_period),
|
||||
], exc_type=KeyboardInterrupt)
|
||||
self.master_mock.assert_has_calls(master_calls)
|
||||
|
||||
def test_start_with_on_wait(self):
|
||||
try:
|
||||
# KeyboardInterrupt will be raised after two iterations
|
||||
self.proxy(reset_master_mock=True,
|
||||
on_wait=self.on_wait_mock).start()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
master_calls = self.proxy_start_calls([
|
||||
mock.call.connection.drain_events(timeout=self.de_period),
|
||||
mock.call.on_wait(),
|
||||
mock.call.connection.drain_events(timeout=self.de_period),
|
||||
mock.call.on_wait(),
|
||||
mock.call.connection.drain_events(timeout=self.de_period),
|
||||
], exc_type=KeyboardInterrupt)
|
||||
self.master_mock.assert_has_calls(master_calls)
|
||||
|
||||
def test_start_with_on_wait_raises(self):
|
||||
self.on_wait_mock.side_effect = RuntimeError('Woot!')
|
||||
try:
|
||||
# KeyboardInterrupt will be raised after two iterations
|
||||
self.proxy(reset_master_mock=True,
|
||||
on_wait=self.on_wait_mock).start()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
master_calls = self.proxy_start_calls([
|
||||
mock.call.connection.drain_events(timeout=self.de_period),
|
||||
mock.call.on_wait(),
|
||||
], exc_type=RuntimeError)
|
||||
self.master_mock.assert_has_calls(master_calls)
|
||||
|
||||
def test_start_queue_delete_not_found(self):
|
||||
self.queue_inst_mock.delete.side_effect = amqp_exc.NotFound('Woot!')
|
||||
try:
|
||||
# KeyboardInterrupt will be raised after two iterations
|
||||
self.proxy(reset_master_mock=True).start()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
master_calls = self.proxy_start_calls([
|
||||
mock.call.connection.drain_events(timeout=self.de_period),
|
||||
mock.call.connection.drain_events(timeout=self.de_period),
|
||||
mock.call.connection.drain_events(timeout=self.de_period),
|
||||
], exc_type=KeyboardInterrupt)
|
||||
self.master_mock.assert_has_calls(master_calls)
|
||||
|
||||
@mock.patch("taskflow.engines.worker_based.proxy.LOG.error")
|
||||
def test_start_queue_delete_raises(self, mocked_error):
|
||||
self.queue_inst_mock.delete.side_effect = RuntimeError('Woot!')
|
||||
try:
|
||||
# KeyboardInterrupt will be raised after two iterations
|
||||
self.proxy(reset_master_mock=True).start()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
master_calls = self.proxy_start_calls([
|
||||
mock.call.connection.drain_events(timeout=self.de_period),
|
||||
mock.call.connection.drain_events(timeout=self.de_period),
|
||||
mock.call.connection.drain_events(timeout=self.de_period),
|
||||
], exc_type=KeyboardInterrupt)
|
||||
self.master_mock.assert_has_calls(master_calls)
|
||||
self.assertTrue(mocked_error.called)
|
||||
|
||||
def test_start_exchange_delete_not_found(self):
|
||||
self.exchange_inst_mock.delete.side_effect = amqp_exc.NotFound('Woot!')
|
||||
try:
|
||||
# KeyboardInterrupt will be raised after two iterations
|
||||
self.proxy(reset_master_mock=True).start()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
master_calls = self.proxy_start_calls([
|
||||
mock.call.connection.drain_events(timeout=self.de_period),
|
||||
mock.call.connection.drain_events(timeout=self.de_period),
|
||||
mock.call.connection.drain_events(timeout=self.de_period),
|
||||
], exc_type=KeyboardInterrupt)
|
||||
self.master_mock.assert_has_calls(master_calls)
|
||||
|
||||
@mock.patch("taskflow.engines.worker_based.proxy.LOG.error")
|
||||
def test_start_exchange_delete_raises(self, mocked_error):
|
||||
self.exchange_inst_mock.delete.side_effect = RuntimeError('Woot!')
|
||||
try:
|
||||
# KeyboardInterrupt will be raised after two iterations
|
||||
self.proxy(reset_master_mock=True).start()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
master_calls = self.proxy_start_calls([
|
||||
mock.call.connection.drain_events(timeout=self.de_period),
|
||||
mock.call.connection.drain_events(timeout=self.de_period),
|
||||
mock.call.connection.drain_events(timeout=self.de_period),
|
||||
], exc_type=KeyboardInterrupt)
|
||||
self.master_mock.assert_has_calls(master_calls)
|
||||
self.assertTrue(mocked_error.called)
|
||||
|
||||
def test_stop(self):
|
||||
self.conn_inst_mock.drain_events.side_effect = socket.timeout
|
||||
|
||||
# create proxy
|
||||
pr = self.proxy(reset_master_mock=True)
|
||||
|
||||
# check that proxy is not running yes
|
||||
self.assertFalse(pr.is_running)
|
||||
|
||||
# start proxy in separate thread
|
||||
t = threading.Thread(target=pr.start)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
|
||||
# make sure proxy is started
|
||||
pr.wait()
|
||||
|
||||
# check that proxy is running now
|
||||
self.assertTrue(pr.is_running)
|
||||
|
||||
# stop proxy and wait for thread to finish
|
||||
pr.stop()
|
||||
|
||||
# wait for thread to finish
|
||||
t.join()
|
||||
|
||||
self.assertFalse(pr.is_running)
|
135
taskflow/tests/unit/worker_based/test_remote_task.py
Normal file
135
taskflow/tests/unit/worker_based/test_remote_task.py
Normal file
@ -0,0 +1,135 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import mock
|
||||
|
||||
from concurrent import futures
|
||||
|
||||
from taskflow.engines.worker_based import remote_task as rt
|
||||
from taskflow import test
|
||||
from taskflow.tests import utils
|
||||
from taskflow.utils import misc
|
||||
from taskflow.utils import persistence_utils as pu
|
||||
|
||||
|
||||
class TestRemoteTask(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestRemoteTask, self).setUp()
|
||||
self.task = utils.DummyTask()
|
||||
self.task_uuid = 'task-uuid'
|
||||
self.task_action = 'execute'
|
||||
self.task_args = {'context': 'context'}
|
||||
self.timeout = 60
|
||||
|
||||
def remote_task(self, **kwargs):
|
||||
task_kwargs = dict(task=self.task,
|
||||
uuid=self.task_uuid,
|
||||
action=self.task_action,
|
||||
arguments=self.task_args,
|
||||
progress_callback=None,
|
||||
timeout=self.timeout)
|
||||
task_kwargs.update(kwargs)
|
||||
return rt.RemoteTask(**task_kwargs)
|
||||
|
||||
def remote_task_request(self, **kwargs):
|
||||
request = dict(task=self.task.name,
|
||||
task_name=self.task.name,
|
||||
task_version=self.task.version,
|
||||
action=self.task_action,
|
||||
arguments=self.task_args)
|
||||
request.update(kwargs)
|
||||
return request
|
||||
|
||||
def test_creation(self):
|
||||
remote_task = self.remote_task()
|
||||
self.assertEqual(remote_task.uuid, self.task_uuid)
|
||||
self.assertEqual(remote_task.name, self.task.name)
|
||||
self.assertIsInstance(remote_task.result, futures.Future)
|
||||
self.assertFalse(remote_task.result.done())
|
||||
|
||||
def test_repr(self):
|
||||
expected_name = '%s:%s' % (self.task.name, self.task_action)
|
||||
self.assertEqual(repr(self.remote_task()), expected_name)
|
||||
|
||||
def test_request(self):
|
||||
remote_task = self.remote_task()
|
||||
request = self.remote_task_request()
|
||||
self.assertEqual(remote_task.request, request)
|
||||
|
||||
def test_request_with_result(self):
|
||||
remote_task = self.remote_task(result=333)
|
||||
request = self.remote_task_request(result=('success', 333))
|
||||
self.assertEqual(remote_task.request, request)
|
||||
|
||||
def test_request_with_result_none(self):
|
||||
remote_task = self.remote_task(result=None)
|
||||
request = self.remote_task_request(result=('success', None))
|
||||
self.assertEqual(remote_task.request, request)
|
||||
|
||||
def test_request_with_result_failure(self):
|
||||
failure = misc.Failure.from_exception(RuntimeError('Woot!'))
|
||||
remote_task = self.remote_task(result=failure)
|
||||
request = self.remote_task_request(
|
||||
result=('failure', pu.failure_to_dict(failure)))
|
||||
self.assertEqual(remote_task.request, request)
|
||||
|
||||
def test_request_with_failures(self):
|
||||
failure = misc.Failure.from_exception(RuntimeError('Woot!'))
|
||||
remote_task = self.remote_task(failures={self.task.name: failure})
|
||||
request = self.remote_task_request(
|
||||
failures={self.task.name: pu.failure_to_dict(failure)})
|
||||
self.assertEqual(remote_task.request, request)
|
||||
|
||||
@mock.patch('time.time')
|
||||
def test_pending_not_expired(self, mock_time):
|
||||
mock_time.side_effect = [1, self.timeout]
|
||||
remote_task = self.remote_task()
|
||||
self.assertFalse(remote_task.expired)
|
||||
|
||||
@mock.patch('time.time')
|
||||
def test_pending_expired(self, mock_time):
|
||||
mock_time.side_effect = [1, self.timeout + 2]
|
||||
remote_task = self.remote_task()
|
||||
self.assertTrue(remote_task.expired)
|
||||
|
||||
@mock.patch('time.time')
|
||||
def test_running_not_expired(self, mock_time):
|
||||
mock_time.side_effect = [1, self.timeout]
|
||||
remote_task = self.remote_task()
|
||||
remote_task.set_running()
|
||||
self.assertFalse(remote_task.expired)
|
||||
|
||||
def test_set_result(self):
|
||||
remote_task = self.remote_task()
|
||||
remote_task.set_result(111)
|
||||
result = remote_task.result.result()
|
||||
self.assertEqual(result, (self.task, 'executed', 111))
|
||||
|
||||
def test_on_progress(self):
|
||||
progress_callback = mock.MagicMock(name='progress_callback')
|
||||
remote_task = self.remote_task(task=self.task,
|
||||
progress_callback=progress_callback)
|
||||
remote_task.on_progress('event_data', 0.0)
|
||||
remote_task.on_progress('event_data', 1.0)
|
||||
|
||||
expected_calls = [
|
||||
mock.call(self.task, 'event_data', 0.0),
|
||||
mock.call(self.task, 'event_data', 1.0)
|
||||
]
|
||||
self.assertEqual(progress_callback.mock_calls, expected_calls)
|
379
taskflow/tests/unit/worker_based/test_server.py
Normal file
379
taskflow/tests/unit/worker_based/test_server.py
Normal file
@ -0,0 +1,379 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import mock
|
||||
|
||||
from kombu import exceptions as exc
|
||||
|
||||
from taskflow.engines.worker_based import endpoint as ep
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.engines.worker_based import server
|
||||
from taskflow import test
|
||||
from taskflow.tests import utils
|
||||
from taskflow.utils import misc
|
||||
from taskflow.utils import persistence_utils as pu
|
||||
|
||||
|
||||
class TestServer(test.MockTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestServer, self).setUp()
|
||||
self.server_uuid = 'server-uuid'
|
||||
self.server_exchange = 'server-exchange'
|
||||
self.broker_url = 'test-url'
|
||||
self.task_uuid = 'task-uuid'
|
||||
self.task_args = {'x': 1}
|
||||
self.task_action = 'execute'
|
||||
self.task_name = 'taskflow.tests.utils.TaskOneArgOneReturn'
|
||||
self.task_version = (1, 0)
|
||||
self.reply_to = 'reply-to'
|
||||
self.endpoints = [ep.Endpoint(task_cls=utils.TaskOneArgOneReturn),
|
||||
ep.Endpoint(task_cls=utils.TaskWithFailure),
|
||||
ep.Endpoint(task_cls=utils.ProgressingTask)]
|
||||
self.resp_running = dict(state=pr.RUNNING)
|
||||
|
||||
# patch classes
|
||||
self.proxy_mock, self.proxy_inst_mock = self._patch_class(
|
||||
server.proxy, 'Proxy')
|
||||
|
||||
# other mocking
|
||||
self.proxy_inst_mock.is_running = True
|
||||
self.executor_mock = mock.MagicMock(name='executor')
|
||||
self.message_mock = mock.MagicMock(name='message')
|
||||
self.message_mock.properties = {'correlation_id': self.task_uuid,
|
||||
'reply_to': self.reply_to}
|
||||
self.master_mock.attach_mock(self.executor_mock, 'executor')
|
||||
self.master_mock.attach_mock(self.message_mock, 'message')
|
||||
|
||||
def server(self, reset_master_mock=False, **kwargs):
|
||||
server_kwargs = dict(uuid=self.server_uuid,
|
||||
exchange=self.server_exchange,
|
||||
executor=self.executor_mock,
|
||||
endpoints=self.endpoints,
|
||||
url=self.broker_url)
|
||||
server_kwargs.update(kwargs)
|
||||
s = server.Server(**server_kwargs)
|
||||
if reset_master_mock:
|
||||
self._reset_master_mock()
|
||||
return s
|
||||
|
||||
def request(self, **kwargs):
|
||||
request = dict(task=self.task_name,
|
||||
task_name=self.task_name,
|
||||
action=self.task_action,
|
||||
task_version=self.task_version,
|
||||
arguments=self.task_args)
|
||||
request.update(kwargs)
|
||||
return request
|
||||
|
||||
@staticmethod
|
||||
def resp_progress(progress):
|
||||
return dict(state=pr.PROGRESS, progress=progress, event_data={})
|
||||
|
||||
@staticmethod
|
||||
def resp_success(result):
|
||||
return dict(state=pr.SUCCESS, result=result)
|
||||
|
||||
@staticmethod
|
||||
def resp_failure(result, **kwargs):
|
||||
response = dict(state=pr.FAILURE, result=result)
|
||||
response.update(kwargs)
|
||||
return response
|
||||
|
||||
def test_creation(self):
|
||||
s = self.server()
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.Proxy(self.server_uuid, self.server_exchange,
|
||||
s._on_message, url=self.broker_url)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
self.assertEqual(len(s._endpoints), 3)
|
||||
|
||||
def test_creation_with_endpoints(self):
|
||||
s = self.server(endpoints=self.endpoints)
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.Proxy(self.server_uuid, self.server_exchange,
|
||||
s._on_message, url=self.broker_url)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
self.assertEqual(len(s._endpoints), len(self.endpoints))
|
||||
|
||||
def test_on_message_proxy_running_ack_success(self):
|
||||
request = self.request()
|
||||
s = self.server(reset_master_mock=True)
|
||||
s._on_message(request, self.message_mock)
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.message.ack(),
|
||||
mock.call.executor.submit(s._process_request, request,
|
||||
self.message_mock)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_on_message_proxy_running_ack_failure(self):
|
||||
self.message_mock.ack.side_effect = exc.MessageStateError('Woot!')
|
||||
s = self.server(reset_master_mock=True)
|
||||
s._on_message({}, self.message_mock)
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.message.ack()
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_on_message_proxy_not_running_reject_success(self):
|
||||
self.proxy_inst_mock.is_running = False
|
||||
s = self.server(reset_master_mock=True)
|
||||
s._on_message({}, self.message_mock)
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.message.reject(requeue=True)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_on_message_proxy_not_running_reject_failure(self):
|
||||
self.message_mock.reject.side_effect = exc.MessageStateError('Woot!')
|
||||
self.proxy_inst_mock.is_running = False
|
||||
s = self.server(reset_master_mock=True)
|
||||
s._on_message({}, self.message_mock)
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.message.reject(requeue=True)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_parse_request(self):
|
||||
request = self.request()
|
||||
task, action, task_args = server.Server._parse_request(**request)
|
||||
|
||||
self.assertEqual((task, action, task_args),
|
||||
(self.task_name, self.task_action,
|
||||
dict(task_name=self.task_name,
|
||||
arguments=self.task_args)))
|
||||
|
||||
def test_parse_request_with_success_result(self):
|
||||
request = self.request(action='revert', result=('success', 1))
|
||||
task, action, task_args = server.Server._parse_request(**request)
|
||||
|
||||
self.assertEqual((task, action, task_args),
|
||||
(self.task_name, 'revert',
|
||||
dict(task_name=self.task_name,
|
||||
arguments=self.task_args,
|
||||
result=1)))
|
||||
|
||||
def test_parse_request_with_failure_result(self):
|
||||
failure = misc.Failure.from_exception(Exception('test'))
|
||||
failure_dict = pu.failure_to_dict(failure)
|
||||
request = self.request(action='revert',
|
||||
result=('failure', failure_dict))
|
||||
task, action, task_args = server.Server._parse_request(**request)
|
||||
|
||||
self.assertEqual((task, action, task_args),
|
||||
(self.task_name, 'revert',
|
||||
dict(task_name=self.task_name,
|
||||
arguments=self.task_args,
|
||||
result=utils.FailureMatcher(failure))))
|
||||
|
||||
def test_parse_request_with_failures(self):
|
||||
failures = [misc.Failure.from_exception(Exception('test1')),
|
||||
misc.Failure.from_exception(Exception('test2'))]
|
||||
failures_dict = dict((str(i), pu.failure_to_dict(f))
|
||||
for i, f in enumerate(failures))
|
||||
request = self.request(action='revert', failures=failures_dict)
|
||||
task, action, task_args = server.Server._parse_request(**request)
|
||||
|
||||
self.assertEqual(
|
||||
(task, action, task_args),
|
||||
(self.task_name, 'revert',
|
||||
dict(task_name=self.task_name,
|
||||
arguments=self.task_args,
|
||||
failures=dict((str(i), utils.FailureMatcher(f))
|
||||
for i, f in enumerate(failures)))))
|
||||
|
||||
@mock.patch("taskflow.engines.worker_based.server.LOG.error")
|
||||
def test_reply_publish_failure(self, mocked_error):
|
||||
self.proxy_inst_mock.publish.side_effect = RuntimeError('Woot!')
|
||||
|
||||
# create server and process request
|
||||
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
|
||||
s._reply(self.reply_to, self.task_uuid)
|
||||
|
||||
self.assertEqual(self.master_mock.mock_calls, [
|
||||
mock.call.proxy.publish({'state': 'FAILURE'}, self.task_uuid,
|
||||
self.reply_to)
|
||||
])
|
||||
self.assertEqual(mocked_error.mock_calls, [
|
||||
mock.call("Failed to send reply: Woot!")
|
||||
])
|
||||
|
||||
def test_on_update_progress(self):
|
||||
request = self.request(task='taskflow.tests.utils.ProgressingTask',
|
||||
arguments={})
|
||||
|
||||
# create server and process request
|
||||
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
|
||||
s._process_request(request, self.message_mock)
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.proxy.publish(self.resp_running, self.task_uuid,
|
||||
self.reply_to),
|
||||
mock.call.proxy.publish(self.resp_progress(0.0), self.task_uuid,
|
||||
self.reply_to),
|
||||
mock.call.proxy.publish(self.resp_progress(1.0), self.task_uuid,
|
||||
self.reply_to),
|
||||
mock.call.proxy.publish(self.resp_success(5), self.task_uuid,
|
||||
self.reply_to)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_process_request(self):
|
||||
# create server and process request
|
||||
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
|
||||
s._process_request(self.request(), self.message_mock)
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.proxy.publish(self.resp_running, self.task_uuid,
|
||||
self.reply_to),
|
||||
mock.call.proxy.publish(self.resp_success(1), self.task_uuid,
|
||||
self.reply_to)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
@mock.patch("taskflow.engines.worker_based.server.LOG.error")
|
||||
def test_process_request_parse_message_failure(self, mocked_error):
|
||||
self.message_mock.properties = {}
|
||||
request = self.request()
|
||||
s = self.server(reset_master_mock=True)
|
||||
s._process_request(request, self.message_mock)
|
||||
|
||||
self.assertEqual(self.master_mock.mock_calls, [])
|
||||
self.assertTrue(mocked_error.called)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.server.pu')
|
||||
def test_process_request_parse_failure(self, pu_mock):
|
||||
failure_dict = 'failure_dict'
|
||||
pu_mock.failure_to_dict.return_value = failure_dict
|
||||
pu_mock.failure_from_dict.side_effect = ValueError('Woot!')
|
||||
request = self.request(result=('failure', 1))
|
||||
|
||||
# create server and process request
|
||||
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
|
||||
s._process_request(request, self.message_mock)
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.proxy.publish(self.resp_failure(failure_dict),
|
||||
self.task_uuid, self.reply_to)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.server.pu')
|
||||
def test_process_request_endpoint_not_found(self, pu_mock):
|
||||
failure_dict = 'failure_dict'
|
||||
pu_mock.failure_to_dict.return_value = failure_dict
|
||||
request = self.request(task='<unknown>')
|
||||
|
||||
# create server and process request
|
||||
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
|
||||
s._process_request(request, self.message_mock)
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.proxy.publish(self.resp_failure(failure_dict),
|
||||
self.task_uuid, self.reply_to)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.server.pu')
|
||||
def test_process_request_execution_failure(self, pu_mock):
|
||||
failure_dict = 'failure_dict'
|
||||
pu_mock.failure_to_dict.return_value = failure_dict
|
||||
request = self.request(action='<unknown>')
|
||||
|
||||
# create server and process request
|
||||
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
|
||||
s._process_request(request, self.message_mock)
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.proxy.publish(self.resp_running, self.task_uuid,
|
||||
self.reply_to),
|
||||
mock.call.proxy.publish(self.resp_failure(failure_dict),
|
||||
self.task_uuid, self.reply_to)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.server.pu')
|
||||
def test_process_request_task_failure(self, pu_mock):
|
||||
failure_dict = 'failure_dict'
|
||||
pu_mock.failure_to_dict.return_value = failure_dict
|
||||
request = self.request(task='taskflow.tests.utils.TaskWithFailure',
|
||||
arguments={})
|
||||
|
||||
# create server and process request
|
||||
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
|
||||
s._process_request(request, self.message_mock)
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.proxy.publish(self.resp_running, self.task_uuid,
|
||||
self.reply_to),
|
||||
mock.call.proxy.publish(self.resp_failure(failure_dict),
|
||||
self.task_uuid, self.reply_to)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_start(self):
|
||||
self.server(reset_master_mock=True).start()
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.proxy.start()
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_wait(self):
|
||||
server = self.server(reset_master_mock=True)
|
||||
server.start()
|
||||
server.wait()
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.proxy.start(),
|
||||
mock.call.proxy.wait()
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_stop(self):
|
||||
self.server(reset_master_mock=True).stop()
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.proxy.stop(),
|
||||
mock.call.executor.shutdown()
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
173
taskflow/tests/unit/worker_based/test_worker.py
Normal file
173
taskflow/tests/unit/worker_based/test_worker.py
Normal file
@ -0,0 +1,173 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import mock
|
||||
|
||||
from taskflow.engines.worker_based import endpoint
|
||||
from taskflow.engines.worker_based import worker
|
||||
from taskflow import test
|
||||
from taskflow.tests import utils
|
||||
from taskflow.utils import reflection
|
||||
|
||||
|
||||
class TestWorker(test.MockTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestWorker, self).setUp()
|
||||
self.task_cls = utils.DummyTask
|
||||
self.task_name = reflection.get_class_name(self.task_cls)
|
||||
self.broker_url = 'test-url'
|
||||
self.exchange = 'test-exchange'
|
||||
self.topic = 'test-topic'
|
||||
self.threads_count = 5
|
||||
self.endpoint_count = 18
|
||||
|
||||
# patch classes
|
||||
self.executor_mock, self.executor_inst_mock = self._patch_class(
|
||||
worker.futures, 'ThreadPoolExecutor', attach_as='executor')
|
||||
self.server_mock, self.server_inst_mock = self._patch_class(
|
||||
worker.server, 'Server')
|
||||
|
||||
# other mocking
|
||||
self.threads_count_mock = self._patch(
|
||||
'taskflow.engines.worker_based.worker.tu.get_optimal_thread_count')
|
||||
self.threads_count_mock.return_value = self.threads_count
|
||||
|
||||
def worker(self, reset_master_mock=False, **kwargs):
|
||||
worker_kwargs = dict(exchange=self.exchange,
|
||||
topic=self.topic,
|
||||
tasks=[],
|
||||
url=self.broker_url)
|
||||
worker_kwargs.update(kwargs)
|
||||
w = worker.Worker(**worker_kwargs)
|
||||
if reset_master_mock:
|
||||
self._reset_master_mock()
|
||||
return w
|
||||
|
||||
def test_creation(self):
|
||||
self.worker()
|
||||
|
||||
master_mock_calls = [
|
||||
mock.call.executor_class(self.threads_count),
|
||||
mock.call.Server(self.topic, self.exchange,
|
||||
self.executor_inst_mock, [], url=self.broker_url)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_creation_with_custom_threads_count(self):
|
||||
self.worker(threads_count=10)
|
||||
|
||||
master_mock_calls = [
|
||||
mock.call.executor_class(10),
|
||||
mock.call.Server(self.topic, self.exchange,
|
||||
self.executor_inst_mock, [], url=self.broker_url)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_creation_with_custom_executor(self):
|
||||
executor_mock = mock.MagicMock(name='executor')
|
||||
self.worker(executor=executor_mock)
|
||||
|
||||
master_mock_calls = [
|
||||
mock.call.Server(self.topic, self.exchange, executor_mock, [],
|
||||
url=self.broker_url)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_run_with_no_tasks(self):
|
||||
self.worker(reset_master_mock=True).run()
|
||||
|
||||
master_mock_calls = [
|
||||
mock.call.server.start()
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_run_with_tasks(self):
|
||||
self.worker(reset_master_mock=True,
|
||||
tasks=['taskflow.tests.utils:DummyTask']).run()
|
||||
|
||||
master_mock_calls = [
|
||||
mock.call.server.start()
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_wait(self):
|
||||
w = self.worker(reset_master_mock=True)
|
||||
w.run()
|
||||
w.wait()
|
||||
|
||||
master_mock_calls = [
|
||||
mock.call.server.start(),
|
||||
mock.call.server.wait()
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_stop(self):
|
||||
self.worker(reset_master_mock=True).stop()
|
||||
|
||||
master_mock_calls = [
|
||||
mock.call.server.stop()
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_derive_endpoints_from_string_tasks(self):
|
||||
endpoints = worker.Worker._derive_endpoints(
|
||||
['taskflow.tests.utils:DummyTask'])
|
||||
|
||||
self.assertEqual(len(endpoints), 1)
|
||||
self.assertIsInstance(endpoints[0], endpoint.Endpoint)
|
||||
self.assertEqual(endpoints[0].name, self.task_name)
|
||||
|
||||
def test_derive_endpoints_from_string_modules(self):
|
||||
endpoints = worker.Worker._derive_endpoints(['taskflow.tests.utils'])
|
||||
|
||||
self.assertEqual(len(endpoints), self.endpoint_count)
|
||||
|
||||
def test_derive_endpoints_from_string_non_existent_module(self):
|
||||
tasks = ['non.existent.module']
|
||||
|
||||
self.assertRaises(ImportError, worker.Worker._derive_endpoints, tasks)
|
||||
|
||||
def test_derive_endpoints_from_string_non_existent_task(self):
|
||||
tasks = ['non.existent.module:Task']
|
||||
|
||||
self.assertRaises(ImportError, worker.Worker._derive_endpoints, tasks)
|
||||
|
||||
def test_derive_endpoints_from_string_non_task_class(self):
|
||||
tasks = ['taskflow.tests.utils:FakeTask']
|
||||
|
||||
self.assertRaises(TypeError, worker.Worker._derive_endpoints, tasks)
|
||||
|
||||
def test_derive_endpoints_from_tasks(self):
|
||||
endpoints = worker.Worker._derive_endpoints([self.task_cls])
|
||||
|
||||
self.assertEqual(len(endpoints), 1)
|
||||
self.assertIsInstance(endpoints[0], endpoint.Endpoint)
|
||||
self.assertEqual(endpoints[0].name, self.task_name)
|
||||
|
||||
def test_derive_endpoints_from_non_task_class(self):
|
||||
self.assertRaises(TypeError, worker.Worker._derive_endpoints,
|
||||
[utils.FakeTask])
|
||||
|
||||
def test_derive_endpoints_from_modules(self):
|
||||
endpoints = worker.Worker._derive_endpoints([utils])
|
||||
|
||||
self.assertEqual(len(endpoints), self.endpoint_count)
|
||||
|
||||
def test_derive_endpoints_unexpected_task_type(self):
|
||||
self.assertRaises(TypeError, worker.Worker._derive_endpoints, [111])
|
@ -49,10 +49,17 @@ def make_reverting_task(token, blowup=False):
|
||||
|
||||
|
||||
class DummyTask(task.Task):
|
||||
|
||||
def execute(self, context, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class FakeTask(object):
|
||||
|
||||
def execute(self, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
if six.PY3:
|
||||
RUNTIME_ERROR_CLASSES = ['RuntimeError', 'Exception',
|
||||
'BaseException', 'object']
|
||||
@ -101,7 +108,22 @@ class FailingTask(SaveOrderTask):
|
||||
raise RuntimeError('Woot!')
|
||||
|
||||
|
||||
class TaskWithFailure(task.Task):
|
||||
|
||||
def execute(self, **kwargs):
|
||||
raise RuntimeError('Woot!')
|
||||
|
||||
|
||||
class ProgressingTask(task.Task):
|
||||
|
||||
def execute(self, *args, **kwargs):
|
||||
self.update_progress(0.0)
|
||||
self.update_progress(1.0)
|
||||
return 5
|
||||
|
||||
|
||||
class NastyTask(task.Task):
|
||||
|
||||
def execute(self, **kwargs):
|
||||
pass
|
||||
|
||||
@ -220,3 +242,16 @@ class EngineTestBase(object):
|
||||
|
||||
def _make_engine(self, flow, flow_detail=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class FailureMatcher(object):
|
||||
"""Needed for failure objects comparison."""
|
||||
|
||||
def __init__(self, failure):
|
||||
self._failure = failure
|
||||
|
||||
def __repr__(self):
|
||||
return str(self._failure)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self._failure.matches(other)
|
||||
|
@ -122,6 +122,11 @@ def is_bound_method(method):
|
||||
return bool(get_method_self(method))
|
||||
|
||||
|
||||
def is_subclass(obj, cls):
|
||||
"""Returns if the object is class and it is subclass of a given class."""
|
||||
return inspect.isclass(obj) and issubclass(obj, cls)
|
||||
|
||||
|
||||
def _get_arg_spec(function):
|
||||
if isinstance(function, type):
|
||||
bound = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user