Some WBE protocol/executor cleanups

Remove some of the usage of @property as none of
these objects are publicly exposed (or have docstrings
on them) to save some space/lines of code that aren't
really adding any benefit.

Use less **kwargs when we know exactly what the keyword
arguments will or will not be. Being explicit makes it
easier to understand these functions (vs not knowing what
the arguments can or can't be).

Removes base worker finder because right now we only
have one implementation (at some point we will have
two) but we can just wait to add a base class until
then.

Change-Id: I7107ff6b77a355b4c5d301948355fb6386605388
This commit is contained in:
Joshua Harlow
2016-02-05 14:30:09 -08:00
committed by Joshua Harlow
parent a70bd8a7e5
commit 1ab60b7e98
8 changed files with 118 additions and 139 deletions

View File

@@ -421,7 +421,18 @@ Implementations
Components
----------
.. warning::
External usage of internal engine functions, components and modules should
be kept to a **minimum** as they may be altered, refactored or moved to
other locations **without** notice (and without the typical deprecation
cycle).
.. automodule:: taskflow.engines.worker_based.dispatcher
.. automodule:: taskflow.engines.worker_based.endpoint
.. automodule:: taskflow.engines.worker_based.executor
.. automodule:: taskflow.engines.worker_based.proxy
.. automodule:: taskflow.engines.worker_based.worker
.. automodule:: taskflow.engines.worker_based.types
.. _kombu: http://kombu.readthedocs.org/

View File

@@ -92,16 +92,19 @@ class WorkerTaskExecutor(executor.TaskExecutor):
if response.state == pr.RUNNING:
request.transition_and_log_error(pr.RUNNING, logger=LOG)
elif response.state == pr.EVENT:
# Proxy the event + details to the task/request notifier...
# Proxy the event + details to the task notifier so
# that it shows up in the local process (and activates
# any local callbacks...); thus making it look like
# the task is running locally (in some regards).
event_type = response.data['event_type']
details = response.data['details']
request.notifier.notify(event_type, details)
request.task.notifier.notify(event_type, details)
elif response.state in (pr.FAILURE, pr.SUCCESS):
if request.transition_and_log_error(response.state,
logger=LOG):
with self._ongoing_requests_lock:
del self._ongoing_requests[request.uuid]
request.set_result(**response.data)
request.set_result(result=response.data['result'])
else:
LOG.warning("Unexpected response status '%s'",
response.state)
@@ -184,18 +187,19 @@ class WorkerTaskExecutor(executor.TaskExecutor):
self._clean()
def _submit_task(self, task, task_uuid, action, arguments,
progress_callback=None, **kwargs):
progress_callback=None, result=pr.NO_RESULT,
failures=None):
"""Submit task request to a worker."""
request = pr.Request(task, task_uuid, action, arguments,
self._transition_timeout, **kwargs)
timeout=self._transition_timeout,
result=result, failures=failures)
# Register the callback, so that we can proxy the progress correctly.
if (progress_callback is not None and
request.notifier.can_be_registered(EVENT_UPDATE_PROGRESS)):
request.notifier.register(EVENT_UPDATE_PROGRESS, progress_callback)
cleaner = functools.partial(request.notifier.deregister,
EVENT_UPDATE_PROGRESS,
progress_callback)
request.result.add_done_callback(lambda fut: cleaner())
task.notifier.can_be_registered(EVENT_UPDATE_PROGRESS)):
task.notifier.register(EVENT_UPDATE_PROGRESS, progress_callback)
request.future.add_done_callback(
lambda _fut: task.notifier.deregister(EVENT_UPDATE_PROGRESS,
progress_callback))
# Get task's worker and publish request if worker was found.
worker = self._finder.get_worker_for_task(task)
if worker is not None:
@@ -208,7 +212,7 @@ class WorkerTaskExecutor(executor.TaskExecutor):
" worker/s available to process it", request)
with self._ongoing_requests_lock:
self._ongoing_requests[request.uuid] = request
return request.result
return request.future
def _publish_request(self, request, worker):
"""Publish request to a given topic."""
@@ -238,8 +242,8 @@ class WorkerTaskExecutor(executor.TaskExecutor):
def revert_task(self, task, task_uuid, arguments, result, failures,
progress_callback=None):
return self._submit_task(task, task_uuid, pr.REVERT, arguments,
progress_callback=progress_callback,
result=result, failures=failures)
result=result, failures=failures,
progress_callback=progress_callback)
def wait_for_workers(self, workers=1, timeout=None):
"""Waits for geq workers to notify they are ready to do work.

View File

@@ -98,6 +98,9 @@ NOTIFY = 'NOTIFY'
REQUEST = 'REQUEST'
RESPONSE = 'RESPONSE'
# Object that denotes nothing (none can actually be valid).
NO_RESULT = object()
LOG = logging.getLogger(__name__)
@@ -252,44 +255,26 @@ class Request(Message):
'required': ['task_cls', 'task_name', 'task_version', 'action'],
}
def __init__(self, task, uuid, action, arguments, timeout, **kwargs):
self._task = task
self._uuid = uuid
def __init__(self, task, uuid, action,
arguments, timeout=REQUEST_TIMEOUT, result=NO_RESULT,
failures=None):
self._action = action
self._event = ACTION_TO_EVENT[action]
self._arguments = arguments
self._kwargs = kwargs
self._result = result
self._failures = failures
self._watch = timeutils.StopWatch(duration=timeout).start()
self._state = WAITING
self._lock = threading.Lock()
self._created_on = timeutils.now()
self._result = futurist.Future()
self._result.atom = task
self._notifier = task.notifier
self.state = WAITING
self.task = task
self.uuid = uuid
self.created_on = timeutils.now()
self.future = futurist.Future()
self.future.atom = task
@property
def result(self):
return self._result
@property
def notifier(self):
return self._notifier
@property
def uuid(self):
return self._uuid
@property
def task(self):
return self._task
@property
def state(self):
return self._state
@property
def created_on(self):
return self._created_on
def set_result(self, result):
"""Sets the responses futures result."""
self.future.set_result((self._event, result))
@property
def expired(self):
@@ -302,7 +287,7 @@ class Request(Message):
state for more then the given timeout (it is not considered to be
expired in any other state).
"""
if self._state in WAITING_STATES:
if self.state in WAITING_STATES:
return self._watch.expired()
return False
@@ -313,30 +298,25 @@ class Request(Message):
convert all `failure.Failure` objects into dictionaries (which will
then be reconstituted by the receiver).
"""
request = {
'task_cls': reflection.get_class_name(self._task),
'task_name': self._task.name,
'task_version': self._task.version,
'task_cls': reflection.get_class_name(self.task),
'task_name': self.task.name,
'task_version': self.task.version,
'action': self._action,
'arguments': self._arguments,
}
if 'result' in self._kwargs:
result = self._kwargs['result']
if self._result is not NO_RESULT:
result = self._result
if isinstance(result, ft.Failure):
request['result'] = ('failure', failure_to_dict(result))
else:
request['result'] = ('success', result)
if 'failures' in self._kwargs:
failures = self._kwargs['failures']
if self._failures:
request['failures'] = {}
for task, failure in six.iteritems(failures):
request['failures'][task] = failure_to_dict(failure)
for atom_name, failure in six.iteritems(self._failures):
request['failures'][atom_name] = failure_to_dict(failure)
return request
def set_result(self, result):
self.result.set_result((self._event, result))
def transition_and_log_error(self, new_state, logger=None):
"""Transitions *and* logs an error if that transitioning raises.
@@ -366,7 +346,7 @@ class Request(Message):
valid (and will not be performed), it raises an InvalidState
exception.
"""
old_state = self._state
old_state = self.state
if old_state == new_state:
return False
pair = (old_state, new_state)
@@ -375,7 +355,7 @@ class Request(Message):
" not allowed" % pair)
if new_state in _STOP_TIMER_STATES:
self._watch.stop()
self._state = new_state
self.state = new_state
LOG.debug("Transitioned '%s' from %s state to %s state", self,
old_state, new_state)
return True
@@ -504,8 +484,8 @@ class Response(Message):
}
def __init__(self, state, **data):
self._state = state
self._data = data
self.state = state
self.data = data
@classmethod
def from_dict(cls, data):
@@ -515,16 +495,8 @@ class Response(Message):
data['result'] = ft.Failure.from_dict(data['result'])
return cls(state, **data)
@property
def state(self):
return self._state
@property
def data(self):
return self._data
def to_dict(self):
return dict(state=self._state, data=self._data)
return dict(state=self.state, data=self.data)
@classmethod
def validate(cls, data):

View File

@@ -14,7 +14,6 @@
# License for the specific language governing permissions and limitations
# under the License.
import abc
import random
import threading
@@ -77,16 +76,24 @@ class TopicWorker(object):
return r
@six.add_metaclass(abc.ABCMeta)
class WorkerFinder(object):
"""Base class for worker finders..."""
class ProxyWorkerFinder(object):
"""Requests and receives responses about workers topic+task details."""
def __init__(self):
def __init__(self, uuid, proxy, topics,
beat_periodicity=pr.NOTIFY_PERIOD):
self._cond = threading.Condition()
self._proxy = proxy
self._topics = topics
self._workers = {}
self._uuid = uuid
self._seen_workers = 0
self._messages_processed = 0
self._messages_published = 0
self._watch = timeutils.StopWatch(duration=beat_periodicity)
@abc.abstractmethod
def _total_workers(self):
"""Returns how many workers are known."""
def total_workers(self):
"""Number of workers currently known."""
return len(self._workers)
def wait_for_workers(self, workers=1, timeout=None):
"""Waits for geq workers to notify they are ready to do work.
@@ -102,9 +109,9 @@ class WorkerFinder(object):
watch = timeutils.StopWatch(duration=timeout)
watch.start()
with self._cond:
while self._total_workers() < workers:
while len(self._workers) < workers:
if watch.expired():
return max(0, workers - self._total_workers())
return max(0, workers - len(self._workers))
self._cond.wait(watch.leftover(return_none=True))
return 0
@@ -124,28 +131,9 @@ class WorkerFinder(object):
else:
return random.choice(available_workers)
@abc.abstractmethod
def get_worker_for_task(self, task):
"""Gets a worker that can perform a given task."""
class ProxyWorkerFinder(WorkerFinder):
"""Requests and receives responses about workers topic+task details."""
def __init__(self, uuid, proxy, topics,
beat_periodicity=pr.NOTIFY_PERIOD):
super(ProxyWorkerFinder, self).__init__()
self._proxy = proxy
self._topics = topics
self._workers = {}
self._uuid = uuid
self._seen_workers = 0
self._messages_processed = 0
self._messages_published = 0
self._watch = timeutils.StopWatch(duration=beat_periodicity)
@property
def messages_processed(self):
"""How many notify response messages have been processed."""
return self._messages_processed
def _next_worker(self, topic, tasks, temporary=False):
@@ -175,9 +163,6 @@ class ProxyWorkerFinder(WorkerFinder):
self._messages_published += 1
self._watch.restart()
def _total_workers(self):
return len(self._workers)
def _add(self, topic, tasks):
"""Adds/updates a worker for the topic for the given tasks."""
try:
@@ -207,7 +192,7 @@ class ProxyWorkerFinder(WorkerFinder):
response.tasks)
if new_or_updated:
LOG.debug("Updated worker '%s' (%s total workers are"
" currently known)", worker, self._total_workers())
" currently known)", worker, len(self._workers))
self._cond.notify_all()
self._messages_processed += 1
@@ -220,6 +205,7 @@ class ProxyWorkerFinder(WorkerFinder):
self._cond.notify_all()
def get_worker_for_task(self, task):
"""Gets a worker that can perform a given task."""
available_workers = []
with self._cond:
for worker in six.itervalues(self._workers):

View File

@@ -110,8 +110,8 @@ class TestWorkerTaskExecutor(test.MockTestCase):
ex._process_response(response.to_dict(), self.message_mock)
expected_calls = [
mock.call.notifier.notify(task_atom.EVENT_UPDATE_PROGRESS,
{'progress': 1.0}),
mock.call.task.notifier.notify(task_atom.EVENT_UPDATE_PROGRESS,
{'progress': 1.0}),
]
self.assertEqual(expected_calls, self.request_inst_mock.mock_calls)
@@ -139,7 +139,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
expected_calls = [
mock.call.transition_and_log_error(pr.SUCCESS, logger=mock.ANY),
mock.call.set_result(result=self.task_result, event='executed')
mock.call.set_result(result=self.task_result)
]
self.assertEqual(expected_calls, self.request_inst_mock.mock_calls)
@@ -198,7 +198,8 @@ class TestWorkerTaskExecutor(test.MockTestCase):
expected_calls = [
mock.call.Request(self.task, self.task_uuid, 'execute',
self.task_args, self.timeout),
self.task_args, timeout=self.timeout,
result=mock.ANY, failures=mock.ANY),
mock.call.request.transition_and_log_error(pr.PENDING,
logger=mock.ANY),
mock.call.proxy.publish(self.request_inst_mock,
@@ -216,7 +217,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
expected_calls = [
mock.call.Request(self.task, self.task_uuid, 'revert',
self.task_args, self.timeout,
self.task_args, timeout=self.timeout,
failures=self.task_failures,
result=self.task_result),
mock.call.request.transition_and_log_error(pr.PENDING,
@@ -234,7 +235,8 @@ class TestWorkerTaskExecutor(test.MockTestCase):
expected_calls = [
mock.call.Request(self.task, self.task_uuid, 'execute',
self.task_args, self.timeout),
self.task_args, timeout=self.timeout,
result=mock.ANY, failures=mock.ANY),
]
self.assertEqual(expected_calls, self.master_mock.mock_calls)
@@ -246,7 +248,8 @@ class TestWorkerTaskExecutor(test.MockTestCase):
expected_calls = [
mock.call.Request(self.task, self.task_uuid, 'execute',
self.task_args, self.timeout),
self.task_args, timeout=self.timeout,
result=mock.ANY, failures=mock.ANY),
mock.call.request.transition_and_log_error(pr.PENDING,
logger=mock.ANY),
mock.call.proxy.publish(self.request_inst_mock,

View File

@@ -14,7 +14,6 @@
# License for the specific language governing permissions and limitations
# under the License.
import futurist
from oslo_utils import uuidutils
from taskflow.engines.action_engine import executor
@@ -51,9 +50,10 @@ class TestProtocolValidation(test.TestCase):
pr.Notify.validate, msg, True)
def test_request(self):
msg = pr.Request(utils.DummyTask("hi"), uuidutils.generate_uuid(),
pr.EXECUTE, {}, 1.0)
pr.Request.validate(msg.to_dict())
request = pr.Request(utils.DummyTask("hi"),
uuidutils.generate_uuid(),
pr.EXECUTE, {}, 1.0)
pr.Request.validate(request.to_dict())
def test_request_invalid(self):
msg = {
@@ -64,11 +64,12 @@ class TestProtocolValidation(test.TestCase):
self.assertRaises(excp.InvalidFormat, pr.Request.validate, msg)
def test_request_invalid_action(self):
msg = pr.Request(utils.DummyTask("hi"), uuidutils.generate_uuid(),
pr.EXECUTE, {}, 1.0)
msg = msg.to_dict()
msg['action'] = 'NOTHING'
self.assertRaises(excp.InvalidFormat, pr.Request.validate, msg)
request = pr.Request(utils.DummyTask("hi"),
uuidutils.generate_uuid(),
pr.EXECUTE, {}, 1.0)
request = request.to_dict()
request['action'] = 'NOTHING'
self.assertRaises(excp.InvalidFormat, pr.Request.validate, request)
def test_response_progress(self):
msg = pr.Response(pr.EVENT, details={'progress': 0.5},
@@ -105,7 +106,6 @@ class TestProtocol(test.TestCase):
uuid=self.task_uuid,
action=self.task_action,
arguments=self.task_args,
progress_callback=None,
timeout=self.timeout)
request_kwargs.update(kwargs)
return pr.Request(**request_kwargs)
@@ -135,25 +135,28 @@ class TestProtocol(test.TestCase):
request = self.request()
self.assertEqual(self.task_uuid, request.uuid)
self.assertEqual(self.task, request.task)
self.assertIsInstance(request.result, futurist.Future)
self.assertFalse(request.result.done())
self.assertFalse(request.future.done())
def test_to_dict_default(self):
self.assertEqual(self.request_to_dict(), self.request().to_dict())
request = self.request()
self.assertEqual(self.request_to_dict(), request.to_dict())
def test_to_dict_with_result(self):
request = self.request(result=333)
self.assertEqual(self.request_to_dict(result=('success', 333)),
self.request(result=333).to_dict())
request.to_dict())
def test_to_dict_with_result_none(self):
request = self.request(result=None)
self.assertEqual(self.request_to_dict(result=('success', None)),
self.request(result=None).to_dict())
request.to_dict())
def test_to_dict_with_result_failure(self):
a_failure = failure.Failure.from_exception(RuntimeError('Woot!'))
expected = self.request_to_dict(result=('failure',
a_failure.to_dict()))
self.assertEqual(expected, self.request(result=a_failure).to_dict())
request = self.request(result=a_failure)
self.assertEqual(expected, request.to_dict())
def test_to_dict_with_failures(self):
a_failure = failure.Failure.from_exception(RuntimeError('Woot!'))
@@ -173,16 +176,16 @@ class TestProtocol(test.TestCase):
@mock.patch('oslo_utils.timeutils.now')
def test_pending_not_expired(self, now):
now.return_value = 0
req = self.request()
request = self.request()
now.return_value = self.timeout - 1
self.assertFalse(req.expired)
self.assertFalse(request.expired)
@mock.patch('oslo_utils.timeutils.now')
def test_pending_expired(self, now):
now.return_value = 0
req = self.request()
request = self.request()
now.return_value = self.timeout + 1
self.assertTrue(req.expired)
self.assertTrue(request.expired)
@mock.patch('oslo_utils.timeutils.now')
def test_running_not_expired(self, now):
@@ -196,5 +199,5 @@ class TestProtocol(test.TestCase):
def test_set_result(self):
request = self.request()
request.set_result(111)
result = request.result.result()
result = request.future.result()
self.assertEqual((executor.EXECUTED, 111), result)

View File

@@ -75,10 +75,10 @@ class TestServer(test.MockTestCase):
uuid=self.task_uuid,
action=self.task_action,
arguments=self.task_args,
progress_callback=None,
timeout=60)
request_kwargs.update(kwargs)
return pr.Request(**request_kwargs).to_dict()
request = pr.Request(**request_kwargs)
return request.to_dict()
def test_creation(self):
s = self.server()

View File

@@ -38,7 +38,7 @@ class TestProxyFinder(test.TestCase):
w, emit = finder._add('dummy-topic', [utils.DummyTask])
self.assertIsNotNone(w)
self.assertTrue(emit)
self.assertEqual(1, finder._total_workers())
self.assertEqual(1, finder.total_workers())
w2 = finder.get_worker_for_task(utils.DummyTask)
self.assertEqual(w.identity, w2.identity)
@@ -60,7 +60,7 @@ class TestProxyFinder(test.TestCase):
added.append(finder._add('dummy-topic', [utils.DummyTask]))
added.append(finder._add('dummy-topic-2', [utils.DummyTask]))
added.append(finder._add('dummy-topic-3', [utils.NastyTask]))
self.assertEqual(3, finder._total_workers())
self.assertEqual(3, finder.total_workers())
w = finder.get_worker_for_task(utils.NastyTask)
self.assertEqual(added[-1][0].identity, w.identity)
w = finder.get_worker_for_task(utils.DummyTask)