Merge "Some WBE protocol/executor cleanups"

This commit is contained in:
Jenkins
2016-02-15 05:22:20 +00:00
committed by Gerrit Code Review
8 changed files with 118 additions and 139 deletions

View File

@@ -421,7 +421,18 @@ Implementations
Components
----------
.. warning::
External usage of internal engine functions, components and modules should
be kept to a **minimum** as they may be altered, refactored or moved to
other locations **without** notice (and without the typical deprecation
cycle).
.. automodule:: taskflow.engines.worker_based.dispatcher
.. automodule:: taskflow.engines.worker_based.endpoint
.. automodule:: taskflow.engines.worker_based.executor
.. automodule:: taskflow.engines.worker_based.proxy
.. automodule:: taskflow.engines.worker_based.worker
.. automodule:: taskflow.engines.worker_based.types
.. _kombu: http://kombu.readthedocs.org/

View File

@@ -92,16 +92,19 @@ class WorkerTaskExecutor(executor.TaskExecutor):
if response.state == pr.RUNNING:
request.transition_and_log_error(pr.RUNNING, logger=LOG)
elif response.state == pr.EVENT:
# Proxy the event + details to the task/request notifier...
# Proxy the event + details to the task notifier so
# that it shows up in the local process (and activates
# any local callbacks...); thus making it look like
# the task is running locally (in some regards).
event_type = response.data['event_type']
details = response.data['details']
request.notifier.notify(event_type, details)
request.task.notifier.notify(event_type, details)
elif response.state in (pr.FAILURE, pr.SUCCESS):
if request.transition_and_log_error(response.state,
logger=LOG):
with self._ongoing_requests_lock:
del self._ongoing_requests[request.uuid]
request.set_result(**response.data)
request.set_result(result=response.data['result'])
else:
LOG.warning("Unexpected response status '%s'",
response.state)
@@ -184,18 +187,19 @@ class WorkerTaskExecutor(executor.TaskExecutor):
self._clean()
def _submit_task(self, task, task_uuid, action, arguments,
progress_callback=None, **kwargs):
progress_callback=None, result=pr.NO_RESULT,
failures=None):
"""Submit task request to a worker."""
request = pr.Request(task, task_uuid, action, arguments,
self._transition_timeout, **kwargs)
timeout=self._transition_timeout,
result=result, failures=failures)
# Register the callback, so that we can proxy the progress correctly.
if (progress_callback is not None and
request.notifier.can_be_registered(EVENT_UPDATE_PROGRESS)):
request.notifier.register(EVENT_UPDATE_PROGRESS, progress_callback)
cleaner = functools.partial(request.notifier.deregister,
EVENT_UPDATE_PROGRESS,
progress_callback)
request.result.add_done_callback(lambda fut: cleaner())
task.notifier.can_be_registered(EVENT_UPDATE_PROGRESS)):
task.notifier.register(EVENT_UPDATE_PROGRESS, progress_callback)
request.future.add_done_callback(
lambda _fut: task.notifier.deregister(EVENT_UPDATE_PROGRESS,
progress_callback))
# Get task's worker and publish request if worker was found.
worker = self._finder.get_worker_for_task(task)
if worker is not None:
@@ -208,7 +212,7 @@ class WorkerTaskExecutor(executor.TaskExecutor):
" worker/s available to process it", request)
with self._ongoing_requests_lock:
self._ongoing_requests[request.uuid] = request
return request.result
return request.future
def _publish_request(self, request, worker):
"""Publish request to a given topic."""
@@ -238,8 +242,8 @@ class WorkerTaskExecutor(executor.TaskExecutor):
def revert_task(self, task, task_uuid, arguments, result, failures,
progress_callback=None):
return self._submit_task(task, task_uuid, pr.REVERT, arguments,
progress_callback=progress_callback,
result=result, failures=failures)
result=result, failures=failures,
progress_callback=progress_callback)
def wait_for_workers(self, workers=1, timeout=None):
"""Waits for geq workers to notify they are ready to do work.

View File

@@ -98,6 +98,9 @@ NOTIFY = 'NOTIFY'
REQUEST = 'REQUEST'
RESPONSE = 'RESPONSE'
# Object that denotes nothing (none can actually be valid).
NO_RESULT = object()
LOG = logging.getLogger(__name__)
@@ -252,44 +255,26 @@ class Request(Message):
'required': ['task_cls', 'task_name', 'task_version', 'action'],
}
def __init__(self, task, uuid, action, arguments, timeout, **kwargs):
self._task = task
self._uuid = uuid
def __init__(self, task, uuid, action,
arguments, timeout=REQUEST_TIMEOUT, result=NO_RESULT,
failures=None):
self._action = action
self._event = ACTION_TO_EVENT[action]
self._arguments = arguments
self._kwargs = kwargs
self._result = result
self._failures = failures
self._watch = timeutils.StopWatch(duration=timeout).start()
self._state = WAITING
self._lock = threading.Lock()
self._created_on = timeutils.now()
self._result = futurist.Future()
self._result.atom = task
self._notifier = task.notifier
self.state = WAITING
self.task = task
self.uuid = uuid
self.created_on = timeutils.now()
self.future = futurist.Future()
self.future.atom = task
@property
def result(self):
return self._result
@property
def notifier(self):
return self._notifier
@property
def uuid(self):
return self._uuid
@property
def task(self):
return self._task
@property
def state(self):
return self._state
@property
def created_on(self):
return self._created_on
def set_result(self, result):
"""Sets the responses futures result."""
self.future.set_result((self._event, result))
@property
def expired(self):
@@ -302,7 +287,7 @@ class Request(Message):
state for more then the given timeout (it is not considered to be
expired in any other state).
"""
if self._state in WAITING_STATES:
if self.state in WAITING_STATES:
return self._watch.expired()
return False
@@ -313,30 +298,25 @@ class Request(Message):
convert all `failure.Failure` objects into dictionaries (which will
then be reconstituted by the receiver).
"""
request = {
'task_cls': reflection.get_class_name(self._task),
'task_name': self._task.name,
'task_version': self._task.version,
'task_cls': reflection.get_class_name(self.task),
'task_name': self.task.name,
'task_version': self.task.version,
'action': self._action,
'arguments': self._arguments,
}
if 'result' in self._kwargs:
result = self._kwargs['result']
if self._result is not NO_RESULT:
result = self._result
if isinstance(result, ft.Failure):
request['result'] = ('failure', failure_to_dict(result))
else:
request['result'] = ('success', result)
if 'failures' in self._kwargs:
failures = self._kwargs['failures']
if self._failures:
request['failures'] = {}
for task, failure in six.iteritems(failures):
request['failures'][task] = failure_to_dict(failure)
for atom_name, failure in six.iteritems(self._failures):
request['failures'][atom_name] = failure_to_dict(failure)
return request
def set_result(self, result):
self.result.set_result((self._event, result))
def transition_and_log_error(self, new_state, logger=None):
"""Transitions *and* logs an error if that transitioning raises.
@@ -366,7 +346,7 @@ class Request(Message):
valid (and will not be performed), it raises an InvalidState
exception.
"""
old_state = self._state
old_state = self.state
if old_state == new_state:
return False
pair = (old_state, new_state)
@@ -375,7 +355,7 @@ class Request(Message):
" not allowed" % pair)
if new_state in _STOP_TIMER_STATES:
self._watch.stop()
self._state = new_state
self.state = new_state
LOG.debug("Transitioned '%s' from %s state to %s state", self,
old_state, new_state)
return True
@@ -504,8 +484,8 @@ class Response(Message):
}
def __init__(self, state, **data):
self._state = state
self._data = data
self.state = state
self.data = data
@classmethod
def from_dict(cls, data):
@@ -515,16 +495,8 @@ class Response(Message):
data['result'] = ft.Failure.from_dict(data['result'])
return cls(state, **data)
@property
def state(self):
return self._state
@property
def data(self):
return self._data
def to_dict(self):
return dict(state=self._state, data=self._data)
return dict(state=self.state, data=self.data)
@classmethod
def validate(cls, data):

View File

@@ -14,7 +14,6 @@
# License for the specific language governing permissions and limitations
# under the License.
import abc
import random
import threading
@@ -77,16 +76,24 @@ class TopicWorker(object):
return r
@six.add_metaclass(abc.ABCMeta)
class WorkerFinder(object):
"""Base class for worker finders..."""
class ProxyWorkerFinder(object):
"""Requests and receives responses about workers topic+task details."""
def __init__(self):
def __init__(self, uuid, proxy, topics,
beat_periodicity=pr.NOTIFY_PERIOD):
self._cond = threading.Condition()
self._proxy = proxy
self._topics = topics
self._workers = {}
self._uuid = uuid
self._seen_workers = 0
self._messages_processed = 0
self._messages_published = 0
self._watch = timeutils.StopWatch(duration=beat_periodicity)
@abc.abstractmethod
def _total_workers(self):
"""Returns how many workers are known."""
def total_workers(self):
"""Number of workers currently known."""
return len(self._workers)
def wait_for_workers(self, workers=1, timeout=None):
"""Waits for geq workers to notify they are ready to do work.
@@ -102,9 +109,9 @@ class WorkerFinder(object):
watch = timeutils.StopWatch(duration=timeout)
watch.start()
with self._cond:
while self._total_workers() < workers:
while len(self._workers) < workers:
if watch.expired():
return max(0, workers - self._total_workers())
return max(0, workers - len(self._workers))
self._cond.wait(watch.leftover(return_none=True))
return 0
@@ -124,28 +131,9 @@ class WorkerFinder(object):
else:
return random.choice(available_workers)
@abc.abstractmethod
def get_worker_for_task(self, task):
"""Gets a worker that can perform a given task."""
class ProxyWorkerFinder(WorkerFinder):
"""Requests and receives responses about workers topic+task details."""
def __init__(self, uuid, proxy, topics,
beat_periodicity=pr.NOTIFY_PERIOD):
super(ProxyWorkerFinder, self).__init__()
self._proxy = proxy
self._topics = topics
self._workers = {}
self._uuid = uuid
self._seen_workers = 0
self._messages_processed = 0
self._messages_published = 0
self._watch = timeutils.StopWatch(duration=beat_periodicity)
@property
def messages_processed(self):
"""How many notify response messages have been processed."""
return self._messages_processed
def _next_worker(self, topic, tasks, temporary=False):
@@ -175,9 +163,6 @@ class ProxyWorkerFinder(WorkerFinder):
self._messages_published += 1
self._watch.restart()
def _total_workers(self):
return len(self._workers)
def _add(self, topic, tasks):
"""Adds/updates a worker for the topic for the given tasks."""
try:
@@ -207,7 +192,7 @@ class ProxyWorkerFinder(WorkerFinder):
response.tasks)
if new_or_updated:
LOG.debug("Updated worker '%s' (%s total workers are"
" currently known)", worker, self._total_workers())
" currently known)", worker, len(self._workers))
self._cond.notify_all()
self._messages_processed += 1
@@ -220,6 +205,7 @@ class ProxyWorkerFinder(WorkerFinder):
self._cond.notify_all()
def get_worker_for_task(self, task):
"""Gets a worker that can perform a given task."""
available_workers = []
with self._cond:
for worker in six.itervalues(self._workers):

View File

@@ -110,8 +110,8 @@ class TestWorkerTaskExecutor(test.MockTestCase):
ex._process_response(response.to_dict(), self.message_mock)
expected_calls = [
mock.call.notifier.notify(task_atom.EVENT_UPDATE_PROGRESS,
{'progress': 1.0}),
mock.call.task.notifier.notify(task_atom.EVENT_UPDATE_PROGRESS,
{'progress': 1.0}),
]
self.assertEqual(expected_calls, self.request_inst_mock.mock_calls)
@@ -139,7 +139,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
expected_calls = [
mock.call.transition_and_log_error(pr.SUCCESS, logger=mock.ANY),
mock.call.set_result(result=self.task_result, event='executed')
mock.call.set_result(result=self.task_result)
]
self.assertEqual(expected_calls, self.request_inst_mock.mock_calls)
@@ -198,7 +198,8 @@ class TestWorkerTaskExecutor(test.MockTestCase):
expected_calls = [
mock.call.Request(self.task, self.task_uuid, 'execute',
self.task_args, self.timeout),
self.task_args, timeout=self.timeout,
result=mock.ANY, failures=mock.ANY),
mock.call.request.transition_and_log_error(pr.PENDING,
logger=mock.ANY),
mock.call.proxy.publish(self.request_inst_mock,
@@ -216,7 +217,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
expected_calls = [
mock.call.Request(self.task, self.task_uuid, 'revert',
self.task_args, self.timeout,
self.task_args, timeout=self.timeout,
failures=self.task_failures,
result=self.task_result),
mock.call.request.transition_and_log_error(pr.PENDING,
@@ -234,7 +235,8 @@ class TestWorkerTaskExecutor(test.MockTestCase):
expected_calls = [
mock.call.Request(self.task, self.task_uuid, 'execute',
self.task_args, self.timeout),
self.task_args, timeout=self.timeout,
result=mock.ANY, failures=mock.ANY),
]
self.assertEqual(expected_calls, self.master_mock.mock_calls)
@@ -246,7 +248,8 @@ class TestWorkerTaskExecutor(test.MockTestCase):
expected_calls = [
mock.call.Request(self.task, self.task_uuid, 'execute',
self.task_args, self.timeout),
self.task_args, timeout=self.timeout,
result=mock.ANY, failures=mock.ANY),
mock.call.request.transition_and_log_error(pr.PENDING,
logger=mock.ANY),
mock.call.proxy.publish(self.request_inst_mock,

View File

@@ -14,7 +14,6 @@
# License for the specific language governing permissions and limitations
# under the License.
import futurist
from oslo_utils import uuidutils
from taskflow.engines.action_engine import executor
@@ -51,9 +50,10 @@ class TestProtocolValidation(test.TestCase):
pr.Notify.validate, msg, True)
def test_request(self):
msg = pr.Request(utils.DummyTask("hi"), uuidutils.generate_uuid(),
pr.EXECUTE, {}, 1.0)
pr.Request.validate(msg.to_dict())
request = pr.Request(utils.DummyTask("hi"),
uuidutils.generate_uuid(),
pr.EXECUTE, {}, 1.0)
pr.Request.validate(request.to_dict())
def test_request_invalid(self):
msg = {
@@ -64,11 +64,12 @@ class TestProtocolValidation(test.TestCase):
self.assertRaises(excp.InvalidFormat, pr.Request.validate, msg)
def test_request_invalid_action(self):
msg = pr.Request(utils.DummyTask("hi"), uuidutils.generate_uuid(),
pr.EXECUTE, {}, 1.0)
msg = msg.to_dict()
msg['action'] = 'NOTHING'
self.assertRaises(excp.InvalidFormat, pr.Request.validate, msg)
request = pr.Request(utils.DummyTask("hi"),
uuidutils.generate_uuid(),
pr.EXECUTE, {}, 1.0)
request = request.to_dict()
request['action'] = 'NOTHING'
self.assertRaises(excp.InvalidFormat, pr.Request.validate, request)
def test_response_progress(self):
msg = pr.Response(pr.EVENT, details={'progress': 0.5},
@@ -105,7 +106,6 @@ class TestProtocol(test.TestCase):
uuid=self.task_uuid,
action=self.task_action,
arguments=self.task_args,
progress_callback=None,
timeout=self.timeout)
request_kwargs.update(kwargs)
return pr.Request(**request_kwargs)
@@ -135,25 +135,28 @@ class TestProtocol(test.TestCase):
request = self.request()
self.assertEqual(self.task_uuid, request.uuid)
self.assertEqual(self.task, request.task)
self.assertIsInstance(request.result, futurist.Future)
self.assertFalse(request.result.done())
self.assertFalse(request.future.done())
def test_to_dict_default(self):
self.assertEqual(self.request_to_dict(), self.request().to_dict())
request = self.request()
self.assertEqual(self.request_to_dict(), request.to_dict())
def test_to_dict_with_result(self):
request = self.request(result=333)
self.assertEqual(self.request_to_dict(result=('success', 333)),
self.request(result=333).to_dict())
request.to_dict())
def test_to_dict_with_result_none(self):
request = self.request(result=None)
self.assertEqual(self.request_to_dict(result=('success', None)),
self.request(result=None).to_dict())
request.to_dict())
def test_to_dict_with_result_failure(self):
a_failure = failure.Failure.from_exception(RuntimeError('Woot!'))
expected = self.request_to_dict(result=('failure',
a_failure.to_dict()))
self.assertEqual(expected, self.request(result=a_failure).to_dict())
request = self.request(result=a_failure)
self.assertEqual(expected, request.to_dict())
def test_to_dict_with_failures(self):
a_failure = failure.Failure.from_exception(RuntimeError('Woot!'))
@@ -173,16 +176,16 @@ class TestProtocol(test.TestCase):
@mock.patch('oslo_utils.timeutils.now')
def test_pending_not_expired(self, now):
now.return_value = 0
req = self.request()
request = self.request()
now.return_value = self.timeout - 1
self.assertFalse(req.expired)
self.assertFalse(request.expired)
@mock.patch('oslo_utils.timeutils.now')
def test_pending_expired(self, now):
now.return_value = 0
req = self.request()
request = self.request()
now.return_value = self.timeout + 1
self.assertTrue(req.expired)
self.assertTrue(request.expired)
@mock.patch('oslo_utils.timeutils.now')
def test_running_not_expired(self, now):
@@ -196,5 +199,5 @@ class TestProtocol(test.TestCase):
def test_set_result(self):
request = self.request()
request.set_result(111)
result = request.result.result()
result = request.future.result()
self.assertEqual((executor.EXECUTED, 111), result)

View File

@@ -75,10 +75,10 @@ class TestServer(test.MockTestCase):
uuid=self.task_uuid,
action=self.task_action,
arguments=self.task_args,
progress_callback=None,
timeout=60)
request_kwargs.update(kwargs)
return pr.Request(**request_kwargs).to_dict()
request = pr.Request(**request_kwargs)
return request.to_dict()
def test_creation(self):
s = self.server()

View File

@@ -38,7 +38,7 @@ class TestProxyFinder(test.TestCase):
w, emit = finder._add('dummy-topic', [utils.DummyTask])
self.assertIsNotNone(w)
self.assertTrue(emit)
self.assertEqual(1, finder._total_workers())
self.assertEqual(1, finder.total_workers())
w2 = finder.get_worker_for_task(utils.DummyTask)
self.assertEqual(w.identity, w2.identity)
@@ -60,7 +60,7 @@ class TestProxyFinder(test.TestCase):
added.append(finder._add('dummy-topic', [utils.DummyTask]))
added.append(finder._add('dummy-topic-2', [utils.DummyTask]))
added.append(finder._add('dummy-topic-3', [utils.NastyTask]))
self.assertEqual(3, finder._total_workers())
self.assertEqual(3, finder.total_workers())
w = finder.get_worker_for_task(utils.NastyTask)
self.assertEqual(added[-1][0].identity, w.identity)
w = finder.get_worker_for_task(utils.DummyTask)