diff --git a/oslo_messaging/_drivers/impl_zmq.py b/oslo_messaging/_drivers/impl_zmq.py index dbc287d04..5eb3fdc35 100644 --- a/oslo_messaging/_drivers/impl_zmq.py +++ b/oslo_messaging/_drivers/impl_zmq.py @@ -72,7 +72,7 @@ zmq_opts = [ help='Expiration timeout in seconds of a name service record ' 'about existing target ( < 0 means no timeout).'), - cfg.BoolOpt('direct_over_proxy', default=True, + cfg.BoolOpt('direct_over_proxy', default=False, help='Configures zmq-messaging to use proxy with ' 'non PUB/SUB patterns.'), @@ -117,11 +117,10 @@ class LazyDriverItem(object): if self.item is not None and os.getpid() == self.process_id: return self.item - self._lock.acquire() - if self.item is None or os.getpid() != self.process_id: - self.process_id = os.getpid() - self.item = self.item_class(*self.args, **self.kwargs) - self._lock.release() + with self._lock: + if self.item is None or os.getpid() != self.process_id: + self.process_id = os.getpid() + self.item = self.item_class(*self.args, **self.kwargs) return self.item def cleanup(self): diff --git a/oslo_messaging/_drivers/zmq_driver/client/publishers/dealer/zmq_dealer_call_publisher.py b/oslo_messaging/_drivers/zmq_driver/client/publishers/dealer/zmq_dealer_call_publisher.py index db3fc0280..c1115588e 100644 --- a/oslo_messaging/_drivers/zmq_driver/client/publishers/dealer/zmq_dealer_call_publisher.py +++ b/oslo_messaging/_drivers/zmq_driver/client/publishers/dealer/zmq_dealer_call_publisher.py @@ -20,12 +20,10 @@ import futurist import oslo_messaging from oslo_messaging._drivers import common as rpc_common -from oslo_messaging._drivers.zmq_driver.client.publishers\ +from oslo_messaging._drivers.zmq_driver.client.publishers \ import zmq_publisher_base -from oslo_messaging._drivers.zmq_driver import zmq_address from oslo_messaging._drivers.zmq_driver import zmq_async from oslo_messaging._drivers.zmq_driver import zmq_names -from oslo_messaging._drivers.zmq_driver import zmq_socket from oslo_messaging._i18n import _LW LOG = logging.getLogger(__name__) @@ -33,7 +31,7 @@ LOG = logging.getLogger(__name__) zmq = zmq_async.import_zmq() -class DealerCallPublisher(zmq_publisher_base.PublisherBase): +class DealerCallPublisher(object): """Thread-safe CALL publisher Used as faster and thread-safe publisher for CALL @@ -41,7 +39,8 @@ class DealerCallPublisher(zmq_publisher_base.PublisherBase): """ def __init__(self, conf, matchmaker): - super(DealerCallPublisher, self).__init__(conf) + super(DealerCallPublisher, self).__init__() + self.conf = conf self.matchmaker = matchmaker self.reply_waiter = ReplyWaiter(conf) self.sender = RequestSender(conf, matchmaker, self.reply_waiter) \ @@ -66,11 +65,17 @@ class DealerCallPublisher(zmq_publisher_base.PublisherBase): else: return reply[zmq_names.FIELD_REPLY] + def cleanup(self): + self.reply_waiter.cleanup() + self.sender.cleanup() -class RequestSender(zmq_publisher_base.PublisherMultisend): + +class RequestSender(zmq_publisher_base.PublisherBase): def __init__(self, conf, matchmaker, reply_waiter): - super(RequestSender, self).__init__(conf, matchmaker, zmq.DEALER) + sockets_manager = zmq_publisher_base.SocketsManager( + conf, matchmaker, zmq.ROUTER, zmq.DEALER) + super(RequestSender, self).__init__(sockets_manager) self.reply_waiter = reply_waiter self.queue, self.empty_except = zmq_async.get_queue() self.executor = zmq_async.get_executor(self.run_loop) @@ -89,19 +94,8 @@ class RequestSender(zmq_publisher_base.PublisherMultisend): LOG.debug("Sending message_id %(message)s to a target %(target)s", {"message": request.message_id, "target": request.target}) - def _check_hosts_connections(self, target, listener_type): - if str(target) in self.outbound_sockets: - socket = self.outbound_sockets[str(target)] - else: - hosts = self.matchmaker.get_hosts( - target, listener_type) - socket = zmq_socket.ZmqSocket(self.zmq_context, self.socket_type) - self.outbound_sockets[str(target)] = socket - - for host in hosts: - self._connect_to_host(socket, host, target) - - return socket + def _connect_socket(self, target): + return self.outbound_sockets.get_socket(target) def run_loop(self): try: @@ -109,12 +103,15 @@ class RequestSender(zmq_publisher_base.PublisherMultisend): except self.empty_except: return - socket = self._check_hosts_connections( - request.target, zmq_names.socket_type_str(zmq.ROUTER)) + socket = self._connect_socket(request.target) self._do_send_request(socket, request) self.reply_waiter.poll_socket(socket) + def cleanup(self): + self.executor.stop() + super(RequestSender, self).cleanup() + class RequestSenderLight(RequestSender): """This class used with proxy. @@ -132,14 +129,8 @@ class RequestSenderLight(RequestSender): self.socket = None - def _check_hosts_connections(self, target, listener_type): - if self.socket is None: - self.socket = zmq_socket.ZmqSocket(self.zmq_context, - self.socket_type) - self.outbound_sockets[str(target)] = self.socket - address = zmq_address.get_broker_address(self.conf) - self._connect_to_address(self.socket, address, target) - return self.socket + def _connect_socket(self, target): + return self.outbound_sockets.get_socket_to_broker(target) def _do_send_request(self, socket, request): LOG.debug("Sending %(type)s message_id %(message)s" @@ -196,3 +187,6 @@ class ReplyWaiter(object): call_future.set_result(reply) else: LOG.warning(_LW("Received timed out reply: %s"), reply_id) + + def cleanup(self): + self.poller.close() diff --git a/oslo_messaging/_drivers/zmq_driver/client/publishers/dealer/zmq_dealer_publisher.py b/oslo_messaging/_drivers/zmq_driver/client/publishers/dealer/zmq_dealer_publisher.py index 07606a0eb..85dc459ae 100644 --- a/oslo_messaging/_drivers/zmq_driver/client/publishers/dealer/zmq_dealer_publisher.py +++ b/oslo_messaging/_drivers/zmq_driver/client/publishers/dealer/zmq_dealer_publisher.py @@ -25,17 +25,18 @@ LOG = logging.getLogger(__name__) zmq = zmq_async.import_zmq() -class DealerPublisher(zmq_publisher_base.PublisherMultisend): +class DealerPublisher(zmq_publisher_base.PublisherBase): def __init__(self, conf, matchmaker): - super(DealerPublisher, self).__init__(conf, matchmaker, zmq.DEALER) + sockets_manager = zmq_publisher_base.SocketsManager( + conf, matchmaker, zmq.ROUTER, zmq.DEALER) + super(DealerPublisher, self).__init__(sockets_manager) def send_request(self, request): self._check_request_pattern(request) - dealer_socket = self._check_hosts_connections( - request.target, zmq_names.socket_type_str(zmq.ROUTER)) + dealer_socket = self.outbound_sockets.get_socket(request.target) if not dealer_socket.connections: # NOTE(ozamiatin): Here we can provide @@ -68,11 +69,13 @@ class DealerPublisher(zmq_publisher_base.PublisherMultisend): super(DealerPublisher, self).cleanup() -class DealerPublisherLight(zmq_publisher_base.PublisherBase): +class DealerPublisherLight(object): """Used when publishing to proxy. """ def __init__(self, conf, address): - super(DealerPublisherLight, self).__init__(conf) + super(DealerPublisherLight, self).__init__() + self.conf = conf + self.zmq_context = zmq.Context() self.socket = self.zmq_context.socket(zmq.DEALER) self.address = address self.socket.connect(address) diff --git a/oslo_messaging/_drivers/zmq_driver/client/publishers/zmq_pub_publisher.py b/oslo_messaging/_drivers/zmq_driver/client/publishers/zmq_pub_publisher.py index f228f2592..7dc4b239d 100644 --- a/oslo_messaging/_drivers/zmq_driver/client/publishers/zmq_pub_publisher.py +++ b/oslo_messaging/_drivers/zmq_driver/client/publishers/zmq_pub_publisher.py @@ -27,7 +27,7 @@ LOG = logging.getLogger(__name__) zmq = zmq_async.import_zmq() -class PubPublisherProxy(zmq_publisher_base.PublisherBase): +class PubPublisherProxy(object): """PUB/SUB based request publisher The publisher intended to be used for Fanout and Notify @@ -42,7 +42,9 @@ class PubPublisherProxy(zmq_publisher_base.PublisherBase): """ def __init__(self, conf, matchmaker): - super(PubPublisherProxy, self).__init__(conf) + super(PubPublisherProxy, self).__init__() + self.conf = conf + self.zmq_context = zmq.Context() self.matchmaker = matchmaker self.socket = zmq_socket.ZmqRandomPortSocket( diff --git a/oslo_messaging/_drivers/zmq_driver/client/publishers/zmq_publisher_base.py b/oslo_messaging/_drivers/zmq_driver/client/publishers/zmq_publisher_base.py index bcd3a9fa3..e08905a6e 100644 --- a/oslo_messaging/_drivers/zmq_driver/client/publishers/zmq_publisher_base.py +++ b/oslo_messaging/_drivers/zmq_driver/client/publishers/zmq_publisher_base.py @@ -14,7 +14,7 @@ import abc import logging -import uuid +import time import six @@ -23,7 +23,7 @@ from oslo_messaging._drivers.zmq_driver import zmq_address from oslo_messaging._drivers.zmq_driver import zmq_async from oslo_messaging._drivers.zmq_driver import zmq_names from oslo_messaging._drivers.zmq_driver import zmq_socket -from oslo_messaging._i18n import _LE, _LI +from oslo_messaging._i18n import _LE LOG = logging.getLogger(__name__) @@ -57,7 +57,7 @@ class PublisherBase(object): Publisher can send request objects from zmq_request. """ - def __init__(self, conf): + def __init__(self, sockets_manager): """Construct publisher @@ -67,10 +67,9 @@ class PublisherBase(object): :param conf: configuration object :type conf: oslo_config.CONF """ - - self.conf = conf - self.zmq_context = zmq.Context() - self.outbound_sockets = {} + self.outbound_sockets = sockets_manager + self.conf = sockets_manager.conf + self.matchmaker = sockets_manager.matchmaker super(PublisherBase, self).__init__() @abc.abstractmethod @@ -99,60 +98,51 @@ class PublisherBase(object): def cleanup(self): """Cleanup publisher. Close allocated connections.""" - for socket in self.outbound_sockets.values(): - socket.setsockopt(zmq.LINGER, 0) - socket.close() + self.outbound_sockets.cleanup() -class PublisherMultisend(PublisherBase): +class SocketsManager(object): - def __init__(self, conf, matchmaker, socket_type): - - """Construct publisher multi-send - - Base class for fanout-sending publishers. - - :param conf: configuration object - :type conf: oslo_config.CONF - :param matchmaker: Name Service interface object - :type matchmaker: matchmaker.MatchMakerBase - """ - super(PublisherMultisend, self).__init__(conf) - self.socket_type = socket_type + def __init__(self, conf, matchmaker, listener_type, socket_type): + self.conf = conf self.matchmaker = matchmaker + self.listener_type = listener_type + self.socket_type = socket_type + self.zmq_context = zmq.Context() + self.outbound_sockets = {} - def _check_hosts_connections(self, target, listener_type): - # TODO(ozamiatin): Place for significant optimization - # Matchmaker cache should be implemented - if str(target) in self.outbound_sockets: - socket = self.outbound_sockets[str(target)] - else: - hosts = self.matchmaker.get_hosts(target, listener_type) - socket = zmq_socket.ZmqSocket(self.zmq_context, self.socket_type) - self.outbound_sockets[str(target)] = socket - for host in hosts: - self._connect_to_host(socket, host, target) + def _track_socket(self, socket, target): + self.outbound_sockets[str(target)] = (socket, time.time()) + + def _get_hosts_and_connect(self, socket, target): + hosts = self.matchmaker.get_hosts( + target, zmq_names.socket_type_str(self.listener_type)) + for host in hosts: + socket.connect_to_host(host) + self._track_socket(socket, target) + + def _check_for_new_hosts(self, target): + socket, tm = self.outbound_sockets[str(target)] + if 0 <= self.conf.zmq_target_expire <= time.time() - tm: + self._get_hosts_and_connect(socket, target) return socket - def _connect_to_address(self, socket, address, target): - stype = zmq_names.socket_type_str(self.socket_type) - try: - LOG.info(_LI("Connecting %(stype)s to %(address)s for %(target)s"), - {"stype": stype, "address": address, "target": target}) + def get_socket(self, target): + if str(target) in self.outbound_sockets: + socket = self._check_for_new_hosts(target) + else: + socket = zmq_socket.ZmqSocket(self.zmq_context, self.socket_type) + self._get_hosts_and_connect(socket, target) + return socket - if six.PY3: - socket.setsockopt_string(zmq.IDENTITY, str(uuid.uuid1())) - else: - socket.handle.identity = str(uuid.uuid1()) + def get_socket_to_broker(self, target): + socket = zmq_socket.ZmqSocket(self.zmq_context, self.socket_type) + self._track_socket(socket, target) + address = zmq_address.get_broker_address(self.conf) + socket.connect_to_address(address) + return socket - socket.connect(address) - except zmq.ZMQError as e: - errmsg = _LE("Failed connecting %(stype) to %(address)s: %(e)s")\ - % (stype, address, e) - LOG.error(_LE("Failed connecting %(stype) to %(address)s: %(e)s"), - (stype, address, e)) - raise rpc_common.RPCException(errmsg) - - def _connect_to_host(self, socket, host, target): - address = zmq_address.get_tcp_direct_address(host) - self._connect_to_address(socket, address, target) + def cleanup(self): + for socket, tm in self.outbound_sockets.values(): + socket.setsockopt(zmq.LINGER, 0) + socket.close() diff --git a/oslo_messaging/_drivers/zmq_driver/client/publishers/zmq_push_publisher.py b/oslo_messaging/_drivers/zmq_driver/client/publishers/zmq_push_publisher.py index c7854aeb4..549d3dced 100644 --- a/oslo_messaging/_drivers/zmq_driver/client/publishers/zmq_push_publisher.py +++ b/oslo_messaging/_drivers/zmq_driver/client/publishers/zmq_push_publisher.py @@ -25,7 +25,7 @@ LOG = logging.getLogger(__name__) zmq = zmq_async.import_zmq() -class PushPublisher(zmq_publisher_base.PublisherMultisend): +class PushPublisher(zmq_publisher_base.PublisherBase): def __init__(self, conf, matchmaker): super(PushPublisher, self).__init__(conf, matchmaker, zmq.PUSH) diff --git a/oslo_messaging/_drivers/zmq_driver/matchmaker/base.py b/oslo_messaging/_drivers/zmq_driver/matchmaker/base.py index 5a0d2789c..75f6a7315 100644 --- a/oslo_messaging/_drivers/zmq_driver/matchmaker/base.py +++ b/oslo_messaging/_drivers/zmq_driver/matchmaker/base.py @@ -14,7 +14,6 @@ import abc import collections import logging -import retrying import six @@ -53,20 +52,6 @@ class MatchMakerBase(object): :type hostname: tuple """ - def get_publishers_retrying(self): - """Retry until at least one publisher appears""" - - def retry_if_empty(publishers): - return not publishers - - _retry = retrying.retry(retry_on_result=retry_if_empty) - - @_retry - def _get_publishers(): - return self.get_publishers() - - return _get_publishers() - @abc.abstractmethod def get_publishers(self): """Get all publisher-hosts from nameserver. diff --git a/oslo_messaging/_drivers/zmq_driver/matchmaker/matchmaker_redis.py b/oslo_messaging/_drivers/zmq_driver/matchmaker/matchmaker_redis.py index 0550deab4..72dfa65e4 100644 --- a/oslo_messaging/_drivers/zmq_driver/matchmaker/matchmaker_redis.py +++ b/oslo_messaging/_drivers/zmq_driver/matchmaker/matchmaker_redis.py @@ -56,26 +56,31 @@ matchmaker_redis_opts = [ ] _PUBLISHERS_KEY = "PUBLISHERS" +_RETRY_METHODS = ("get_hosts", "get_publishers") def retry_if_connection_error(ex): return isinstance(ex, redis.ConnectionError) +def retry_if_empty(hosts): + return not hosts + + def apply_retrying(obj, cfg): for attr_name, attr in inspect.getmembers(obj): if not (inspect.ismethod(attr) or inspect.isfunction(attr)): continue - if attr_name.startswith("_"): - continue - setattr( - obj, - attr_name, - retry( - wait_fixed=cfg.matchmaker_redis.wait_timeout, - stop_max_delay=cfg.matchmaker_redis.check_timeout, - retry_on_exception=retry_if_connection_error - )(attr)) + if attr_name in _RETRY_METHODS: + setattr( + obj, + attr_name, + retry( + wait_fixed=cfg.matchmaker_redis.wait_timeout, + stop_max_delay=cfg.matchmaker_redis.check_timeout, + retry_on_exception=retry_if_connection_error, + retry_on_result=retry_if_empty + )(attr)) class RedisMatchMaker(base.MatchMakerBase): @@ -150,6 +155,7 @@ class RedisMatchMaker(base.MatchMakerBase): self._redis.lrem(key, 0, hostname) def get_hosts(self, target, listener_type): + LOG.debug("[Redis] get_hosts for target %s", target) hosts = [] key = zmq_address.target_to_key(target, listener_type) hosts.extend(self._get_hosts_by_key(key)) diff --git a/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_sub_consumer.py b/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_sub_consumer.py index 4d1e03585..36bb99cc8 100644 --- a/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_sub_consumer.py +++ b/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_sub_consumer.py @@ -87,7 +87,7 @@ class SubConsumer(zmq_consumer_base.ConsumerBase): self.poller.register(self.socket, self.receive_message) LOG.debug("[%s] SUB consumer connected to publishers %s", - (self.id, publishers)) + self.id, publishers) def listen(self, target): LOG.debug("Listen to target %s", target) @@ -96,7 +96,7 @@ class SubConsumer(zmq_consumer_base.ConsumerBase): def _receive_request(self, socket): topic_filter = socket.recv() - LOG.debug("[%(id)s] Received %(topict_filter)s topic", + LOG.debug("[%(id)s] Received %(topic_filter)s topic", {'id': self.id, 'topic_filter': topic_filter}) assert topic_filter in self.subscriptions request = socket.recv_pyobj() @@ -135,7 +135,7 @@ class MatchmakerPoller(object): self.executor.execute() def _poll_for_publishers(self): - publishers = self.matchmaker.get_publishers_retrying() + publishers = self.matchmaker.get_publishers() if publishers: self.on_result(publishers) self.executor.done() diff --git a/oslo_messaging/_drivers/zmq_driver/zmq_socket.py b/oslo_messaging/_drivers/zmq_driver/zmq_socket.py index 4119e5735..c16657262 100644 --- a/oslo_messaging/_drivers/zmq_driver/zmq_socket.py +++ b/oslo_messaging/_drivers/zmq_driver/zmq_socket.py @@ -13,11 +13,15 @@ # under the License. import logging +import uuid +import six + +from oslo_messaging._drivers import common as rpc_common from oslo_messaging._drivers.zmq_driver import zmq_address from oslo_messaging._drivers.zmq_driver import zmq_async from oslo_messaging._drivers.zmq_driver import zmq_names -from oslo_messaging._i18n import _LE +from oslo_messaging._i18n import _LE, _LI from oslo_messaging import exceptions LOG = logging.getLogger(__name__) @@ -83,6 +87,29 @@ class ZmqSocket(object): def close(self, *args, **kwargs): self.handle.close(*args, **kwargs) + def connect_to_address(self, address): + stype = zmq_names.socket_type_str(self.socket_type) + try: + LOG.info(_LI("Connecting %(stype)s to %(address)s"), + {"stype": stype, "address": address}) + + if six.PY3: + self.setsockopt_string(zmq.IDENTITY, str(uuid.uuid1())) + else: + self.handle.identity = str(uuid.uuid1()) + + self.connect(address) + except zmq.ZMQError as e: + errmsg = _LE("Failed connecting %(stype) to %(address)s: %(e)s")\ + % (stype, address, e) + LOG.error(_LE("Failed connecting %(stype) to %(address)s: %(e)s"), + (stype, address, e)) + raise rpc_common.RPCException(errmsg) + + def connect_to_host(self, host): + address = zmq_address.get_tcp_direct_address(host) + self.connect_to_address(address) + class ZmqPortRangeExceededException(exceptions.MessagingException): """Raised by ZmqRandomPortSocket - wrapping zmq.ZMQBindError""" diff --git a/oslo_messaging/tests/drivers/zmq/matchmaker/test_impl_matchmaker.py b/oslo_messaging/tests/drivers/zmq/matchmaker/test_impl_matchmaker.py index ac58b205e..7237fe25f 100644 --- a/oslo_messaging/tests/drivers/zmq/matchmaker/test_impl_matchmaker.py +++ b/oslo_messaging/tests/drivers/zmq/matchmaker/test_impl_matchmaker.py @@ -16,6 +16,8 @@ from stevedore import driver import testscenarios import testtools +import retrying + import oslo_messaging from oslo_messaging.tests import utils as test_utils from oslo_utils import importutils @@ -92,4 +94,9 @@ class TestImplMatchmaker(test_utils.BaseTestCase): def test_get_hosts_wrong_topic(self): target = oslo_messaging.Target(topic="no_such_topic") - self.assertEqual(self.test_matcher.get_hosts(target, "test"), []) + hosts = [] + try: + hosts = self.test_matcher.get_hosts(target, "test") + except retrying.RetryError: + pass + self.assertEqual(hosts, []) diff --git a/oslo_messaging/tests/functional/zmq/__init__.py b/oslo_messaging/tests/functional/zmq/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/oslo_messaging/tests/functional/zmq/multiproc_utils.py b/oslo_messaging/tests/functional/zmq/multiproc_utils.py new file mode 100644 index 000000000..fb66615b6 --- /dev/null +++ b/oslo_messaging/tests/functional/zmq/multiproc_utils.py @@ -0,0 +1,236 @@ +# Copyright 2016 Mirantis, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import logging +import logging.handlers +import multiprocessing +import os +import sys +import time +import uuid +import threading + +from oslo_config import cfg + +import oslo_messaging +from oslo_messaging._drivers.zmq_driver import zmq_async +from oslo_messaging.tests.functional import utils + + +zmq = zmq_async.import_zmq() + +LOG = logging.getLogger(__name__) + + +class QueueHandler(logging.Handler): + """ + This is a logging handler which sends events to a multiprocessing queue. + + The plan is to add it to Python 3.2, but this can be copy pasted into + user code for use with earlier Python versions. + """ + + def __init__(self, queue): + """ + Initialise an instance, using the passed queue. + """ + logging.Handler.__init__(self) + self.queue = queue + + def emit(self, record): + """ + Emit a record. + + Writes the LogRecord to the queue. + """ + try: + ei = record.exc_info + if ei: + dummy = self.format(record) # just to get traceback text into record.exc_text + record.exc_info = None # not needed any more + self.queue.put_nowait(record) + except (KeyboardInterrupt, SystemExit): + raise + except: + self.handleError(record) + + +def listener_configurer(conf): + root = logging.getLogger() + h = logging.StreamHandler(sys.stdout) + f = logging.Formatter('%(asctime)s %(processName)-10s %(name)s ' + '%(levelname)-8s %(message)s') + h.setFormatter(f) + root.addHandler(h) + log_path = conf.rpc_zmq_ipc_dir + "/" + "zmq_multiproc.log" + file_handler = logging.StreamHandler(open(log_path, 'w')) + file_handler.setFormatter(f) + root.addHandler(file_handler) + + +def server_configurer(queue): + h = QueueHandler(queue) + root = logging.getLogger() + root.addHandler(h) + root.setLevel(logging.DEBUG) + + +def listener_thread(queue, configurer, conf): + configurer(conf) + while True: + time.sleep(0.3) + try: + record = queue.get() + if record is None: + break + logger = logging.getLogger(record.name) + logger.handle(record) + except (KeyboardInterrupt, SystemExit): + raise + + +class Client(oslo_messaging.RPCClient): + + def __init__(self, transport, topic): + super(Client, self).__init__( + transport=transport, target=oslo_messaging.Target(topic=topic)) + self.replies = [] + + def call_a(self): + LOG.warning("call_a - client side") + rep = self.call({}, 'call_a') + LOG.warning("after call_a - client side") + self.replies.append(rep) + return rep + + +class ReplyServerEndpoint(object): + + def call_a(self, *args, **kwargs): + LOG.warning("call_a - Server endpoint reached!") + return "OK" + + +class Server(object): + + def __init__(self, conf, log_queue, transport_url, name, topic=None): + self.conf = conf + self.log_queue = log_queue + self.transport_url = transport_url + self.name = name + self.topic = topic or str(uuid.uuid4()) + self.ready = multiprocessing.Value('b', False) + self._stop = multiprocessing.Event() + + def start(self): + self.process = multiprocessing.Process(target=self._run_server, + name=self.name, + args=(self.conf, + self.transport_url, + self.log_queue, + self.ready)) + self.process.start() + LOG.debug("Server process started: pid: %d" % self.process.pid) + + def _run_server(self, conf, url, log_queue, ready): + server_configurer(log_queue) + LOG.debug("Starting RPC server") + + transport = oslo_messaging.get_transport(conf, url=url) + target = oslo_messaging.Target(topic=self.topic, + server=self.name) + self.rpc_server = oslo_messaging.get_rpc_server( + transport=transport, target=target, + endpoints=[ReplyServerEndpoint()], + executor='eventlet') + self.rpc_server.start() + ready.value = True + LOG.debug("RPC server being started") + while not self._stop.is_set(): + LOG.debug("Waiting for the stop signal ...") + time.sleep(1) + self.rpc_server.stop() + LOG.debug("Leaving process T:%s Pid:%d" % (str(target), + os.getpid())) + + def cleanup(self): + LOG.debug("Stopping server") + self.shutdown() + + def shutdown(self): + self._stop.set() + + def restart(self, time_for_restart=1): + pass + + def hang(self): + pass + + def crash(self): + pass + + def ping(self): + pass + + +class MutliprocTestCase(utils.SkipIfNoTransportURL): + + def setUp(self): + super(MutliprocTestCase, self).setUp(conf=cfg.ConfigOpts()) + + if not self.url.startswith("zmq:"): + self.skipTest("ZeroMQ specific skipped ...") + + self.transport = oslo_messaging.get_transport(self.conf, url=self.url) + + LOG.debug("Start log queue") + + self.log_queue = multiprocessing.Queue() + self.log_listener = threading.Thread(target=listener_thread, + args=(self.log_queue, + listener_configurer, + self.conf)) + self.log_listener.start() + self.spawned = [] + + self.conf.prog = "test_prog" + self.conf.project = "test_project" + + def tearDown(self): + super(MutliprocTestCase, self).tearDown() + for process in self.spawned: + process.cleanup() + + def get_client(self, topic): + return Client(self.transport, topic) + + def spawn_server(self, name, wait_for_server=False, topic=None): + srv = Server(self.conf, self.log_queue, self.url, name, topic) + LOG.debug("[SPAWN] %s (starting)..." % srv.name) + srv.start() + if wait_for_server: + while not srv.ready.value: + LOG.debug("[SPAWN] %s (waiting for server ready)..." % srv.name) + time.sleep(1) + LOG.debug("[SPAWN] Server %s:%d started.", srv.name, srv.process.pid) + self.spawned.append(srv) + return srv + + def spawn_servers(self, number, wait_for_server=False, random_topic=True): + common_topic = str(uuid.uuid4()) if random_topic else None + names = ["server_%i_%s" % (i, str(uuid.uuid4())[:8]) + for i in range(number)] + for name in names: + server = self.spawn_server(name, wait_for_server, common_topic) + self.spawned.append(server) diff --git a/oslo_messaging/tests/functional/zmq/test_startup.py b/oslo_messaging/tests/functional/zmq/test_startup.py new file mode 100644 index 000000000..c3258131f --- /dev/null +++ b/oslo_messaging/tests/functional/zmq/test_startup.py @@ -0,0 +1,57 @@ +# Copyright 2016 Mirantis, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import logging +import os +import sys + +from oslo_messaging.tests.functional.zmq import multiproc_utils + + +LOG = logging.getLogger(__name__) + + +class StartupOrderTestCase(multiproc_utils.MutliprocTestCase): + + def setUp(self): + super(StartupOrderTestCase, self).setUp() + + self.conf.prog = "test_prog" + self.conf.project = "test_project" + + kwargs = {'rpc_response_timeout': 30, + 'use_pub_sub': False, + 'direct_over_proxy': False} + self.config(**kwargs) + + log_path = self.conf.rpc_zmq_ipc_dir + "/" + str(os.getpid()) + ".log" + sys.stdout = open(log_path, "w", buffering=0) + + def test_call_server_before_client(self): + self.spawn_servers(3, wait_for_server=True, random_topic=False) + servers = self.spawned + client = self.get_client(servers[0].topic) + for i in range(3): + reply = client.call_a() + self.assertIsNotNone(reply) + self.assertEqual(3, len(client.replies)) + + def test_call_client_dont_wait_for_server(self): + self.spawn_servers(3, wait_for_server=False, random_topic=False) + servers = self.spawned + client = self.get_client(servers[0].topic) + for i in range(3): + reply = client.call_a() + self.assertIsNotNone(reply) + self.assertEqual(3, len(client.replies))