285 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			285 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# -*- coding: utf-8 -*-
 | 
						|
 | 
						|
#    Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
 | 
						|
#
 | 
						|
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
 | 
						|
#    not use this file except in compliance with the License. You may obtain
 | 
						|
#    a copy of the License at
 | 
						|
#
 | 
						|
#         http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
#    Unless required by applicable law or agreed to in writing, software
 | 
						|
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
						|
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
						|
#    License for the specific language governing permissions and limitations
 | 
						|
#    under the License.
 | 
						|
 | 
						|
import functools
 | 
						|
import threading
 | 
						|
 | 
						|
from oslo_utils import timeutils
 | 
						|
import six
 | 
						|
 | 
						|
from taskflow.engines.action_engine import executor
 | 
						|
from taskflow.engines.worker_based import dispatcher
 | 
						|
from taskflow.engines.worker_based import protocol as pr
 | 
						|
from taskflow.engines.worker_based import proxy
 | 
						|
from taskflow.engines.worker_based import types as wt
 | 
						|
from taskflow import exceptions as exc
 | 
						|
from taskflow import logging
 | 
						|
from taskflow.task import EVENT_UPDATE_PROGRESS  # noqa
 | 
						|
from taskflow.utils import kombu_utils as ku
 | 
						|
from taskflow.utils import misc
 | 
						|
from taskflow.utils import threading_utils as tu
 | 
						|
 | 
						|
LOG = logging.getLogger(__name__)
 | 
						|
 | 
						|
 | 
						|
class WorkerTaskExecutor(executor.TaskExecutor):
 | 
						|
    """Executes tasks on remote workers."""
 | 
						|
 | 
						|
    def __init__(self, uuid, exchange, topics,
 | 
						|
                 transition_timeout=pr.REQUEST_TIMEOUT,
 | 
						|
                 url=None, transport=None, transport_options=None,
 | 
						|
                 retry_options=None, worker_expiry=pr.EXPIRES_AFTER):
 | 
						|
        self._uuid = uuid
 | 
						|
        self._ongoing_requests = {}
 | 
						|
        self._ongoing_requests_lock = threading.RLock()
 | 
						|
        self._transition_timeout = transition_timeout
 | 
						|
        self._proxy = proxy.Proxy(uuid, exchange,
 | 
						|
                                  on_wait=self._on_wait, url=url,
 | 
						|
                                  transport=transport,
 | 
						|
                                  transport_options=transport_options,
 | 
						|
                                  retry_options=retry_options)
 | 
						|
        # NOTE(harlowja): This is the most simplest finder impl. that
 | 
						|
        # doesn't have external dependencies (outside of what this engine
 | 
						|
        # already requires); it though does create periodic 'polling' traffic
 | 
						|
        # to workers to 'learn' of the tasks they can perform (and requires
 | 
						|
        # pre-existing knowledge of the topics those workers are on to gather
 | 
						|
        # and update this information).
 | 
						|
        self._finder = wt.ProxyWorkerFinder(uuid, self._proxy, topics,
 | 
						|
                                            worker_expiry=worker_expiry)
 | 
						|
        self._proxy.dispatcher.type_handlers.update({
 | 
						|
            pr.RESPONSE: dispatcher.Handler(self._process_response,
 | 
						|
                                            validator=pr.Response.validate),
 | 
						|
            pr.NOTIFY: dispatcher.Handler(
 | 
						|
                self._finder.process_response,
 | 
						|
                validator=functools.partial(pr.Notify.validate,
 | 
						|
                                            response=True)),
 | 
						|
        })
 | 
						|
        # Thread that will run the message dispatching (and periodically
 | 
						|
        # call the on_wait callback to do various things) loop...
 | 
						|
        self._helper = None
 | 
						|
        self._messages_processed = {
 | 
						|
            'finder': self._finder.messages_processed,
 | 
						|
        }
 | 
						|
 | 
						|
    def _process_response(self, response, message):
 | 
						|
        """Process response from remote side."""
 | 
						|
        LOG.debug("Started processing response message '%s'",
 | 
						|
                  ku.DelayedPretty(message))
 | 
						|
        try:
 | 
						|
            request_uuid = message.properties['correlation_id']
 | 
						|
        except KeyError:
 | 
						|
            LOG.warning("The 'correlation_id' message property is"
 | 
						|
                        " missing in message '%s'",
 | 
						|
                        ku.DelayedPretty(message))
 | 
						|
        else:
 | 
						|
            request = self._ongoing_requests.get(request_uuid)
 | 
						|
            if request is not None:
 | 
						|
                response = pr.Response.from_dict(response)
 | 
						|
                LOG.debug("Extracted response '%s' and matched it to"
 | 
						|
                          " request '%s'", response, request)
 | 
						|
                if response.state == pr.RUNNING:
 | 
						|
                    request.transition_and_log_error(pr.RUNNING, logger=LOG)
 | 
						|
                elif response.state == pr.EVENT:
 | 
						|
                    # Proxy the event + details to the task notifier so
 | 
						|
                    # that it shows up in the local process (and activates
 | 
						|
                    # any local callbacks...); thus making it look like
 | 
						|
                    # the task is running locally (in some regards).
 | 
						|
                    event_type = response.data['event_type']
 | 
						|
                    details = response.data['details']
 | 
						|
                    request.task.notifier.notify(event_type, details)
 | 
						|
                elif response.state in (pr.FAILURE, pr.SUCCESS):
 | 
						|
                    if request.transition_and_log_error(response.state,
 | 
						|
                                                        logger=LOG):
 | 
						|
                        with self._ongoing_requests_lock:
 | 
						|
                            del self._ongoing_requests[request.uuid]
 | 
						|
                        request.set_result(result=response.data['result'])
 | 
						|
                else:
 | 
						|
                    LOG.warning("Unexpected response status '%s'",
 | 
						|
                                response.state)
 | 
						|
            else:
 | 
						|
                LOG.debug("Request with id='%s' not found", request_uuid)
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def _handle_expired_request(request):
 | 
						|
        """Handle a expired request.
 | 
						|
 | 
						|
        When a request has expired it is removed from the ongoing requests
 | 
						|
        dictionary and a ``RequestTimeout`` exception is set as a
 | 
						|
        request result.
 | 
						|
        """
 | 
						|
        if request.transition_and_log_error(pr.FAILURE, logger=LOG):
 | 
						|
            # Raise an exception (and then catch it) so we get a nice
 | 
						|
            # traceback that the request will get instead of it getting
 | 
						|
            # just an exception with no traceback...
 | 
						|
            try:
 | 
						|
                request_age = timeutils.now() - request.created_on
 | 
						|
                raise exc.RequestTimeout(
 | 
						|
                    "Request '%s' has expired after waiting for %0.2f"
 | 
						|
                    " seconds for it to transition out of (%s) states"
 | 
						|
                    % (request, request_age, ", ".join(pr.WAITING_STATES)))
 | 
						|
            except exc.RequestTimeout:
 | 
						|
                with misc.capture_failure() as failure:
 | 
						|
                    LOG.debug(failure.exception_str)
 | 
						|
                    request.set_result(failure)
 | 
						|
            return True
 | 
						|
        return False
 | 
						|
 | 
						|
    def _clean(self):
 | 
						|
        if not self._ongoing_requests:
 | 
						|
            return
 | 
						|
        with self._ongoing_requests_lock:
 | 
						|
            ongoing_requests_uuids = set(six.iterkeys(self._ongoing_requests))
 | 
						|
        waiting_requests = {}
 | 
						|
        expired_requests = {}
 | 
						|
        for request_uuid in ongoing_requests_uuids:
 | 
						|
            try:
 | 
						|
                request = self._ongoing_requests[request_uuid]
 | 
						|
            except KeyError:
 | 
						|
                # Guess it got removed before we got to it...
 | 
						|
                pass
 | 
						|
            else:
 | 
						|
                if request.expired:
 | 
						|
                    expired_requests[request_uuid] = request
 | 
						|
                elif request.current_state == pr.WAITING:
 | 
						|
                    waiting_requests[request_uuid] = request
 | 
						|
        if expired_requests:
 | 
						|
            with self._ongoing_requests_lock:
 | 
						|
                while expired_requests:
 | 
						|
                    request_uuid, request = expired_requests.popitem()
 | 
						|
                    if self._handle_expired_request(request):
 | 
						|
                        del self._ongoing_requests[request_uuid]
 | 
						|
        if waiting_requests:
 | 
						|
            finder = self._finder
 | 
						|
            new_messages_processed = finder.messages_processed
 | 
						|
            last_messages_processed = self._messages_processed['finder']
 | 
						|
            if new_messages_processed > last_messages_processed:
 | 
						|
                # Some new message got to the finder, so we can see
 | 
						|
                # if any new workers match (if no new messages have been
 | 
						|
                # processed we might as well not do anything).
 | 
						|
                while waiting_requests:
 | 
						|
                    _request_uuid, request = waiting_requests.popitem()
 | 
						|
                    worker = finder.get_worker_for_task(request.task)
 | 
						|
                    if (worker is not None and
 | 
						|
                            request.transition_and_log_error(pr.PENDING,
 | 
						|
                                                             logger=LOG)):
 | 
						|
                        self._publish_request(request, worker)
 | 
						|
                self._messages_processed['finder'] = new_messages_processed
 | 
						|
 | 
						|
    def _on_wait(self):
 | 
						|
        """This function is called cyclically between draining events."""
 | 
						|
        # Publish any finding messages (used to locate workers).
 | 
						|
        self._finder.maybe_publish()
 | 
						|
        # If the finder hasn't heard from workers in a given amount
 | 
						|
        # of time, then those workers are likely dead, so clean them out...
 | 
						|
        self._finder.clean()
 | 
						|
        # Process any expired requests or requests that have no current
 | 
						|
        # worker located (publish messages for those if we now do have
 | 
						|
        # a worker located).
 | 
						|
        self._clean()
 | 
						|
 | 
						|
    def _submit_task(self, task, task_uuid, action, arguments,
 | 
						|
                     progress_callback=None, result=pr.NO_RESULT,
 | 
						|
                     failures=None):
 | 
						|
        """Submit task request to a worker."""
 | 
						|
        request = pr.Request(task, task_uuid, action, arguments,
 | 
						|
                             timeout=self._transition_timeout,
 | 
						|
                             result=result, failures=failures)
 | 
						|
        # Register the callback, so that we can proxy the progress correctly.
 | 
						|
        if (progress_callback is not None and
 | 
						|
                task.notifier.can_be_registered(EVENT_UPDATE_PROGRESS)):
 | 
						|
            task.notifier.register(EVENT_UPDATE_PROGRESS, progress_callback)
 | 
						|
            request.future.add_done_callback(
 | 
						|
                lambda _fut: task.notifier.deregister(EVENT_UPDATE_PROGRESS,
 | 
						|
                                                      progress_callback))
 | 
						|
        # Get task's worker and publish request if worker was found.
 | 
						|
        worker = self._finder.get_worker_for_task(task)
 | 
						|
        if worker is not None:
 | 
						|
            if request.transition_and_log_error(pr.PENDING, logger=LOG):
 | 
						|
                with self._ongoing_requests_lock:
 | 
						|
                    self._ongoing_requests[request.uuid] = request
 | 
						|
                self._publish_request(request, worker)
 | 
						|
        else:
 | 
						|
            LOG.debug("Delaying submission of '%s', no currently known"
 | 
						|
                      " worker/s available to process it", request)
 | 
						|
            with self._ongoing_requests_lock:
 | 
						|
                self._ongoing_requests[request.uuid] = request
 | 
						|
        return request.future
 | 
						|
 | 
						|
    def _publish_request(self, request, worker):
 | 
						|
        """Publish request to a given topic."""
 | 
						|
        LOG.debug("Submitting execution of '%s' to worker '%s' (expecting"
 | 
						|
                  " response identified by reply_to=%s and"
 | 
						|
                  " correlation_id=%s) - waited %0.3f seconds to"
 | 
						|
                  " get published", request, worker, self._uuid,
 | 
						|
                  request.uuid, timeutils.now() - request.created_on)
 | 
						|
        try:
 | 
						|
            self._proxy.publish(request, worker.topic,
 | 
						|
                                reply_to=self._uuid,
 | 
						|
                                correlation_id=request.uuid)
 | 
						|
        except Exception:
 | 
						|
            with misc.capture_failure() as failure:
 | 
						|
                LOG.critical("Failed to submit '%s' (transitioning it to"
 | 
						|
                             " %s)", request, pr.FAILURE, exc_info=True)
 | 
						|
                if request.transition_and_log_error(pr.FAILURE, logger=LOG):
 | 
						|
                    with self._ongoing_requests_lock:
 | 
						|
                        del self._ongoing_requests[request.uuid]
 | 
						|
                    request.set_result(failure)
 | 
						|
 | 
						|
    def execute_task(self, task, task_uuid, arguments,
 | 
						|
                     progress_callback=None):
 | 
						|
        return self._submit_task(task, task_uuid, pr.EXECUTE, arguments,
 | 
						|
                                 progress_callback=progress_callback)
 | 
						|
 | 
						|
    def revert_task(self, task, task_uuid, arguments, result, failures,
 | 
						|
                    progress_callback=None):
 | 
						|
        return self._submit_task(task, task_uuid, pr.REVERT, arguments,
 | 
						|
                                 result=result, failures=failures,
 | 
						|
                                 progress_callback=progress_callback)
 | 
						|
 | 
						|
    def wait_for_workers(self, workers=1, timeout=None):
 | 
						|
        """Waits for geq workers to notify they are ready to do work.
 | 
						|
 | 
						|
        NOTE(harlowja): if a timeout is provided this function will wait
 | 
						|
        until that timeout expires, if the amount of workers does not reach
 | 
						|
        the desired amount of workers before the timeout expires then this will
 | 
						|
        return how many workers are still needed, otherwise it will
 | 
						|
        return zero.
 | 
						|
        """
 | 
						|
        return self._finder.wait_for_workers(workers=workers,
 | 
						|
                                             timeout=timeout)
 | 
						|
 | 
						|
    def start(self):
 | 
						|
        """Starts message processing thread."""
 | 
						|
        if self._helper is not None:
 | 
						|
            raise RuntimeError("Worker executor must be stopped before"
 | 
						|
                               " it can be started")
 | 
						|
        self._helper = tu.daemon_thread(self._proxy.start)
 | 
						|
        self._helper.start()
 | 
						|
        self._proxy.wait()
 | 
						|
 | 
						|
    def stop(self):
 | 
						|
        """Stops message processing thread."""
 | 
						|
        if self._helper is not None:
 | 
						|
            self._proxy.stop()
 | 
						|
            self._helper.join()
 | 
						|
            self._helper = None
 | 
						|
        with self._ongoing_requests_lock:
 | 
						|
            while self._ongoing_requests:
 | 
						|
                _request_uuid, request = self._ongoing_requests.popitem()
 | 
						|
                self._handle_expired_request(request)
 | 
						|
        self._finder.reset()
 | 
						|
        self._messages_processed['finder'] = self._finder.messages_processed
 |