[zmq] Refactoring consumer side

* Introduce ZmqNotificationServer
* Implement single listener per target

Change-Id: I874c3fa6a86d3110a2145bea8ad06ca0bbd522c7
This commit is contained in:
Oleksii Zamiatin 2016-02-22 13:55:01 +02:00
parent f0d251d19d
commit 1385df6181
14 changed files with 92 additions and 258 deletions

View File

@ -254,9 +254,7 @@ class ZmqDriver(base.BaseDriver):
:param target: Message destination target
:type target: oslo_messaging.Target
"""
server = zmq_server.ZmqServer(self, self.conf, self.matchmaker)
server.listen(target)
return server
return zmq_server.ZmqServer(self, self.conf, self.matchmaker, target)
def listen_for_notifications(self, targets_and_priorities, pool):
"""Listen to a specified list of targets on a server side
@ -266,9 +264,8 @@ class ZmqDriver(base.BaseDriver):
:param pool: Not used for zmq implementation
:type pool: object
"""
server = zmq_server.ZmqServer(self, self.conf, self.matchmaker)
server.listen_notification(targets_and_priorities)
return server
return zmq_server.ZmqNotificationServer(
self, self.conf, self.matchmaker, targets_and_priorities)
def cleanup(self):
"""Cleanup all driver's connections finally

View File

@ -110,14 +110,12 @@ class ReplyWaiter(object):
self._lock = threading.Lock()
def track_reply(self, reply_future, message_id):
self._lock.acquire()
self.replies[message_id] = reply_future
self._lock.release()
with self._lock:
self.replies[message_id] = reply_future
def untrack_id(self, message_id):
self._lock.acquire()
self.replies.pop(message_id)
self._lock.release()
with self._lock:
self.replies.pop(message_id)
def poll_socket(self, socket):

View File

@ -54,9 +54,6 @@ class DealerPublisher(zmq_publisher_base.QueuedSender):
raise zmq_publisher_base.UnsupportedSendPattern(request.msg_type)
super(DealerPublisher, self).send_request(request)
def cleanup(self):
super(DealerPublisher, self).cleanup()
class DealerPublisherLight(zmq_publisher_base.QueuedSender):
"""Used when publishing to proxy. """
@ -91,3 +88,4 @@ class DealerPublisherLight(zmq_publisher_base.QueuedSender):
def cleanup(self):
self.socket.close()
super(DealerPublisherLight, self).cleanup()

View File

@ -44,7 +44,7 @@ class GreenPoller(zmq_poller.ZmqPoller):
try:
return self.incoming_queue.get(timeout=timeout)
except eventlet.queue.Empty:
return (None, None)
return None, None
def close(self):
for thread in self.thread_by_socket.values():
@ -53,41 +53,6 @@ class GreenPoller(zmq_poller.ZmqPoller):
self.thread_by_socket = {}
class HoldReplyPoller(GreenPoller):
def __init__(self):
super(HoldReplyPoller, self).__init__()
self.event_by_socket = {}
self._is_running = threading.Event()
def register(self, socket, recv_method=None):
super(HoldReplyPoller, self).register(socket, recv_method)
self.event_by_socket[socket] = threading.Event()
def resume_polling(self, socket):
pause = self.event_by_socket[socket]
pause.set()
def _socket_receive(self, socket, recv_method=None):
pause = self.event_by_socket[socket]
while not self._is_running.is_set():
pause.clear()
if recv_method:
incoming = recv_method(socket)
else:
incoming = socket.recv_multipart()
self.incoming_queue.put((incoming, socket))
pause.wait()
def close(self):
self._is_running.set()
for pause in self.event_by_socket.values():
pause.set()
eventlet.sleep()
super(HoldReplyPoller, self).close()
class GreenExecutor(zmq_poller.Executor):
def __init__(self, method):

View File

@ -14,12 +14,12 @@
import abc
import logging
import threading
import time
import six
from oslo_messaging._drivers import common as rpc_common
from oslo_messaging._drivers.zmq_driver import zmq_address
from oslo_messaging._drivers.zmq_driver import zmq_async
from oslo_messaging._drivers.zmq_driver import zmq_names
from oslo_messaging._drivers.zmq_driver import zmq_socket
@ -40,10 +40,6 @@ class ConsumerBase(object):
self.sockets = []
self.context = zmq.Context()
@abc.abstractmethod
def listen(self, target):
"""Associate new sockets with targets here"""
@abc.abstractmethod
def receive_message(self, target):
"""Method for poller - receiving message routine"""
@ -59,18 +55,26 @@ class SingleSocketConsumer(ConsumerBase):
def __init__(self, conf, poller, server, socket_type):
super(SingleSocketConsumer, self).__init__(conf, poller, server)
self.matchmaker = server.matchmaker
self.target = server.target
self.socket_type = socket_type
self.host = None
self.socket = self.subscribe_socket(socket_type)
self.target_updater = TargetUpdater(conf, self.matchmaker, self.target,
self.host, socket_type)
def subscribe_socket(self, socket_type):
try:
socket = zmq_socket.ZmqRandomPortSocket(
self.conf, self.context, socket_type)
self.sockets.append(socket)
self.poller.register(socket, self.receive_message)
LOG.debug("Run %(stype)s consumer on %(addr)s:%(port)d",
{"stype": zmq_names.socket_type_str(socket_type),
"addr": socket.bind_address,
"port": socket.port})
self.host = zmq_address.combine_address(self.conf.rpc_zmq_host,
socket.port)
self.poller.register(socket, self.receive_message)
return socket
except zmq.ZMQError as e:
errmsg = _LE("Failed binding to port %(port)d: %(e)s")\
@ -87,42 +91,30 @@ class SingleSocketConsumer(ConsumerBase):
def port(self):
return self.socket.port
def cleanup(self):
self.target_updater.cleanup()
super(SingleSocketConsumer, self).cleanup()
class TargetsManager(object):
def __init__(self, conf, matchmaker, host, socket_type):
self.targets = []
class TargetUpdater(object):
"""This entity performs periodic async updates
to the matchmaker.
"""
def __init__(self, conf, matchmaker, target, host, socket_type):
self.conf = conf
self.matchmaker = matchmaker
self.target = target
self.host = host
self.socket_type = socket_type
self.targets_lock = threading.Lock()
self.updater = zmq_async.get_executor(method=self._update_targets) \
if conf.zmq_target_expire > 0 else None
if self.updater:
self.updater.execute()
self.executor = zmq_async.get_executor(method=self._update_target)
self.executor.execute()
def _update_targets(self):
with self.targets_lock:
for target in self.targets:
self.matchmaker.register(
target, self.host,
zmq_names.socket_type_str(self.socket_type))
# Update target-records once per half expiration time
def _update_target(self):
self.matchmaker.register(
self.target, self.host,
zmq_names.socket_type_str(self.socket_type))
time.sleep(self.conf.zmq_target_expire / 2)
def listen(self, target):
with self.targets_lock:
self.targets.append(target)
self.matchmaker.register(
target, self.host,
zmq_names.socket_type_str(self.socket_type))
def cleanup(self):
if self.updater:
self.updater.stop()
for target in self.targets:
self.matchmaker.unregister(
target, self.host,
zmq_names.socket_type_str(self.socket_type))
self.executor.stop()

View File

@ -17,7 +17,6 @@ import logging
from oslo_messaging._drivers import base
from oslo_messaging._drivers.zmq_driver.server.consumers\
import zmq_consumer_base
from oslo_messaging._drivers.zmq_driver import zmq_address
from oslo_messaging._drivers.zmq_driver import zmq_async
from oslo_messaging._drivers.zmq_driver import zmq_names
from oslo_messaging._i18n import _LE, _LI
@ -46,21 +45,8 @@ class PullConsumer(zmq_consumer_base.SingleSocketConsumer):
def __init__(self, conf, poller, server):
super(PullConsumer, self).__init__(conf, poller, server, zmq.PULL)
self.matchmaker = server.matchmaker
self.host = zmq_address.combine_address(self.conf.rpc_zmq_host,
self.port)
self.targets = zmq_consumer_base.TargetsManager(
conf, self.matchmaker, self.host, zmq.PULL)
LOG.info(_LI("[%s] Run PULL consumer"), self.host)
def listen(self, target):
LOG.info(_LI("Listen to target %s"), str(target))
self.targets.listen(target)
def cleanup(self):
super(PullConsumer, self).cleanup()
self.targets.cleanup()
def receive_message(self, socket):
try:
request = socket.recv_pyobj()

View File

@ -18,7 +18,6 @@ from oslo_messaging._drivers import base
from oslo_messaging._drivers.zmq_driver.server.consumers\
import zmq_consumer_base
from oslo_messaging._drivers.zmq_driver.server import zmq_incoming_message
from oslo_messaging._drivers.zmq_driver import zmq_address
from oslo_messaging._drivers.zmq_driver import zmq_async
from oslo_messaging._drivers.zmq_driver import zmq_names
from oslo_messaging._i18n import _LE, _LI
@ -53,22 +52,8 @@ class RouterConsumer(zmq_consumer_base.SingleSocketConsumer):
def __init__(self, conf, poller, server):
super(RouterConsumer, self).__init__(conf, poller, server, zmq.ROUTER)
self.matchmaker = server.matchmaker
self.host = zmq_address.combine_address(self.conf.rpc_zmq_host,
self.port)
self.targets = zmq_consumer_base.TargetsManager(
conf, self.matchmaker, self.host, zmq.ROUTER)
LOG.info(_LI("[%s] Run ROUTER consumer"), self.host)
def listen(self, target):
LOG.info(_LI("[%(host)s] Listen to target %(target)s"),
{'host': self.host, 'target': target})
self.targets.listen(target)
def cleanup(self):
super(RouterConsumer, self).cleanup()
self.targets.cleanup()
def _receive_request(self, socket):
reply_id = socket.recv()
empty = socket.recv()

View File

@ -13,7 +13,6 @@
# under the License.
import logging
import threading
import uuid
import six
@ -34,12 +33,11 @@ zmq = zmq_async.import_zmq()
class SubIncomingMessage(base.RpcIncomingMessage):
def __init__(self, request, socket, poller):
def __init__(self, request, socket):
super(SubIncomingMessage, self).__init__(
request.context, request.message)
self.socket = socket
self.msg_id = request.message_id
poller.resume_polling(socket)
def reply(self, reply=None, failure=None, log_failure=True):
"""Reply is not needed for non-call messages."""
@ -56,16 +54,24 @@ class SubConsumer(zmq_consumer_base.ConsumerBase):
def __init__(self, conf, poller, server):
super(SubConsumer, self).__init__(conf, poller, server)
self.matchmaker = server.matchmaker
self.target = server.target
self.subscriptions = set()
self.targets = []
self._socket_lock = threading.Lock()
self.socket = zmq_socket.ZmqSocket(self.conf, self.context, zmq.SUB)
self.sockets.append(self.socket)
self.id = uuid.uuid4()
self.publishers_poller = MatchmakerPoller(
self.matchmaker, on_result=self.on_publishers)
self._subscribe_on_target(self.target)
self.on_publishers(self.matchmaker.get_publishers())
self.poller.register(self.socket, self.receive_message)
def on_publishers(self, publishers):
for host, sync in publishers:
self.socket.connect(zmq_address.get_tcp_direct_address(host))
LOG.debug("[%s] SUB consumer connected to publishers %s",
self.id, publishers)
def _subscribe_on_target(self, target):
# NOTE(ozamiatin): No locks needed here, because this is called
# before the async updater loop started
topic_filter = zmq_address.target_to_subscribe_filter(target)
if target.topic:
self.socket.setsockopt(zmq.SUBSCRIBE, six.b(target.topic))
@ -80,20 +86,6 @@ class SubConsumer(zmq_consumer_base.ConsumerBase):
LOG.debug("[%(host)s] Subscribing to topic %(filter)s",
{"host": self.id, "filter": topic_filter})
def on_publishers(self, publishers):
with self._socket_lock:
for host, sync in publishers:
self.socket.connect(zmq_address.get_tcp_direct_address(host))
self.poller.register(self.socket, self.receive_message)
LOG.debug("[%s] SUB consumer connected to publishers %s",
self.id, publishers)
def listen(self, target):
LOG.debug("Listen to target %s", target)
with self._socket_lock:
self._subscribe_on_target(target)
def _receive_request(self, socket):
topic_filter = socket.recv()
LOG.debug("[%(id)s] Received %(topic_filter)s topic",
@ -115,42 +107,9 @@ class SubConsumer(zmq_consumer_base.ConsumerBase):
if request.msg_type not in zmq_names.MULTISEND_TYPES:
LOG.error(_LE("Unknown message type: %s"), request.msg_type)
else:
return SubIncomingMessage(request, socket, self.poller)
return SubIncomingMessage(request, socket)
except zmq.ZMQError as e:
LOG.error(_LE("Receiving message failed: %s"), str(e))
class MatchmakerPoller(object):
"""This entity performs periodical async polling
to the matchmaker if no hosts were registered for
specified target before.
"""
def __init__(self, matchmaker, on_result):
self.matchmaker = matchmaker
self.executor = zmq_async.get_executor(
method=self._poll_for_publishers)
self.on_result = on_result
self.executor.execute()
def _poll_for_publishers(self):
publishers = self.matchmaker.get_publishers()
if publishers:
self.on_result(publishers)
self.executor.done()
class BackChatter(object):
def __init__(self, conf, context):
self.socket = zmq_socket.ZmqSocket(conf, context, zmq.PUSH)
def connect(self, address):
self.socket.connect(address)
def send_ready(self):
for i in range(self.socket.connections_count()):
self.socket.send(zmq_names.ACK_TYPE)
def close(self):
self.socket.close()
def cleanup(self):
super(SubConsumer, self).cleanup()

View File

@ -32,27 +32,26 @@ zmq = zmq_async.import_zmq()
class ZmqServer(base.Listener):
def __init__(self, driver, conf, matchmaker=None):
def __init__(self, driver, conf, matchmaker, target, poller=None):
super(ZmqServer, self).__init__()
self.driver = driver
self.conf = conf
self.matchmaker = matchmaker
self.poller = zmq_async.get_poller()
self.target = target
self.poller = poller or zmq_async.get_poller()
self.router_consumer = zmq_router_consumer.RouterConsumer(
conf, self.poller, self)
self.pull_consumer = zmq_pull_consumer.PullConsumer(
conf, self.poller, self)
self.sub_consumer = zmq_sub_consumer.SubConsumer(
conf, self.poller, self) if conf.use_pub_sub else None
self.notify_consumer = self.sub_consumer if conf.use_pub_sub \
else self.router_consumer
self.consumers = [self.router_consumer, self.pull_consumer]
if self.sub_consumer:
self.consumers.append(self.sub_consumer)
@base.batch_poll_helper
def poll(self, timeout=None):
def poll(self, timeout=None, prefetch_size=1):
message, socket = self.poller.poll(
timeout or self.conf.rpc_poll_timeout)
return message
@ -67,15 +66,35 @@ class ZmqServer(base.Listener):
for consumer in self.consumers:
consumer.cleanup()
def listen(self, target):
self.router_consumer.listen(target)
self.pull_consumer.listen(target)
if self.sub_consumer:
self.sub_consumer.listen(target)
def listen_notification(self, targets_and_priorities):
consumer = self.notify_consumer
class ZmqNotificationServer(base.Listener):
def __init__(self, driver, conf, matchmaker, targets_and_priorities):
super(ZmqNotificationServer, self).__init__()
self.driver = driver
self.conf = conf
self.matchmaker = matchmaker
self.servers = []
self.poller = zmq_async.get_poller()
self._listen(targets_and_priorities)
def _listen(self, targets_and_priorities):
for target, priority in targets_and_priorities:
t = copy.deepcopy(target)
t.topic = target.topic + '.' + priority
consumer.listen(t)
self.servers.append(ZmqServer(
self.driver, self.conf, self.matchmaker, t, self.poller))
@base.batch_poll_helper
def poll(self, timeout=None, prefetch_size=1):
message, socket = self.poller.poll(
timeout or self.conf.rpc_poll_timeout)
return message
def stop(self):
for server in self.servers:
server.stop()
def cleanup(self):
for server in self.servers:
server.cleanup()

View File

@ -43,17 +43,6 @@ def get_poller(zmq_concurrency='eventlet'):
return threading_poller.ThreadingPoller()
def get_reply_poller(zmq_concurrency='eventlet'):
_raise_error_if_invalid_config_value(zmq_concurrency)
if zmq_concurrency == 'eventlet' and _is_eventlet_zmq_available():
from oslo_messaging._drivers.zmq_driver.poller import green_poller
return green_poller.HoldReplyPoller()
from oslo_messaging._drivers.zmq_driver.poller import threading_poller
return threading_poller.ThreadingPoller()
def get_executor(method, zmq_concurrency='eventlet'):
_raise_error_if_invalid_config_value(zmq_concurrency)

View File

@ -134,6 +134,7 @@ class ZmqRandomPortSocket(ZmqSocket):
min_port=conf.rpc_zmq_min_port,
max_port=conf.rpc_zmq_max_port,
max_tries=conf.rpc_zmq_bind_port_retries)
self.connected = True
except zmq.ZMQBindError:
LOG.error(_LE("Random ports range exceeded!"))
raise ZmqPortRangeExceededException()

View File

@ -12,7 +12,6 @@
# License for the specific language governing permissions and limitations
# under the License.
import fixtures
import testtools
import oslo_messaging
@ -150,57 +149,3 @@ class TestZmqBasics(zmq_common.ZmqBaseTestCase):
self.driver.send_notification(target, context, message, '3.0')
self.listener._received.wait(5)
self.assertTrue(self.listener._received.isSet())
class TestPoller(test_utils.BaseTestCase):
@testtools.skipIf(zmq is None, "zmq not available")
def setUp(self):
super(TestPoller, self).setUp()
self.poller = zmq_async.get_poller()
self.ctx = zmq.Context()
self.internal_ipc_dir = self.useFixture(fixtures.TempDir()).path
self.ADDR_REQ = "ipc://%s/request1" % self.internal_ipc_dir
def test_poll_blocking(self):
rep = self.ctx.socket(zmq.REP)
rep.bind(self.ADDR_REQ)
reply_poller = zmq_async.get_reply_poller()
reply_poller.register(rep)
def listener():
incoming, socket = reply_poller.poll()
self.assertEqual(b'Hello', incoming[0])
socket.send_string('Reply')
reply_poller.resume_polling(socket)
executor = zmq_async.get_executor(listener)
executor.execute()
req1 = self.ctx.socket(zmq.REQ)
req1.connect(self.ADDR_REQ)
req2 = self.ctx.socket(zmq.REQ)
req2.connect(self.ADDR_REQ)
req1.send_string('Hello')
req2.send_string('Hello')
reply = req1.recv_string()
self.assertEqual('Reply', reply)
reply = req2.recv_string()
self.assertEqual('Reply', reply)
def test_poll_timeout(self):
rep = self.ctx.socket(zmq.REP)
rep.bind(self.ADDR_REQ)
reply_poller = zmq_async.get_reply_poller()
reply_poller.register(rep)
incoming, socket = reply_poller.poll(1)
self.assertIsNone(incoming)
self.assertIsNone(socket)

View File

@ -46,7 +46,7 @@ class TestPubSub(zmq_common.ZmqBaseTestCase):
self.listeners.append(zmq_common.TestServerListener(self.driver))
def _send_request(self, target):
# Needed only in test env to get listener a chance to connect
# Needed only in test env to give listener a chance to connect
# before request fires
time.sleep(1)
with contextlib.closing(zmq_request.FanoutRequest(

View File

@ -105,21 +105,21 @@ class TestGetReplyPoller(test_utils.BaseTestCase):
def test_default_reply_poller_is_HoldReplyPoller(self):
zmq_async._is_eventlet_zmq_available = lambda: True
actual = zmq_async.get_reply_poller()
actual = zmq_async.get_poller()
self.assertTrue(isinstance(actual, green_poller.HoldReplyPoller))
self.assertTrue(isinstance(actual, green_poller.GreenPoller))
def test_when_eventlet_is_available_then_return_HoldReplyPoller(self):
zmq_async._is_eventlet_zmq_available = lambda: True
actual = zmq_async.get_reply_poller('eventlet')
actual = zmq_async.get_poller('eventlet')
self.assertTrue(isinstance(actual, green_poller.HoldReplyPoller))
self.assertTrue(isinstance(actual, green_poller.GreenPoller))
def test_when_eventlet_is_unavailable_then_return_ThreadingPoller(self):
zmq_async._is_eventlet_zmq_available = lambda: False
actual = zmq_async.get_reply_poller('eventlet')
actual = zmq_async.get_poller('eventlet')
self.assertTrue(isinstance(actual, threading_poller.ThreadingPoller))
@ -128,7 +128,7 @@ class TestGetReplyPoller(test_utils.BaseTestCase):
errmsg = 'Invalid zmq_concurrency value: x'
with self.assertRaisesRegexp(ValueError, errmsg):
zmq_async.get_reply_poller(invalid_opt)
zmq_async.get_poller(invalid_opt)
class TestGetExecutor(test_utils.BaseTestCase):