Message-oriented worker-based flow with kombu

* Implemented Worker to be started on remote host for
  handling tasks request.
* Implemented WorkerTaskExecutor that proxies tasks
  requests to remote workers.
* Implemented Proxy that is used for consuming and
  publishing messages by Worker and Executor.
* Added worker-based engine and worker task executor.
* Added kombu dependency to requirements.
* Added worker-based flow example.
* Added unit-tests for worker-based flow components.

Implements: blueprint worker-based-engine
Change-Id: I8c6859ba4a1a56c2592e3d67cdfb8968b13ee99c
This commit is contained in:
Stanislav Kudriashev 2013-12-19 17:48:06 +02:00
parent 5acb0832df
commit 32e8c3da61
30 changed files with 2742 additions and 9 deletions

View File

@ -21,5 +21,8 @@ psycopg2
# ZooKeeper backends
kazoo>=1.3.1
# Eventlet may be used with parallel engine
# Eventlet may be used with parallel engine:
eventlet>=0.13.0
# Needed for the worker-based engine:
kombu>=2.4.8

View File

@ -46,6 +46,7 @@ taskflow.engines =
default = taskflow.engines.action_engine.engine:SingleThreadedActionEngine
serial = taskflow.engines.action_engine.engine:SingleThreadedActionEngine
parallel = taskflow.engines.action_engine.engine:MultiThreadedActionEngine
worker-based = taskflow.engines.worker_based.engine:WorkerBasedActionEngine
[nosetests]
cover-erase = true

View File

@ -65,11 +65,11 @@ class TaskExecutorBase(object):
"""
@abc.abstractmethod
def execute_task(self, task, arguments, progress_callback=None):
def execute_task(self, task, task_uuid, arguments, progress_callback=None):
"""Schedules task execution."""
@abc.abstractmethod
def revert_task(self, task, arguments, result, failures,
def revert_task(self, task, task_uuid, arguments, result, failures,
progress_callback=None):
"""Schedules task reversion."""
@ -89,11 +89,11 @@ class TaskExecutorBase(object):
class SerialTaskExecutor(TaskExecutorBase):
"""Execute task one after another."""
def execute_task(self, task, arguments, progress_callback=None):
def execute_task(self, task, task_uuid, arguments, progress_callback=None):
return async_utils.make_completed_future(
_execute_task(task, arguments, progress_callback))
def revert_task(self, task, arguments, result, failures,
def revert_task(self, task, task_uuid, arguments, result, failures,
progress_callback=None):
return async_utils.make_completed_future(
_revert_task(task, arguments, result,
@ -115,11 +115,11 @@ class ParallelTaskExecutor(TaskExecutorBase):
self._executor = executor
self._own_executor = executor is None
def execute_task(self, task, arguments, progress_callback=None):
def execute_task(self, task, task_uuid, arguments, progress_callback=None):
return self._executor.submit(
_execute_task, task, arguments, progress_callback)
def revert_task(self, task, arguments, result, failures,
def revert_task(self, task, task_uuid, arguments, result, failures,
progress_callback=None):
return self._executor.submit(
_revert_task, task,

View File

@ -27,6 +27,7 @@ SAVE_RESULT_STATES = (states.SUCCESS, states.FAILURE)
class TaskAction(object):
def __init__(self, storage, task_executor, notifier):
self._storage = storage
self._task_executor = task_executor
@ -66,7 +67,8 @@ class TaskAction(object):
if not self._change_state(task, states.RUNNING, progress=0.0):
return
kwargs = self._storage.fetch_mapped_args(task.rebind)
return self._task_executor.execute_task(task, kwargs,
task_uuid = self._storage.get_task_uuid(task.name)
return self._task_executor.execute_task(task, task_uuid, kwargs,
self._on_update_progress)
def complete_execution(self, task, result):
@ -80,9 +82,10 @@ class TaskAction(object):
if not self._change_state(task, states.REVERTING, progress=0.0):
return
kwargs = self._storage.fetch_mapped_args(task.rebind)
task_uuid = self._storage.get_task_uuid(task.name)
task_result = self._storage.get(task.name)
failures = self._storage.get_failures()
future = self._task_executor.revert_task(task, kwargs,
future = self._task_executor.revert_task(task, task_uuid, kwargs,
task_result, failures,
self._on_update_progress)
return future

View File

@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from taskflow.engines.action_engine import executor
from taskflow.utils import reflection
class Endpoint(object):
"""Represents a single task with execute/revert methods."""
def __init__(self, task_cls):
self._task_cls = task_cls
self._task_cls_name = reflection.get_class_name(task_cls)
self._executor = executor.SerialTaskExecutor()
def __str__(self):
return self._task_cls_name
@property
def name(self):
return self._task_cls_name
def _get_task(self, name=None):
# NOTE(skudriashev): Note that task is created here with the 'name'
# argument passed to its constructor. This will be a problem
# when task's constructor requires some other arguments.
return self._task_cls(name=name)
def execute(self, task_name, **kwargs):
task, event, result = self._executor.execute_task(
self._get_task(task_name), **kwargs).result()
return result
def revert(self, task_name, **kwargs):
task, event, result = self._executor.revert_task(
self._get_task(task_name), **kwargs).result()
return result

View File

@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from taskflow.engines.action_engine import engine
from taskflow.engines.worker_based import executor
from taskflow import storage as t_storage
class WorkerBasedActionEngine(engine.ActionEngine):
_storage_cls = t_storage.SingleThreadedStorage
def _task_executor_cls(self):
return executor.WorkerTaskExecutor(**self._executor_config)
def __init__(self, flow, flow_detail, backend, conf):
self._executor_config = {
'uuid': flow_detail.uuid,
'url': conf.get('url'),
'exchange': conf.get('exchange', 'default'),
'workers_info': conf.get('workers_info', {}),
'transport': conf.get('transport'),
'transport_options': conf.get('transport_options')
}
super(WorkerBasedActionEngine, self).__init__(
flow, flow_detail, backend, conf)

View File

@ -0,0 +1,180 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import logging
import six
import threading
from kombu import exceptions as kombu_exc
from taskflow.engines.action_engine import executor
from taskflow.engines.worker_based import protocol as pr
from taskflow.engines.worker_based import proxy
from taskflow.engines.worker_based import remote_task as rt
from taskflow import exceptions as exc
from taskflow.utils import async_utils
from taskflow.utils import misc
from taskflow.utils import persistence_utils as pu
LOG = logging.getLogger(__name__)
class WorkerTaskExecutor(executor.TaskExecutorBase):
"""Executes tasks on remote workers."""
def __init__(self, uuid, exchange, workers_info, **kwargs):
self._uuid = uuid
self._proxy = proxy.Proxy(uuid, exchange, self._on_message,
self._on_wait, **kwargs)
self._proxy_thread = None
self._remote_tasks = {}
# TODO(skudriashev): This data should be collected from workers
# using broadcast messages directly.
self._workers_info = {}
for topic, tasks in workers_info.items():
for task in tasks:
self._workers_info[task] = topic
def _get_proxy_thread(self):
proxy_thread = threading.Thread(target=self._proxy.start)
# NOTE(skudriashev): When the main thread is terminated unexpectedly
# and proxy thread is still alive - it will prevent main thread from
# exiting unless the daemon property is set to True.
proxy_thread.daemon = True
return proxy_thread
def _on_message(self, response, message):
"""This method is called on incoming response."""
LOG.debug("Got response: %s" % response)
try:
# acknowledge message
message.ack()
except kombu_exc.MessageStateError as e:
LOG.warning("Failed to acknowledge AMQP message: %s" % e)
else:
LOG.debug("AMQP message acknowledged.")
# get task uuid from message correlation id parameter
try:
task_uuid = message.properties['correlation_id']
except KeyError:
LOG.warning("Got message with no 'correlation_id' property.")
else:
LOG.debug("Task uuid: '%s'" % task_uuid)
self._process_response(task_uuid, response)
def _process_response(self, task_uuid, response):
"""Process response from remote side."""
try:
task = self._remote_tasks[task_uuid]
except KeyError:
LOG.debug("Task with id='%s' not found." % task_uuid)
else:
state = response.pop('state')
if state == pr.RUNNING:
task.set_running()
elif state == pr.PROGRESS:
task.on_progress(**response)
elif state == pr.FAILURE:
response['result'] = pu.failure_from_dict(
response['result'])
task.set_result(**response)
self._remove_remote_task(task)
elif state == pr.SUCCESS:
task.set_result(**response)
self._remove_remote_task(task)
else:
LOG.warning("Unexpected response status: '%s'" % state)
def _on_wait(self):
"""This function is called cyclically between draining events
iterations to clean-up expired task requests.
"""
expired_tasks = [task for task in six.itervalues(self._remote_tasks)
if task.expired]
for task in expired_tasks:
LOG.debug("Task request '%s' is expired." % task)
task.set_result(misc.Failure.from_exception(
exc.Timeout("Task request '%s' is expired" % task)))
del self._remote_tasks[task.uuid]
def _store_remote_task(self, task):
"""Store task in the remote tasks map."""
self._remote_tasks[task.uuid] = task
return task
def _remove_remote_task(self, task):
"""Remove remote task from the tasks map."""
if task.uuid in self._remote_tasks:
del self._remote_tasks[task.uuid]
def _submit_task(self, task, task_uuid, action, arguments,
progress_callback, timeout=pr.REQUEST_TIMEOUT, **kwargs):
"""Submit task request to workers."""
remote_task = self._store_remote_task(
rt.RemoteTask(task, task_uuid, action, arguments,
progress_callback, timeout, **kwargs)
)
try:
# get task's workers topic to send request to
try:
topic = self._workers_info[remote_task.name]
except KeyError:
raise exc.NotFound("Workers topic not found for the '%s'"
"task." % remote_task.name)
else:
# publish request
request = remote_task.request
LOG.debug("Sending request: %s" % request)
self._proxy.publish(request, remote_task.uuid,
routing_key=topic, reply_to=self._uuid)
except Exception as e:
LOG.error("Failed to submit the '%s' task: %s" % (remote_task, e))
self._remove_remote_task(remote_task)
remote_task.set_result(misc.Failure())
return remote_task.result
def execute_task(self, task, task_uuid, arguments,
progress_callback=None):
return self._submit_task(task, task_uuid, pr.EXECUTE, arguments,
progress_callback)
def revert_task(self, task, task_uuid, arguments, result, failures,
progress_callback=None):
return self._submit_task(task, task_uuid, pr.REVERT, arguments,
progress_callback, result=result,
failures=failures)
def wait_for_any(self, fs, timeout=None):
"""Wait for futures returned by this executor to complete."""
return async_utils.wait_for_any(fs, timeout)
def start(self):
"""Start proxy thread."""
if self._proxy_thread is None:
self._proxy_thread = self._get_proxy_thread()
self._proxy_thread.start()
self._proxy.wait()
def stop(self):
"""Stop proxy, so its thread would be gracefully terminated."""
if self._proxy_thread is not None:
if self._proxy_thread.is_alive():
self._proxy.stop()
self._proxy_thread.join()
self._proxy_thread = None

View File

@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from taskflow.engines.action_engine import executor
# NOTE(skudriashev): This is protocol events, not related to the task states.
PENDING = 'PENDING'
RUNNING = 'RUNNING'
SUCCESS = 'SUCCESS'
FAILURE = 'FAILURE'
PROGRESS = 'PROGRESS'
# Remote task actions
EXECUTE = 'execute'
REVERT = 'revert'
# Remote task action to event map
ACTION_TO_EVENT = {
EXECUTE: executor.EXECUTED,
REVERT: executor.REVERTED
}
# NOTE(skudriashev): A timeout which specifies request expiration period.
REQUEST_TIMEOUT = 60
# NOTE(skudriashev): A timeout which controls for how long a queue can be
# unused before it is automatically deleted. Unused means the queue has no
# consumers, the queue has not been redeclared, the `queue.get` has not been
# invoked for a duration of at least the expiration period. In our case this
# period is equal to the request timeout, once request is expired - queue is
# no longer needed.
QUEUE_EXPIRE_TIMEOUT = REQUEST_TIMEOUT

View File

@ -0,0 +1,126 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import kombu
import logging
import socket
import threading
from amqp import exceptions as amqp_exc
from taskflow.engines.worker_based import protocol as pr
LOG = logging.getLogger(__name__)
# NOTE(skudriashev): A timeout of 1 is often used in environments where
# the socket can get "stuck", and is a best practice for Kombu consumers.
DRAIN_EVENTS_PERIOD = 1
class Proxy(object):
"""Proxy picks up messages from the named exchange, calls on_message
callback when new message received and is used to publish messages.
"""
def __init__(self, uuid, exchange_name, on_message, on_wait=None,
**kwargs):
self._uuid = uuid
self._exchange_name = exchange_name
self._on_message = on_message
self._on_wait = on_wait
self._running = threading.Event()
self._url = kwargs.get('url')
self._transport = kwargs.get('transport')
self._transport_opts = kwargs.get('transport_options')
# create connection
self._conn = kombu.Connection(self._url, transport=self._transport,
transport_options=self._transport_opts)
# create exchange
self._exchange = kombu.Exchange(name=self._exchange_name,
channel=self._conn,
durable=False,
auto_delete=True)
@property
def is_running(self):
"""Return whether proxy is running."""
return self._running.is_set()
def _make_queue(self, name, exchange, **kwargs):
"""Make named queue for the given exchange."""
queue_arguments = {'x-expires': pr.QUEUE_EXPIRE_TIMEOUT * 1000}
return kombu.Queue(name="%s_%s" % (self._exchange_name, name),
exchange=exchange,
routing_key=name,
durable=False,
queue_arguments=queue_arguments,
**kwargs)
def publish(self, msg, task_uuid, routing_key, **kwargs):
"""Publish message to the named exchange with routing key."""
with kombu.producers[self._conn].acquire(block=True) as producer:
queue = self._make_queue(routing_key, self._exchange)
producer.publish(body=msg,
routing_key=routing_key,
exchange=self._exchange,
correlation_id=task_uuid,
declare=[queue],
**kwargs)
def start(self):
"""Start proxy."""
LOG.info("Starting to consume from the '%s' exchange." %
self._exchange_name)
with kombu.connections[self._conn].acquire(block=True) as conn:
queue = self._make_queue(self._uuid, self._exchange, channel=conn)
try:
with conn.Consumer(queues=queue,
callbacks=[self._on_message]):
self._running.set()
while self.is_running:
try:
conn.drain_events(timeout=DRAIN_EVENTS_PERIOD)
except socket.timeout:
pass
if self._on_wait is not None:
self._on_wait()
finally:
try:
queue.delete(if_unused=True)
except (amqp_exc.PreconditionFailed, amqp_exc.NotFound):
pass
except Exception as e:
LOG.error("Failed to delete the '%s' queue: %s" %
(queue.name, e))
try:
self._exchange.delete(if_unused=True)
except (amqp_exc.PreconditionFailed, amqp_exc.NotFound):
pass
except Exception as e:
LOG.error("Failed to delete the '%s' exchange: %s" %
(self._exchange.name, e))
def wait(self):
"""Wait until proxy is started."""
self._running.wait()
def stop(self):
"""Stop proxy."""
self._running.clear()

View File

@ -0,0 +1,105 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import time
from concurrent import futures
from taskflow.engines.worker_based import protocol as pr
from taskflow.utils import misc
from taskflow.utils import persistence_utils as pu
from taskflow.utils import reflection
class RemoteTask(object):
"""Represents remote task with its request data and execution results.
Every remote task is created in the PENDING state and will be expired
within the given timeout.
"""
def __init__(self, task, uuid, action, arguments, progress_callback,
timeout, **kwargs):
self._task = task
self._name = reflection.get_class_name(task)
self._uuid = uuid
self._action = action
self._event = pr.ACTION_TO_EVENT[action]
self._arguments = arguments
self._progress_callback = progress_callback
self._timeout = timeout
self._kwargs = kwargs
self._time = time.time()
self._state = pr.PENDING
self.result = futures.Future()
def __repr__(self):
return "%s:%s" % (self._name, self._action)
@property
def uuid(self):
return self._uuid
@property
def name(self):
return self._name
@property
def request(self):
"""Return json-serializable task request, converting all `misc.Failure`
objects into dictionaries.
"""
request = dict(task=self._name, task_name=self._task.name,
task_version=self._task.version, action=self._action,
arguments=self._arguments)
if 'result' in self._kwargs:
result = self._kwargs['result']
if isinstance(result, misc.Failure):
request['result'] = ('failure', pu.failure_to_dict(result))
else:
request['result'] = ('success', result)
if 'failures' in self._kwargs:
failures = self._kwargs['failures']
request['failures'] = {}
for task, failure in failures.items():
request['failures'][task] = pu.failure_to_dict(failure)
return request
@property
def expired(self):
"""Check if task is expired.
When new remote task is created its state is set to the PENDING,
creation time is stored and timeout is given via constructor arguments.
Remote task is considered to be expired when it is in the PENDING state
for more then the given timeout (task is not considered to be expired
in any other state). After remote task is expired - the `Timeout`
exception is raised and task is removed from the remote tasks map.
"""
if self._state == pr.PENDING:
return time.time() - self._time > self._timeout
return False
def set_result(self, result):
self.result.set_result((self._task, self._event, result))
def set_running(self):
self._state = pr.RUNNING
def on_progress(self, event_data, progress):
self._progress_callback(self._task, event_data, progress)

View File

@ -0,0 +1,178 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import functools
import logging
from kombu import exceptions as kombu_exc
from taskflow.engines.worker_based import protocol as pr
from taskflow.engines.worker_based import proxy
from taskflow.utils import misc
from taskflow.utils import persistence_utils as pu
LOG = logging.getLogger(__name__)
class Server(object):
"""Server implementation that waits for incoming tasks requests."""
def __init__(self, uuid, exchange, executor, endpoints, **kwargs):
self._proxy = proxy.Proxy(uuid, exchange, self._on_message, **kwargs)
self._executor = executor
self._endpoints = dict([(endpoint.name, endpoint)
for endpoint in endpoints])
def _on_message(self, request, message):
"""This method is called on incoming request."""
LOG.debug("Got request: %s" % request)
# NOTE(skudriashev): Process all incoming requests only if proxy is
# running, otherwise reject and requeue them.
if self._proxy.is_running:
# NOTE(skudriashev): Process request only if message has been
# acknowledged successfully.
try:
# acknowledge message
message.ack()
except kombu_exc.MessageStateError as e:
LOG.warning("Failed to acknowledge AMQP message: %s" % e)
else:
LOG.debug("AMQP message acknowledged.")
# spawn new thread to process request
self._executor.submit(self._process_request, request, message)
else:
try:
# reject and requeue message
message.reject(requeue=True)
except kombu_exc.MessageStateError as e:
LOG.warning("Failed to reject/requeue AMQP message: %s" % e)
else:
LOG.debug("AMQP message rejected and requeued.")
@staticmethod
def _parse_request(task, task_name, action, arguments, result=None,
failures=None, **kwargs):
"""Parse request before it can be processed. All `misc.Failure` objects
that have been converted to dict on the remote side to be serializable
are now converted back to objects.
"""
action_args = dict(arguments=arguments, task_name=task_name)
if result is not None:
data_type, data = result
if data_type == 'failure':
action_args['result'] = pu.failure_from_dict(data)
else:
action_args['result'] = data
if failures is not None:
action_args['failures'] = {}
for k, v in failures.items():
action_args['failures'][k] = pu.failure_from_dict(v)
return task, action, action_args
@staticmethod
def _parse_message(message):
"""Parse broker message to get the `reply_to` and the `correlation_id`
properties. If required properties are missing - the `ValueError` is
raised.
"""
properties = []
for prop in ('reply_to', 'correlation_id'):
try:
properties.append(message.properties[prop])
except KeyError:
raise ValueError("The '%s' message property is missing." %
prop)
return properties
def _reply(self, reply_to, task_uuid, state=pr.FAILURE, **kwargs):
"""Send reply to the `reply_to` queue."""
response = dict(state=state, **kwargs)
LOG.debug("Send reply: %s" % response)
try:
self._proxy.publish(response, task_uuid, reply_to)
except Exception as e:
LOG.error("Failed to send reply: %s" % e)
def _on_update_progress(self, reply_to, task_uuid, task, event_data,
progress):
"""Send task update progress notification."""
self._reply(reply_to, task_uuid, pr.PROGRESS, event_data=event_data,
progress=progress)
def _process_request(self, request, message):
"""Process request in separate thread and reply back."""
# parse broker message first to get the `reply_to` and the `task_uuid`
# parameters to have possibility to reply back
try:
reply_to, task_uuid = self._parse_message(message)
except ValueError as e:
LOG.error("Failed to parse broker message: %s" % e)
return
else:
# prepare task progress callback
progress_callback = functools.partial(
self._on_update_progress, reply_to, task_uuid)
# prepare reply callback
reply_callback = functools.partial(
self._reply, reply_to, task_uuid)
# parse request to get task name, action and action arguments
try:
task, action, action_args = self._parse_request(**request)
action_args.update(task_uuid=task_uuid,
progress_callback=progress_callback)
except ValueError as e:
LOG.error("Failed to parse request: %s" % e)
reply_callback(result=pu.failure_to_dict(misc.Failure()))
return
# get task endpoint
try:
endpoint = self._endpoints[task]
except KeyError:
LOG.error("The '%s' task endpoint does not exist." % task)
reply_callback(result=pu.failure_to_dict(misc.Failure()))
return
else:
reply_callback(state=pr.RUNNING)
# perform task action
try:
result = getattr(endpoint, action)(**action_args)
except Exception as e:
LOG.error("The %s task execution failed: %s" % (endpoint, e))
reply_callback(result=pu.failure_to_dict(misc.Failure()))
else:
if isinstance(result, misc.Failure):
reply_callback(result=pu.failure_to_dict(result))
else:
reply_callback(state=pr.SUCCESS, result=result)
def start(self):
"""Start processing incoming requests."""
self._proxy.start()
def wait(self):
"""Wait until server is started."""
self._proxy.wait()
def stop(self):
"""Stop processing incoming requests."""
self._proxy.stop()
self._executor.shutdown()

View File

@ -0,0 +1,132 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import inspect
import logging
import types
from concurrent import futures
import six
from taskflow.engines.worker_based import endpoint
from taskflow.engines.worker_based import server
from taskflow.openstack.common import importutils
from taskflow import task as t_task
from taskflow.utils import reflection
from taskflow.utils import threading_utils as tu
LOG = logging.getLogger(__name__)
class Worker(object):
"""Worker that can be started on a remote host for handling tasks requests.
:param url: broker url
:param exchange: broker exchange name
:param topic: topic name under which worker is stated
:param tasks: tasks list that worker is capable to perform
Tasks list item can be one of the following types:
1. String:
1.1 Python module name:
> tasks=['taskflow.tests.utils']
1.2. Task class (BaseTask subclass) name:
> tasks=['taskflow.test.utils.DummyTask']
3. Python module:
> from taskflow.tests import utils
> tasks=[utils]
4. Task class (BaseTask subclass):
> from taskflow.tests import utils
> tasks=[utils.DummyTask]
:param executor: custom executor object that is used for processing
requests in separate threads
:keyword threads_count: threads count to be passed to the default executor
:keyword transport: broker transport to be used (e.g. amqp, memory, etc.)
:keyword transport_options: broker transport options
"""
def __init__(self, exchange, topic, tasks, executor=None, **kwargs):
self._topic = topic
self._executor = executor
self._threads_count = kwargs.pop('threads_count',
tu.get_optimal_thread_count())
if self._executor is None:
self._executor = futures.ThreadPoolExecutor(self._threads_count)
self._endpoints = self._derive_endpoints(tasks)
self._server = server.Server(topic, exchange, self._executor,
self._endpoints, **kwargs)
@staticmethod
def _derive_endpoints(tasks):
"""Derive endpoints from list of strings, classes or packages."""
derived_tasks = set()
for item in tasks:
module = None
if isinstance(item, six.string_types):
try:
pkg, cls = item.split(':')
except ValueError:
module = importutils.import_module(item)
else:
obj = importutils.import_class('%s.%s' % (pkg, cls))
if not reflection.is_subclass(obj, t_task.BaseTask):
raise TypeError("Item %s is not a BaseTask subclass" %
item)
derived_tasks.add(obj)
elif isinstance(item, types.ModuleType):
module = item
elif reflection.is_subclass(item, t_task.BaseTask):
derived_tasks.add(item)
else:
raise TypeError("Item %s unexpected type: %s" %
(item, type(item)))
# derive tasks
if module is not None:
for name, obj in inspect.getmembers(module):
if reflection.is_subclass(obj, t_task.BaseTask):
derived_tasks.add(obj)
return [endpoint.Endpoint(task) for task in derived_tasks]
def run(self):
"""Run worker."""
LOG.info("Starting the '%s' topic worker in %s threads." %
(self._topic, self._threads_count))
LOG.info("Tasks list:")
for endpoint in self._endpoints:
LOG.info("|-- %s" % endpoint)
self._server.start()
def wait(self):
"""Wait until worker is started."""
self._server.wait()
def stop(self):
"""Stop worker."""
self._server.stop()

View File

@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import json
import logging
import sys
import taskflow.engines
from taskflow.patterns import linear_flow as lf
from taskflow.tests import utils
LOG = logging.getLogger(__name__)
if __name__ == "__main__":
logging.basicConfig(level=logging.ERROR)
engine_conf = {
'engine': 'worker-based',
'exchange': 'taskflow',
'workers_info': {
'topic': [
'taskflow.tests.utils.TaskOneArgOneReturn',
'taskflow.tests.utils.TaskMultiArgOneReturn'
]
}
}
# parse command line
try:
arg = sys.argv[1]
except IndexError:
pass
else:
try:
cfg = json.loads(arg)
except ValueError:
engine_conf.update(url=arg)
else:
engine_conf.update(cfg)
finally:
LOG.debug("Worker configuration: %s\n" %
json.dumps(engine_conf, sort_keys=True, indent=4))
# create and run flow
flow = lf.Flow('simple-linear').add(
utils.TaskOneArgOneReturn(provides='result1'),
utils.TaskMultiArgOneReturn(provides='result2')
)
eng = taskflow.engines.load(flow,
store=dict(x=111, y=222, z=333),
engine_conf=engine_conf)
eng.run()
print(json.dumps(eng.storage.fetch_all(), sort_keys=True))

View File

@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import json
import logging
import sys
from taskflow.engines.worker_based import worker as w
LOG = logging.getLogger(__name__)
if __name__ == "__main__":
logging.basicConfig(level=logging.ERROR)
worker_conf = {
'exchange': 'taskflow',
'topic': 'topic',
'tasks': [
'taskflow.tests.utils:TaskOneArgOneReturn',
'taskflow.tests.utils:TaskMultiArgOneReturn'
]
}
# parse command line
try:
arg = sys.argv[1]
except IndexError:
pass
else:
try:
cfg = json.loads(arg)
except ValueError:
worker_conf.update(url=arg)
else:
worker_conf.update(cfg)
finally:
LOG.debug("Worker configuration: %s\n" %
json.dumps(worker_conf, sort_keys=True, indent=4))
# run worker
worker = w.Worker(**worker_conf)
try:
worker.run()
except KeyboardInterrupt:
pass

View File

@ -0,0 +1,6 @@
Run worker.
Run flow.
{"result1": 1, "result2": 666, "x": 111, "y": 222, "z": 333}
Flow finished.
Stop worker.

View File

@ -0,0 +1,75 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import json
import os
import subprocess
import sys
import tempfile
self_dir = os.path.abspath(os.path.dirname(__file__))
sys.path.insert(0, self_dir)
import example_utils # noqa
def _path_to(name):
return os.path.abspath(os.path.join(os.path.dirname(__file__),
'worker_based', name))
def run_test(name, config):
cmd = [sys.executable, _path_to(name), config]
process = subprocess.Popen(cmd, stdin=None, stdout=subprocess.PIPE,
stderr=sys.stderr)
return process, cmd
def main():
tmp_path = None
try:
tmp_path = tempfile.mkdtemp(prefix='worker-based-example-')
config = json.dumps({
'transport': 'filesystem',
'transport_options': {
'data_folder_in': tmp_path,
'data_folder_out': tmp_path
}
})
print('Run worker.')
worker_process, _ = run_test('worker.py', config)
print('Run flow.')
flow_process, flow_cmd = run_test('flow.py', config)
stdout, _ = flow_process.communicate()
rc = flow_process.returncode
if rc != 0:
raise RuntimeError("Could not run %s [%s]" % (flow_cmd, rc))
print(stdout.decode())
print('Flow finished.')
print('Stop worker.')
worker_process.terminate()
finally:
if tmp_path is not None:
example_utils.rm_path(tmp_path)
if __name__ == '__main__':
main()

View File

@ -151,3 +151,7 @@ def exception_message(exc):
return six.text_type(exc)
except UnicodeError:
return str(exc)
class Timeout(TaskFlowException):
"""Raised when something was not finished within the given timeout."""

View File

@ -16,6 +16,8 @@
# License for the specific language governing permissions and limitations
# under the License.
import mock
from testtools import compat
from testtools import matchers
from testtools import testcase
@ -88,3 +90,51 @@ class TestCase(testcase.TestCase):
self.fail(msg)
else:
current_tail = current_tail[super_index + 1:]
class MockTestCase(TestCase):
def setUp(self):
super(MockTestCase, self).setUp()
self.master_mock = mock.Mock(name='master_mock')
def _patch(self, target, autospec=True, **kwargs):
"""Patch target and attach it to the master mock."""
patcher = mock.patch(target, autospec=autospec, **kwargs)
mocked = patcher.start()
self.addCleanup(patcher.stop)
attach_as = kwargs.pop('attach_as', None)
if attach_as is not None:
self.master_mock.attach_mock(mocked, attach_as)
return mocked
def _patch_class(self, module, name, autospec=True, attach_as=None):
"""Patch class, create class instance mock and attach them to
the master mock.
"""
if autospec:
instance_mock = mock.Mock(spec_set=getattr(module, name))
else:
instance_mock = mock.Mock()
patcher = mock.patch.object(module, name, autospec=autospec)
class_mock = patcher.start()
self.addCleanup(patcher.stop)
class_mock.return_value = instance_mock
if attach_as is None:
attach_class_as = name
attach_instance_as = name.lower()
else:
attach_class_as = attach_as + '_class'
attach_instance_as = attach_as
self.master_mock.attach_mock(class_mock, attach_class_as)
self.master_mock.attach_mock(instance_mock, attach_instance_as)
return class_mock, instance_mock
def _reset_master_mock(self):
self.master_mock.reset_mock()

View File

@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,82 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from taskflow.engines.worker_based import endpoint as ep
from taskflow import task
from taskflow import test
from taskflow.tests import utils
from taskflow.utils import reflection
class Task(task.Task):
def __init__(self, a, *args, **kwargs):
super(Task, self).__init__(*args, **kwargs)
def execute(self, *args, **kwargs):
pass
class TestEndpoint(test.TestCase):
def setUp(self):
super(TestEndpoint, self).setUp()
self.task_cls = utils.TaskOneReturn
self.task_uuid = 'task-uuid'
self.task_args = {'context': 'context'}
self.task_cls_name = reflection.get_class_name(self.task_cls)
self.task_ep = ep.Endpoint(self.task_cls)
self.task_result = 1
def test_creation(self):
task = self.task_ep._get_task()
self.assertEqual(self.task_ep.name, self.task_cls_name)
self.assertIsInstance(task, self.task_cls)
self.assertEqual(task.name, self.task_cls_name)
def test_creation_with_task_name(self):
task_name = 'test'
task = self.task_ep._get_task(name=task_name)
self.assertEqual(self.task_ep.name, self.task_cls_name)
self.assertIsInstance(task, self.task_cls)
self.assertEqual(task.name, task_name)
def test_creation_task_with_constructor_args(self):
# NOTE(skudriashev): Exception is expected here since task
# is created without any arguments passing to its constructor.
endpoint = ep.Endpoint(Task)
self.assertRaises(TypeError, endpoint._get_task)
def test_to_str(self):
self.assertEqual(str(self.task_ep), self.task_cls_name)
def test_execute(self):
result = self.task_ep.execute(task_name=self.task_cls_name,
task_uuid=self.task_uuid,
arguments=self.task_args,
progress_callback=None)
self.assertEqual(result, self.task_result)
def test_revert(self):
result = self.task_ep.revert(task_name=self.task_cls_name,
task_uuid=self.task_uuid,
arguments=self.task_args,
progress_callback=None,
result=self.task_result,
failures={})
self.assertEqual(result, None)

View File

@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import mock
from taskflow.engines.worker_based import engine
from taskflow.patterns import linear_flow as lf
from taskflow import test
from taskflow.tests import utils
from taskflow.utils import persistence_utils as pu
class TestWorkerBasedActionEngine(test.MockTestCase):
def setUp(self):
super(TestWorkerBasedActionEngine, self).setUp()
self.broker_url = 'test-url'
self.exchange = 'test-exchange'
self.workers_info = {'test-topic': ['task1', 'task2']}
# patch classes
self.executor_mock, self.executor_inst_mock = self._patch_class(
engine.executor, 'WorkerTaskExecutor', attach_as='executor')
def test_creation_default(self):
flow = lf.Flow('test-flow').add(utils.DummyTask())
_, flow_detail = pu.temporary_flow_detail()
engine.WorkerBasedActionEngine(flow, flow_detail, None, {}).compile()
expected_calls = [
mock.call.executor_class(uuid=flow_detail.uuid,
url=None,
exchange='default',
workers_info={},
transport=None,
transport_options=None)
]
self.assertEqual(self.master_mock.mock_calls, expected_calls)
def test_creation_custom(self):
flow = lf.Flow('test-flow').add(utils.DummyTask())
_, flow_detail = pu.temporary_flow_detail()
config = {'url': self.broker_url, 'exchange': self.exchange,
'workers_info': self.workers_info, 'transport': 'memory',
'transport_options': {}}
engine.WorkerBasedActionEngine(
flow, flow_detail, None, config).compile()
expected_calls = [
mock.call.executor_class(uuid=flow_detail.uuid,
url=self.broker_url,
exchange=self.exchange,
workers_info=self.workers_info,
transport='memory',
transport_options={})
]
self.assertEqual(self.master_mock.mock_calls, expected_calls)

View File

@ -0,0 +1,386 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import mock
import threading
import time
from concurrent import futures
from kombu import exceptions as kombu_exc
from taskflow.engines.worker_based import executor
from taskflow.engines.worker_based import protocol as pr
from taskflow.engines.worker_based import remote_task as rt
from taskflow import test
from taskflow.tests import utils
from taskflow.utils import misc
from taskflow.utils import persistence_utils as pu
class TestWorkerTaskExecutor(test.MockTestCase):
def setUp(self):
super(TestWorkerTaskExecutor, self).setUp()
self.task = utils.DummyTask()
self.task_uuid = 'task-uuid'
self.task_args = {'context': 'context'}
self.task_result = 'task-result'
self.task_failures = {}
self.timeout = 60
self.broker_url = 'test-url'
self.executor_uuid = 'executor-uuid'
self.executor_exchange = 'executor-exchange'
self.executor_topic = 'executor-topic'
self.executor_workers_info = {self.executor_topic: [self.task.name]}
self.proxy_started_event = threading.Event()
# patch classes
self.proxy_mock, self.proxy_inst_mock = self._patch_class(
executor.proxy, 'Proxy')
# other mocking
self.proxy_inst_mock.start.side_effect = self._fake_proxy_start
self.proxy_inst_mock.stop.side_effect = self._fake_proxy_stop
self.wait_for_any_mock = self._patch(
'taskflow.engines.worker_based.executor.async_utils.wait_for_any')
self.message_mock = mock.MagicMock(name='message')
self.message_mock.properties = {'correlation_id': self.task_uuid}
self.remote_task_mock = mock.MagicMock(uuid=self.task_uuid)
def _fake_proxy_start(self):
self.proxy_started_event.set()
while self.proxy_started_event.is_set():
time.sleep(0.01)
def _fake_proxy_stop(self):
self.proxy_started_event.clear()
def executor(self, reset_master_mock=True, **kwargs):
executor_kwargs = dict(uuid=self.executor_uuid,
exchange=self.executor_exchange,
workers_info=self.executor_workers_info,
url=self.broker_url)
executor_kwargs.update(kwargs)
ex = executor.WorkerTaskExecutor(**executor_kwargs)
if reset_master_mock:
self._reset_master_mock()
return ex
def request(self, **kwargs):
request = dict(task=self.task.name, task_name=self.task.name,
task_version=self.task.version,
arguments=self.task_args)
request.update(kwargs)
return request
def remote_task(self, **kwargs):
remote_task_kwargs = dict(task=self.task, uuid=self.task_uuid,
action='execute', arguments=self.task_args,
progress_callback=None, timeout=self.timeout)
remote_task_kwargs.update(kwargs)
return rt.RemoteTask(**remote_task_kwargs)
def test_creation(self):
ex = self.executor(reset_master_mock=False)
master_mock_calls = [
mock.call.Proxy(self.executor_uuid, self.executor_exchange,
ex._on_message, ex._on_wait, url=self.broker_url)
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_on_message_state_running(self):
response = dict(state=pr.RUNNING)
ex = self.executor()
ex._store_remote_task(self.remote_task_mock)
ex._on_message(response, self.message_mock)
self.assertEqual(self.remote_task_mock.mock_calls,
[mock.call.set_running()])
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
def test_on_message_state_progress(self):
response = dict(state=pr.PROGRESS, progress=1.0)
ex = self.executor()
ex._store_remote_task(self.remote_task_mock)
ex._on_message(response, self.message_mock)
self.assertEqual(self.remote_task_mock.mock_calls,
[mock.call.on_progress(progress=1.0)])
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
def test_on_message_state_failure(self):
failure = misc.Failure.from_exception(Exception('test'))
failure_dict = pu.failure_to_dict(failure)
response = dict(state=pr.FAILURE, result=failure_dict)
ex = self.executor()
ex._store_remote_task(self.remote_task_mock)
ex._on_message(response, self.message_mock)
self.assertEqual(len(ex._remote_tasks), 0)
self.assertEqual(self.remote_task_mock.mock_calls, [
mock.call.set_result(result=utils.FailureMatcher(failure))
])
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
def test_on_message_state_success(self):
response = dict(state=pr.SUCCESS, result=self.task_result,
event='executed')
ex = self.executor()
ex._store_remote_task(self.remote_task_mock)
ex._on_message(response, self.message_mock)
self.assertEqual(self.remote_task_mock.mock_calls,
[mock.call.set_result(result=self.task_result,
event='executed')])
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
def test_on_message_unknown_state(self):
response = dict(state='unknown')
ex = self.executor()
ex._store_remote_task(self.remote_task_mock)
ex._on_message(response, self.message_mock)
self.assertEqual(self.remote_task_mock.mock_calls, [])
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
def test_on_message_non_existent_task(self):
self.message_mock.properties = {'correlation_id': 'non-existent'}
response = dict(state=pr.RUNNING)
ex = self.executor()
ex._store_remote_task(self.remote_task_mock)
ex._on_message(response, self.message_mock)
self.assertEqual(self.remote_task_mock.mock_calls, [])
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
def test_on_message_no_correlation_id(self):
self.message_mock.properties = {}
response = dict(state=pr.RUNNING)
ex = self.executor()
ex._store_remote_task(self.remote_task_mock)
ex._on_message(response, self.message_mock)
self.assertEqual(self.remote_task_mock.mock_calls, [])
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
@mock.patch('taskflow.engines.worker_based.executor.LOG.warning')
def test_on_message_acknowledge_raises(self, mocked_warning):
self.message_mock.ack.side_effect = kombu_exc.MessageStateError()
self.executor()._on_message({}, self.message_mock)
self.assertTrue(mocked_warning.called)
@mock.patch('taskflow.engines.worker_based.remote_task.time.time')
def test_on_wait_task_not_expired(self, mock_time):
mock_time.side_effect = [1, self.timeout]
ex = self.executor()
ex._store_remote_task(self.remote_task())
self.assertEqual(len(ex._remote_tasks), 1)
ex._on_wait()
self.assertEqual(len(ex._remote_tasks), 1)
@mock.patch('taskflow.engines.worker_based.remote_task.time.time')
def test_on_wait_task_expired(self, mock_time):
mock_time.side_effect = [1, self.timeout + 2, self.timeout * 2]
ex = self.executor()
ex._store_remote_task(self.remote_task())
self.assertEqual(len(ex._remote_tasks), 1)
ex._on_wait()
self.assertEqual(len(ex._remote_tasks), 0)
def test_remove_task_non_existent(self):
task = self.remote_task()
ex = self.executor()
ex._store_remote_task(task)
self.assertEqual(len(ex._remote_tasks), 1)
ex._remove_remote_task(task)
self.assertEqual(len(ex._remote_tasks), 0)
# remove non-existent
ex._remove_remote_task(task)
self.assertEqual(len(ex._remote_tasks), 0)
def test_execute_task(self):
request = self.request(action='execute')
ex = self.executor()
result = ex.execute_task(self.task, self.task_uuid, self.task_args)
expected_calls = [
mock.call.proxy.publish(request, self.task_uuid,
routing_key=self.executor_topic,
reply_to=self.executor_uuid)
]
self.assertEqual(self.master_mock.mock_calls, expected_calls)
self.assertIsInstance(result, futures.Future)
def test_revert_task(self):
request = self.request(action='revert',
result=('success', self.task_result),
failures=self.task_failures)
ex = self.executor()
result = ex.revert_task(self.task, self.task_uuid, self.task_args,
self.task_result, self.task_failures)
expected_calls = [
mock.call.proxy.publish(request, self.task_uuid,
routing_key=self.executor_topic,
reply_to=self.executor_uuid)
]
self.assertEqual(self.master_mock.mock_calls, expected_calls)
self.assertIsInstance(result, futures.Future)
def test_execute_task_topic_not_found(self):
workers_info = {self.executor_topic: ['non-existent-task']}
ex = self.executor(workers_info=workers_info)
result = ex.execute_task(self.task, self.task_uuid, self.task_args)
self.assertFalse(self.proxy_inst_mock.publish.called)
# check execute result
task, event, res = result.result()
self.assertEqual(task, self.task)
self.assertEqual(event, 'executed')
self.assertIsInstance(res, misc.Failure)
def test_execute_task_publish_error(self):
self.proxy_inst_mock.publish.side_effect = Exception('Woot!')
request = self.request(action='execute')
ex = self.executor()
result = ex.execute_task(self.task, self.task_uuid, self.task_args)
expected_calls = [
mock.call.proxy.publish(request, self.task_uuid,
routing_key=self.executor_topic,
reply_to=self.executor_uuid)
]
self.assertEqual(self.master_mock.mock_calls, expected_calls)
# check execute result
task, event, res = result.result()
self.assertEqual(task, self.task)
self.assertEqual(event, 'executed')
self.assertIsInstance(res, misc.Failure)
def test_wait_for_any(self):
fs = [futures.Future(), futures.Future()]
ex = self.executor()
ex.wait_for_any(fs)
expected_calls = [
mock.call(fs, None)
]
self.assertEqual(self.wait_for_any_mock.mock_calls, expected_calls)
def test_wait_for_any_with_timeout(self):
timeout = 30
fs = [futures.Future(), futures.Future()]
ex = self.executor()
ex.wait_for_any(fs, timeout)
master_mock_calls = [
mock.call(fs, timeout)
]
self.assertEqual(self.wait_for_any_mock.mock_calls, master_mock_calls)
def test_start_stop(self):
ex = self.executor()
ex.start()
# make sure proxy thread started
self.proxy_started_event.wait()
# stop executor
ex.stop()
self.master_mock.assert_has_calls([
mock.call.proxy.start(),
mock.call.proxy.wait(),
mock.call.proxy.stop()
], any_order=True)
def test_start_already_running(self):
ex = self.executor()
ex.start()
# make sure proxy thread started
self.proxy_started_event.wait()
# start executor again
ex.start()
# stop executor
ex.stop()
self.master_mock.assert_has_calls([
mock.call.proxy.start(),
mock.call.proxy.wait(),
mock.call.proxy.stop()
], any_order=True)
def test_stop_not_running(self):
self.executor().stop()
self.assertEqual(self.master_mock.mock_calls, [])
def test_stop_not_alive(self):
self.proxy_inst_mock.start.side_effect = None
# start executor
ex = self.executor()
ex.start()
# wait until executor thread is done
ex._proxy_thread.join()
# stop executor
ex.stop()
# since proxy thread is already done - stop is not called
self.master_mock.assert_has_calls([
mock.call.proxy.start(),
mock.call.proxy.wait()
], any_order=True)
def test_restart(self):
ex = self.executor()
ex.start()
# make sure thread started
self.proxy_started_event.wait()
# restart executor
ex.stop()
ex.start()
# make sure thread started
self.proxy_started_event.wait()
# stop executor
ex.stop()
self.master_mock.assert_has_calls([
mock.call.proxy.start(),
mock.call.proxy.wait(),
mock.call.proxy.stop(),
mock.call.proxy.start(),
mock.call.proxy.wait(),
mock.call.proxy.stop()
], any_order=True)

View File

@ -0,0 +1,301 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import mock
import socket
import threading
from amqp import exceptions as amqp_exc
from taskflow.engines.worker_based import proxy
from taskflow import test
class TestProxy(test.MockTestCase):
def setUp(self):
super(TestProxy, self).setUp()
self.uuid = 'test-uuid'
self.broker_url = 'test-url'
self.exchange_name = 'test-exchange'
self.timeout = 5
self.queue_arguments = {
'x-expires': proxy.pr.QUEUE_EXPIRE_TIMEOUT * 1000
}
self.de_period = proxy.DRAIN_EVENTS_PERIOD
# patch classes
self.conn_mock, self.conn_inst_mock = self._patch_class(
proxy.kombu, 'Connection')
self.exchange_mock, self.exchange_inst_mock = self._patch_class(
proxy.kombu, 'Exchange')
self.queue_mock, self.queue_inst_mock = self._patch_class(
proxy.kombu, 'Queue')
self.producer_mock, self.producer_inst_mock = self._patch_class(
proxy.kombu, 'Producer')
# connection mocking
self.conn_inst_mock.drain_events.side_effect = [
socket.timeout, socket.timeout, KeyboardInterrupt]
# connections mocking
self.connections_mock = self._patch(
"taskflow.engines.worker_based.proxy.kombu.connections",
attach_as='connections')
self.connections_mock.__getitem__().acquire().__enter__.return_value =\
self.conn_inst_mock
# producers mocking
self.producers_mock = self._patch(
"taskflow.engines.worker_based.proxy.kombu.producers",
attach_as='producers')
self.producers_mock.__getitem__().acquire().__enter__.return_value =\
self.producer_inst_mock
# consumer mocking
self.conn_inst_mock.Consumer.return_value.__enter__ = mock.MagicMock()
self.conn_inst_mock.Consumer.return_value.__exit__ = mock.MagicMock()
# other mocking
self.on_message_mock = mock.MagicMock(name='on_message')
self.on_wait_mock = mock.MagicMock(name='on_wait')
self.master_mock.attach_mock(self.on_wait_mock, 'on_wait')
# reset master mock
self._reset_master_mock()
def _queue_name(self, uuid):
return "%s_%s" % (self.exchange_name, uuid)
def proxy_start_calls(self, calls, exc_type=mock.ANY):
return [
mock.call.Queue(name=self._queue_name(self.uuid),
exchange=self.exchange_inst_mock,
routing_key=self.uuid,
durable=False,
queue_arguments=self.queue_arguments,
channel=self.conn_inst_mock),
mock.call.connection.Consumer(queues=self.queue_inst_mock,
callbacks=[self.on_message_mock]),
mock.call.connection.Consumer().__enter__(),
] + calls + [
mock.call.connection.Consumer().__exit__(exc_type, mock.ANY,
mock.ANY),
mock.ANY,
mock.call.queue.delete(if_unused=True)
]
def proxy(self, reset_master_mock=False, **kwargs):
proxy_kwargs = dict(uuid=self.uuid,
exchange_name=self.exchange_name,
on_message=self.on_message_mock,
url=self.broker_url)
proxy_kwargs.update(kwargs)
p = proxy.Proxy(**proxy_kwargs)
if reset_master_mock:
self._reset_master_mock()
return p
def test_creation(self):
self.proxy()
master_mock_calls = [
mock.call.Connection(self.broker_url, transport=None,
transport_options=None),
mock.call.Exchange(name=self.exchange_name,
channel=self.conn_inst_mock,
durable=False,
auto_delete=True)
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_creation_custom(self):
transport_opts = {'context': 'context'}
self.proxy(transport='memory', transport_options=transport_opts)
master_mock_calls = [
mock.call.Connection(self.broker_url, transport='memory',
transport_options=transport_opts),
mock.call.Exchange(name=self.exchange_name,
channel=self.conn_inst_mock,
durable=False,
auto_delete=True)
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_publish(self):
task_data = 'task-data'
task_uuid = 'task-uuid'
routing_key = 'routing-key'
kwargs = dict(a='a', b='b')
self.proxy(reset_master_mock=True).publish(
task_data, task_uuid, routing_key, **kwargs)
master_mock_calls = [
mock.call.Queue(name=self._queue_name(routing_key),
exchange=self.exchange_inst_mock,
routing_key=routing_key,
durable=False,
queue_arguments=self.queue_arguments),
mock.call.producer.publish(body=task_data,
routing_key=routing_key,
exchange=self.exchange_inst_mock,
correlation_id=task_uuid,
declare=[self.queue_inst_mock],
**kwargs)
]
self.master_mock.assert_has_calls(master_mock_calls)
def test_start(self):
try:
# KeyboardInterrupt will be raised after two iterations
self.proxy(reset_master_mock=True).start()
except KeyboardInterrupt:
pass
master_calls = self.proxy_start_calls([
mock.call.connection.drain_events(timeout=self.de_period),
mock.call.connection.drain_events(timeout=self.de_period),
mock.call.connection.drain_events(timeout=self.de_period),
], exc_type=KeyboardInterrupt)
self.master_mock.assert_has_calls(master_calls)
def test_start_with_on_wait(self):
try:
# KeyboardInterrupt will be raised after two iterations
self.proxy(reset_master_mock=True,
on_wait=self.on_wait_mock).start()
except KeyboardInterrupt:
pass
master_calls = self.proxy_start_calls([
mock.call.connection.drain_events(timeout=self.de_period),
mock.call.on_wait(),
mock.call.connection.drain_events(timeout=self.de_period),
mock.call.on_wait(),
mock.call.connection.drain_events(timeout=self.de_period),
], exc_type=KeyboardInterrupt)
self.master_mock.assert_has_calls(master_calls)
def test_start_with_on_wait_raises(self):
self.on_wait_mock.side_effect = RuntimeError('Woot!')
try:
# KeyboardInterrupt will be raised after two iterations
self.proxy(reset_master_mock=True,
on_wait=self.on_wait_mock).start()
except KeyboardInterrupt:
pass
master_calls = self.proxy_start_calls([
mock.call.connection.drain_events(timeout=self.de_period),
mock.call.on_wait(),
], exc_type=RuntimeError)
self.master_mock.assert_has_calls(master_calls)
def test_start_queue_delete_not_found(self):
self.queue_inst_mock.delete.side_effect = amqp_exc.NotFound('Woot!')
try:
# KeyboardInterrupt will be raised after two iterations
self.proxy(reset_master_mock=True).start()
except KeyboardInterrupt:
pass
master_calls = self.proxy_start_calls([
mock.call.connection.drain_events(timeout=self.de_period),
mock.call.connection.drain_events(timeout=self.de_period),
mock.call.connection.drain_events(timeout=self.de_period),
], exc_type=KeyboardInterrupt)
self.master_mock.assert_has_calls(master_calls)
@mock.patch("taskflow.engines.worker_based.proxy.LOG.error")
def test_start_queue_delete_raises(self, mocked_error):
self.queue_inst_mock.delete.side_effect = RuntimeError('Woot!')
try:
# KeyboardInterrupt will be raised after two iterations
self.proxy(reset_master_mock=True).start()
except KeyboardInterrupt:
pass
master_calls = self.proxy_start_calls([
mock.call.connection.drain_events(timeout=self.de_period),
mock.call.connection.drain_events(timeout=self.de_period),
mock.call.connection.drain_events(timeout=self.de_period),
], exc_type=KeyboardInterrupt)
self.master_mock.assert_has_calls(master_calls)
self.assertTrue(mocked_error.called)
def test_start_exchange_delete_not_found(self):
self.exchange_inst_mock.delete.side_effect = amqp_exc.NotFound('Woot!')
try:
# KeyboardInterrupt will be raised after two iterations
self.proxy(reset_master_mock=True).start()
except KeyboardInterrupt:
pass
master_calls = self.proxy_start_calls([
mock.call.connection.drain_events(timeout=self.de_period),
mock.call.connection.drain_events(timeout=self.de_period),
mock.call.connection.drain_events(timeout=self.de_period),
], exc_type=KeyboardInterrupt)
self.master_mock.assert_has_calls(master_calls)
@mock.patch("taskflow.engines.worker_based.proxy.LOG.error")
def test_start_exchange_delete_raises(self, mocked_error):
self.exchange_inst_mock.delete.side_effect = RuntimeError('Woot!')
try:
# KeyboardInterrupt will be raised after two iterations
self.proxy(reset_master_mock=True).start()
except KeyboardInterrupt:
pass
master_calls = self.proxy_start_calls([
mock.call.connection.drain_events(timeout=self.de_period),
mock.call.connection.drain_events(timeout=self.de_period),
mock.call.connection.drain_events(timeout=self.de_period),
], exc_type=KeyboardInterrupt)
self.master_mock.assert_has_calls(master_calls)
self.assertTrue(mocked_error.called)
def test_stop(self):
self.conn_inst_mock.drain_events.side_effect = socket.timeout
# create proxy
pr = self.proxy(reset_master_mock=True)
# check that proxy is not running yes
self.assertFalse(pr.is_running)
# start proxy in separate thread
t = threading.Thread(target=pr.start)
t.daemon = True
t.start()
# make sure proxy is started
pr.wait()
# check that proxy is running now
self.assertTrue(pr.is_running)
# stop proxy and wait for thread to finish
pr.stop()
# wait for thread to finish
t.join()
self.assertFalse(pr.is_running)

View File

@ -0,0 +1,135 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import mock
from concurrent import futures
from taskflow.engines.worker_based import remote_task as rt
from taskflow import test
from taskflow.tests import utils
from taskflow.utils import misc
from taskflow.utils import persistence_utils as pu
class TestRemoteTask(test.TestCase):
def setUp(self):
super(TestRemoteTask, self).setUp()
self.task = utils.DummyTask()
self.task_uuid = 'task-uuid'
self.task_action = 'execute'
self.task_args = {'context': 'context'}
self.timeout = 60
def remote_task(self, **kwargs):
task_kwargs = dict(task=self.task,
uuid=self.task_uuid,
action=self.task_action,
arguments=self.task_args,
progress_callback=None,
timeout=self.timeout)
task_kwargs.update(kwargs)
return rt.RemoteTask(**task_kwargs)
def remote_task_request(self, **kwargs):
request = dict(task=self.task.name,
task_name=self.task.name,
task_version=self.task.version,
action=self.task_action,
arguments=self.task_args)
request.update(kwargs)
return request
def test_creation(self):
remote_task = self.remote_task()
self.assertEqual(remote_task.uuid, self.task_uuid)
self.assertEqual(remote_task.name, self.task.name)
self.assertIsInstance(remote_task.result, futures.Future)
self.assertFalse(remote_task.result.done())
def test_repr(self):
expected_name = '%s:%s' % (self.task.name, self.task_action)
self.assertEqual(repr(self.remote_task()), expected_name)
def test_request(self):
remote_task = self.remote_task()
request = self.remote_task_request()
self.assertEqual(remote_task.request, request)
def test_request_with_result(self):
remote_task = self.remote_task(result=333)
request = self.remote_task_request(result=('success', 333))
self.assertEqual(remote_task.request, request)
def test_request_with_result_none(self):
remote_task = self.remote_task(result=None)
request = self.remote_task_request(result=('success', None))
self.assertEqual(remote_task.request, request)
def test_request_with_result_failure(self):
failure = misc.Failure.from_exception(RuntimeError('Woot!'))
remote_task = self.remote_task(result=failure)
request = self.remote_task_request(
result=('failure', pu.failure_to_dict(failure)))
self.assertEqual(remote_task.request, request)
def test_request_with_failures(self):
failure = misc.Failure.from_exception(RuntimeError('Woot!'))
remote_task = self.remote_task(failures={self.task.name: failure})
request = self.remote_task_request(
failures={self.task.name: pu.failure_to_dict(failure)})
self.assertEqual(remote_task.request, request)
@mock.patch('time.time')
def test_pending_not_expired(self, mock_time):
mock_time.side_effect = [1, self.timeout]
remote_task = self.remote_task()
self.assertFalse(remote_task.expired)
@mock.patch('time.time')
def test_pending_expired(self, mock_time):
mock_time.side_effect = [1, self.timeout + 2]
remote_task = self.remote_task()
self.assertTrue(remote_task.expired)
@mock.patch('time.time')
def test_running_not_expired(self, mock_time):
mock_time.side_effect = [1, self.timeout]
remote_task = self.remote_task()
remote_task.set_running()
self.assertFalse(remote_task.expired)
def test_set_result(self):
remote_task = self.remote_task()
remote_task.set_result(111)
result = remote_task.result.result()
self.assertEqual(result, (self.task, 'executed', 111))
def test_on_progress(self):
progress_callback = mock.MagicMock(name='progress_callback')
remote_task = self.remote_task(task=self.task,
progress_callback=progress_callback)
remote_task.on_progress('event_data', 0.0)
remote_task.on_progress('event_data', 1.0)
expected_calls = [
mock.call(self.task, 'event_data', 0.0),
mock.call(self.task, 'event_data', 1.0)
]
self.assertEqual(progress_callback.mock_calls, expected_calls)

View File

@ -0,0 +1,379 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import mock
from kombu import exceptions as exc
from taskflow.engines.worker_based import endpoint as ep
from taskflow.engines.worker_based import protocol as pr
from taskflow.engines.worker_based import server
from taskflow import test
from taskflow.tests import utils
from taskflow.utils import misc
from taskflow.utils import persistence_utils as pu
class TestServer(test.MockTestCase):
def setUp(self):
super(TestServer, self).setUp()
self.server_uuid = 'server-uuid'
self.server_exchange = 'server-exchange'
self.broker_url = 'test-url'
self.task_uuid = 'task-uuid'
self.task_args = {'x': 1}
self.task_action = 'execute'
self.task_name = 'taskflow.tests.utils.TaskOneArgOneReturn'
self.task_version = (1, 0)
self.reply_to = 'reply-to'
self.endpoints = [ep.Endpoint(task_cls=utils.TaskOneArgOneReturn),
ep.Endpoint(task_cls=utils.TaskWithFailure),
ep.Endpoint(task_cls=utils.ProgressingTask)]
self.resp_running = dict(state=pr.RUNNING)
# patch classes
self.proxy_mock, self.proxy_inst_mock = self._patch_class(
server.proxy, 'Proxy')
# other mocking
self.proxy_inst_mock.is_running = True
self.executor_mock = mock.MagicMock(name='executor')
self.message_mock = mock.MagicMock(name='message')
self.message_mock.properties = {'correlation_id': self.task_uuid,
'reply_to': self.reply_to}
self.master_mock.attach_mock(self.executor_mock, 'executor')
self.master_mock.attach_mock(self.message_mock, 'message')
def server(self, reset_master_mock=False, **kwargs):
server_kwargs = dict(uuid=self.server_uuid,
exchange=self.server_exchange,
executor=self.executor_mock,
endpoints=self.endpoints,
url=self.broker_url)
server_kwargs.update(kwargs)
s = server.Server(**server_kwargs)
if reset_master_mock:
self._reset_master_mock()
return s
def request(self, **kwargs):
request = dict(task=self.task_name,
task_name=self.task_name,
action=self.task_action,
task_version=self.task_version,
arguments=self.task_args)
request.update(kwargs)
return request
@staticmethod
def resp_progress(progress):
return dict(state=pr.PROGRESS, progress=progress, event_data={})
@staticmethod
def resp_success(result):
return dict(state=pr.SUCCESS, result=result)
@staticmethod
def resp_failure(result, **kwargs):
response = dict(state=pr.FAILURE, result=result)
response.update(kwargs)
return response
def test_creation(self):
s = self.server()
# check calls
master_mock_calls = [
mock.call.Proxy(self.server_uuid, self.server_exchange,
s._on_message, url=self.broker_url)
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
self.assertEqual(len(s._endpoints), 3)
def test_creation_with_endpoints(self):
s = self.server(endpoints=self.endpoints)
# check calls
master_mock_calls = [
mock.call.Proxy(self.server_uuid, self.server_exchange,
s._on_message, url=self.broker_url)
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
self.assertEqual(len(s._endpoints), len(self.endpoints))
def test_on_message_proxy_running_ack_success(self):
request = self.request()
s = self.server(reset_master_mock=True)
s._on_message(request, self.message_mock)
# check calls
master_mock_calls = [
mock.call.message.ack(),
mock.call.executor.submit(s._process_request, request,
self.message_mock)
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_on_message_proxy_running_ack_failure(self):
self.message_mock.ack.side_effect = exc.MessageStateError('Woot!')
s = self.server(reset_master_mock=True)
s._on_message({}, self.message_mock)
# check calls
master_mock_calls = [
mock.call.message.ack()
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_on_message_proxy_not_running_reject_success(self):
self.proxy_inst_mock.is_running = False
s = self.server(reset_master_mock=True)
s._on_message({}, self.message_mock)
# check calls
master_mock_calls = [
mock.call.message.reject(requeue=True)
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_on_message_proxy_not_running_reject_failure(self):
self.message_mock.reject.side_effect = exc.MessageStateError('Woot!')
self.proxy_inst_mock.is_running = False
s = self.server(reset_master_mock=True)
s._on_message({}, self.message_mock)
# check calls
master_mock_calls = [
mock.call.message.reject(requeue=True)
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_parse_request(self):
request = self.request()
task, action, task_args = server.Server._parse_request(**request)
self.assertEqual((task, action, task_args),
(self.task_name, self.task_action,
dict(task_name=self.task_name,
arguments=self.task_args)))
def test_parse_request_with_success_result(self):
request = self.request(action='revert', result=('success', 1))
task, action, task_args = server.Server._parse_request(**request)
self.assertEqual((task, action, task_args),
(self.task_name, 'revert',
dict(task_name=self.task_name,
arguments=self.task_args,
result=1)))
def test_parse_request_with_failure_result(self):
failure = misc.Failure.from_exception(Exception('test'))
failure_dict = pu.failure_to_dict(failure)
request = self.request(action='revert',
result=('failure', failure_dict))
task, action, task_args = server.Server._parse_request(**request)
self.assertEqual((task, action, task_args),
(self.task_name, 'revert',
dict(task_name=self.task_name,
arguments=self.task_args,
result=utils.FailureMatcher(failure))))
def test_parse_request_with_failures(self):
failures = [misc.Failure.from_exception(Exception('test1')),
misc.Failure.from_exception(Exception('test2'))]
failures_dict = dict((str(i), pu.failure_to_dict(f))
for i, f in enumerate(failures))
request = self.request(action='revert', failures=failures_dict)
task, action, task_args = server.Server._parse_request(**request)
self.assertEqual(
(task, action, task_args),
(self.task_name, 'revert',
dict(task_name=self.task_name,
arguments=self.task_args,
failures=dict((str(i), utils.FailureMatcher(f))
for i, f in enumerate(failures)))))
@mock.patch("taskflow.engines.worker_based.server.LOG.error")
def test_reply_publish_failure(self, mocked_error):
self.proxy_inst_mock.publish.side_effect = RuntimeError('Woot!')
# create server and process request
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
s._reply(self.reply_to, self.task_uuid)
self.assertEqual(self.master_mock.mock_calls, [
mock.call.proxy.publish({'state': 'FAILURE'}, self.task_uuid,
self.reply_to)
])
self.assertEqual(mocked_error.mock_calls, [
mock.call("Failed to send reply: Woot!")
])
def test_on_update_progress(self):
request = self.request(task='taskflow.tests.utils.ProgressingTask',
arguments={})
# create server and process request
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
s._process_request(request, self.message_mock)
# check calls
master_mock_calls = [
mock.call.proxy.publish(self.resp_running, self.task_uuid,
self.reply_to),
mock.call.proxy.publish(self.resp_progress(0.0), self.task_uuid,
self.reply_to),
mock.call.proxy.publish(self.resp_progress(1.0), self.task_uuid,
self.reply_to),
mock.call.proxy.publish(self.resp_success(5), self.task_uuid,
self.reply_to)
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_process_request(self):
# create server and process request
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
s._process_request(self.request(), self.message_mock)
# check calls
master_mock_calls = [
mock.call.proxy.publish(self.resp_running, self.task_uuid,
self.reply_to),
mock.call.proxy.publish(self.resp_success(1), self.task_uuid,
self.reply_to)
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
@mock.patch("taskflow.engines.worker_based.server.LOG.error")
def test_process_request_parse_message_failure(self, mocked_error):
self.message_mock.properties = {}
request = self.request()
s = self.server(reset_master_mock=True)
s._process_request(request, self.message_mock)
self.assertEqual(self.master_mock.mock_calls, [])
self.assertTrue(mocked_error.called)
@mock.patch('taskflow.engines.worker_based.server.pu')
def test_process_request_parse_failure(self, pu_mock):
failure_dict = 'failure_dict'
pu_mock.failure_to_dict.return_value = failure_dict
pu_mock.failure_from_dict.side_effect = ValueError('Woot!')
request = self.request(result=('failure', 1))
# create server and process request
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
s._process_request(request, self.message_mock)
# check calls
master_mock_calls = [
mock.call.proxy.publish(self.resp_failure(failure_dict),
self.task_uuid, self.reply_to)
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
@mock.patch('taskflow.engines.worker_based.server.pu')
def test_process_request_endpoint_not_found(self, pu_mock):
failure_dict = 'failure_dict'
pu_mock.failure_to_dict.return_value = failure_dict
request = self.request(task='<unknown>')
# create server and process request
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
s._process_request(request, self.message_mock)
# check calls
master_mock_calls = [
mock.call.proxy.publish(self.resp_failure(failure_dict),
self.task_uuid, self.reply_to)
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
@mock.patch('taskflow.engines.worker_based.server.pu')
def test_process_request_execution_failure(self, pu_mock):
failure_dict = 'failure_dict'
pu_mock.failure_to_dict.return_value = failure_dict
request = self.request(action='<unknown>')
# create server and process request
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
s._process_request(request, self.message_mock)
# check calls
master_mock_calls = [
mock.call.proxy.publish(self.resp_running, self.task_uuid,
self.reply_to),
mock.call.proxy.publish(self.resp_failure(failure_dict),
self.task_uuid, self.reply_to)
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
@mock.patch('taskflow.engines.worker_based.server.pu')
def test_process_request_task_failure(self, pu_mock):
failure_dict = 'failure_dict'
pu_mock.failure_to_dict.return_value = failure_dict
request = self.request(task='taskflow.tests.utils.TaskWithFailure',
arguments={})
# create server and process request
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
s._process_request(request, self.message_mock)
# check calls
master_mock_calls = [
mock.call.proxy.publish(self.resp_running, self.task_uuid,
self.reply_to),
mock.call.proxy.publish(self.resp_failure(failure_dict),
self.task_uuid, self.reply_to)
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_start(self):
self.server(reset_master_mock=True).start()
# check calls
master_mock_calls = [
mock.call.proxy.start()
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_wait(self):
server = self.server(reset_master_mock=True)
server.start()
server.wait()
# check calls
master_mock_calls = [
mock.call.proxy.start(),
mock.call.proxy.wait()
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_stop(self):
self.server(reset_master_mock=True).stop()
# check calls
master_mock_calls = [
mock.call.proxy.stop(),
mock.call.executor.shutdown()
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)

View File

@ -0,0 +1,173 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import mock
from taskflow.engines.worker_based import endpoint
from taskflow.engines.worker_based import worker
from taskflow import test
from taskflow.tests import utils
from taskflow.utils import reflection
class TestWorker(test.MockTestCase):
def setUp(self):
super(TestWorker, self).setUp()
self.task_cls = utils.DummyTask
self.task_name = reflection.get_class_name(self.task_cls)
self.broker_url = 'test-url'
self.exchange = 'test-exchange'
self.topic = 'test-topic'
self.threads_count = 5
self.endpoint_count = 18
# patch classes
self.executor_mock, self.executor_inst_mock = self._patch_class(
worker.futures, 'ThreadPoolExecutor', attach_as='executor')
self.server_mock, self.server_inst_mock = self._patch_class(
worker.server, 'Server')
# other mocking
self.threads_count_mock = self._patch(
'taskflow.engines.worker_based.worker.tu.get_optimal_thread_count')
self.threads_count_mock.return_value = self.threads_count
def worker(self, reset_master_mock=False, **kwargs):
worker_kwargs = dict(exchange=self.exchange,
topic=self.topic,
tasks=[],
url=self.broker_url)
worker_kwargs.update(kwargs)
w = worker.Worker(**worker_kwargs)
if reset_master_mock:
self._reset_master_mock()
return w
def test_creation(self):
self.worker()
master_mock_calls = [
mock.call.executor_class(self.threads_count),
mock.call.Server(self.topic, self.exchange,
self.executor_inst_mock, [], url=self.broker_url)
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_creation_with_custom_threads_count(self):
self.worker(threads_count=10)
master_mock_calls = [
mock.call.executor_class(10),
mock.call.Server(self.topic, self.exchange,
self.executor_inst_mock, [], url=self.broker_url)
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_creation_with_custom_executor(self):
executor_mock = mock.MagicMock(name='executor')
self.worker(executor=executor_mock)
master_mock_calls = [
mock.call.Server(self.topic, self.exchange, executor_mock, [],
url=self.broker_url)
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_run_with_no_tasks(self):
self.worker(reset_master_mock=True).run()
master_mock_calls = [
mock.call.server.start()
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_run_with_tasks(self):
self.worker(reset_master_mock=True,
tasks=['taskflow.tests.utils:DummyTask']).run()
master_mock_calls = [
mock.call.server.start()
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_wait(self):
w = self.worker(reset_master_mock=True)
w.run()
w.wait()
master_mock_calls = [
mock.call.server.start(),
mock.call.server.wait()
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_stop(self):
self.worker(reset_master_mock=True).stop()
master_mock_calls = [
mock.call.server.stop()
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_derive_endpoints_from_string_tasks(self):
endpoints = worker.Worker._derive_endpoints(
['taskflow.tests.utils:DummyTask'])
self.assertEqual(len(endpoints), 1)
self.assertIsInstance(endpoints[0], endpoint.Endpoint)
self.assertEqual(endpoints[0].name, self.task_name)
def test_derive_endpoints_from_string_modules(self):
endpoints = worker.Worker._derive_endpoints(['taskflow.tests.utils'])
self.assertEqual(len(endpoints), self.endpoint_count)
def test_derive_endpoints_from_string_non_existent_module(self):
tasks = ['non.existent.module']
self.assertRaises(ImportError, worker.Worker._derive_endpoints, tasks)
def test_derive_endpoints_from_string_non_existent_task(self):
tasks = ['non.existent.module:Task']
self.assertRaises(ImportError, worker.Worker._derive_endpoints, tasks)
def test_derive_endpoints_from_string_non_task_class(self):
tasks = ['taskflow.tests.utils:FakeTask']
self.assertRaises(TypeError, worker.Worker._derive_endpoints, tasks)
def test_derive_endpoints_from_tasks(self):
endpoints = worker.Worker._derive_endpoints([self.task_cls])
self.assertEqual(len(endpoints), 1)
self.assertIsInstance(endpoints[0], endpoint.Endpoint)
self.assertEqual(endpoints[0].name, self.task_name)
def test_derive_endpoints_from_non_task_class(self):
self.assertRaises(TypeError, worker.Worker._derive_endpoints,
[utils.FakeTask])
def test_derive_endpoints_from_modules(self):
endpoints = worker.Worker._derive_endpoints([utils])
self.assertEqual(len(endpoints), self.endpoint_count)
def test_derive_endpoints_unexpected_task_type(self):
self.assertRaises(TypeError, worker.Worker._derive_endpoints, [111])

View File

@ -49,10 +49,17 @@ def make_reverting_task(token, blowup=False):
class DummyTask(task.Task):
def execute(self, context, *args, **kwargs):
pass
class FakeTask(object):
def execute(self, **kwargs):
pass
if six.PY3:
RUNTIME_ERROR_CLASSES = ['RuntimeError', 'Exception',
'BaseException', 'object']
@ -101,7 +108,22 @@ class FailingTask(SaveOrderTask):
raise RuntimeError('Woot!')
class TaskWithFailure(task.Task):
def execute(self, **kwargs):
raise RuntimeError('Woot!')
class ProgressingTask(task.Task):
def execute(self, *args, **kwargs):
self.update_progress(0.0)
self.update_progress(1.0)
return 5
class NastyTask(task.Task):
def execute(self, **kwargs):
pass
@ -220,3 +242,16 @@ class EngineTestBase(object):
def _make_engine(self, flow, flow_detail=None):
raise NotImplementedError()
class FailureMatcher(object):
"""Needed for failure objects comparison."""
def __init__(self, failure):
self._failure = failure
def __repr__(self):
return str(self._failure)
def __eq__(self, other):
return self._failure.matches(other)

View File

@ -122,6 +122,11 @@ def is_bound_method(method):
return bool(get_method_self(method))
def is_subclass(obj, cls):
"""Returns if the object is class and it is subclass of a given class."""
return inspect.isclass(obj) and issubclass(obj, cls)
def _get_arg_spec(function):
if isinstance(function, type):
bound = True

View File

@ -43,6 +43,7 @@ deps = -r{toxinidir}/requirements.txt
alembic>=0.4.1
psycopg2
kazoo>=1.3.1
kombu>=2.4.8
commands = python setup.py testr --slowest --testr-args='{posargs}'
[tox:jenkins]