diff --git a/taskflow/engines/worker_based/proxy.py b/taskflow/engines/worker_based/proxy.py index aaa75c86c..d2991ca38 100644 --- a/taskflow/engines/worker_based/proxy.py +++ b/taskflow/engines/worker_based/proxy.py @@ -22,7 +22,7 @@ import kombu import six from taskflow.engines.worker_based import dispatcher - +from taskflow.utils import misc LOG = logging.getLogger(__name__) @@ -40,29 +40,45 @@ class Proxy(object): self._exchange_name = exchange_name self._on_wait = on_wait self._running = threading.Event() - self._url = kwargs.get('url') - self._transport = kwargs.get('transport') - self._transport_opts = kwargs.get('transport_options') self._dispatcher = dispatcher.TypeDispatcher(type_handlers) self._dispatcher.add_requeue_filter( # NOTE(skudriashev): Process all incoming messages only if proxy is # running, otherwise requeue them. lambda data, message: not self.is_running) + + url = kwargs.get('url') + transport = kwargs.get('transport') + transport_opts = kwargs.get('transport_options') + self._drain_events_timeout = DRAIN_EVENTS_PERIOD - if self._transport == 'memory' and self._transport_opts: - polling_interval = self._transport_opts.get('polling_interval') - if polling_interval: + if transport == 'memory' and transport_opts: + polling_interval = transport_opts.get('polling_interval') + if polling_interval is not None: self._drain_events_timeout = polling_interval # create connection - self._conn = kombu.Connection(self._url, transport=self._transport, - transport_options=self._transport_opts) + self._conn = kombu.Connection(url, transport=transport, + transport_options=transport_opts) # create exchange self._exchange = kombu.Exchange(name=self._exchange_name, durable=False, auto_delete=True) + @property + def connection_details(self): + # The kombu drivers seem to use 'N/A' when they don't have a version... + driver_version = self._conn.transport.driver_version() + if driver_version and driver_version.lower() == 'n/a': + driver_version = None + return misc.AttrDict( + uri=self._conn.as_uri(include_password=False), + transport=misc.AttrDict( + options=dict(self._conn.transport_options), + driver_type=self._conn.transport.driver_type, + driver_name=self._conn.transport.driver_name, + driver_version=driver_version)) + @property def is_running(self): """Return whether the proxy is running.""" diff --git a/taskflow/engines/worker_based/server.py b/taskflow/engines/worker_based/server.py index 8b175783f..73625865b 100644 --- a/taskflow/engines/worker_based/server.py +++ b/taskflow/engines/worker_based/server.py @@ -61,6 +61,10 @@ class Server(object): self._endpoints = dict([(endpoint.name, endpoint) for endpoint in endpoints]) + @property + def connection_details(self): + return self._proxy.connection_details + @staticmethod def _parse_request(task_cls, task_name, action, arguments, result=None, failures=None, **kwargs): diff --git a/taskflow/engines/worker_based/worker.py b/taskflow/engines/worker_based/worker.py index 490117889..49816eabd 100644 --- a/taskflow/engines/worker_based/worker.py +++ b/taskflow/engines/worker_based/worker.py @@ -15,6 +15,11 @@ # under the License. import logging +import os +import platform +import socket +import string +import sys from concurrent import futures @@ -23,6 +28,37 @@ from taskflow.engines.worker_based import server from taskflow import task as t_task from taskflow.utils import reflection from taskflow.utils import threading_utils as tu +from taskflow import version + +BANNER_TEMPLATE = string.Template(""" +TaskFlow v${version} WBE worker. +Connection details: + Driver = $transport_driver + Exchange = $exchange + Topic = $topic + Transport = $transport_type + Uri = $connection_uri +Powered by: + Executor = $executor_type + Thread count = $executor_thread_count +Supported endpoints:$endpoints +System details: + Hostname = $hostname + Pid = $pid + Platform = $platform + Python = $python + Thread id = $thread_id +""".strip()) +BANNER_TEMPLATE.defaults = { + # These values may not be possible to fetch/known, default to unknown... + 'pid': '???', + 'hostname': '???', + 'executor_thread_count': '???', + 'endpoints': ' %s' % ([]), + # These are static (avoid refetching...) + 'version': version.version_string(), + 'python': sys.version.split("\n", 1)[0].strip(), +} LOG = logging.getLogger(__name__) @@ -78,6 +114,7 @@ class Worker(object): self._executor = futures.ThreadPoolExecutor(self._threads_count) self._owns_executor = True self._endpoints = self._derive_endpoints(tasks) + self._exchange = exchange self._server = server.Server(topic, exchange, self._executor, self._endpoints, **kwargs) @@ -87,17 +124,48 @@ class Worker(object): derived_tasks = reflection.find_subclasses(tasks, t_task.BaseTask) return [endpoint.Endpoint(task) for task in derived_tasks] - def run(self): - """Run worker.""" - if self._threads_count != -1: - LOG.info("Starting the '%s' topic worker in %s threads.", - self._topic, self._threads_count) + def _generate_banner(self): + """Generates a banner that can be useful to display before running.""" + tpl_params = {} + connection_details = self._server.connection_details + transport = connection_details.transport + if transport.driver_version: + transport_driver = "%s v%s" % (transport.driver_name, + transport.driver_version) else: - LOG.info("Starting the '%s' topic worker using a %s.", self._topic, - self._executor) - LOG.info("Tasks list:") - for e in self._endpoints: - LOG.info("|-- %s", e) + transport_driver = transport.driver_name + tpl_params['transport_driver'] = transport_driver + tpl_params['exchange'] = self._exchange + tpl_params['topic'] = self._topic + tpl_params['transport_type'] = transport.driver_type + tpl_params['connection_uri'] = connection_details.uri + tpl_params['executor_type'] = reflection.get_class_name(self._executor) + if self._threads_count != -1: + tpl_params['executor_thread_count'] = self._threads_count + if self._endpoints: + pretty_endpoints = [] + for ep in self._endpoints: + pretty_endpoints.append(" - %s" % ep) + # This ensures there is a newline before the list... + tpl_params['endpoints'] = "\n" + "\n".join(pretty_endpoints) + try: + tpl_params['hostname'] = socket.getfqdn() + except socket.error: + pass + try: + tpl_params['pid'] = os.getpid() + except OSError: + pass + tpl_params['platform'] = platform.platform() + tpl_params['thread_id'] = tu.get_ident() + return BANNER_TEMPLATE.substitute(BANNER_TEMPLATE.defaults, + **tpl_params) + + def run(self, display_banner=True): + """Runs the worker.""" + if display_banner: + for line in self._generate_banner().splitlines(): + LOG.info(line) self._server.start() def wait(self):