From 656bc21d3490e659193ea3a86c63f70b61a8178f Mon Sep 17 00:00:00 2001 From: Gevorg Davoian Date: Mon, 7 Nov 2016 17:14:48 +0200 Subject: [PATCH] [zmq] Support message versions for rolling upgrades Change-Id: I62b001ca396e3a87c061725db55184b2f1f869e3 Closes-Bug: #1524248 --- .../client/publishers/zmq_publisher_base.py | 6 +- .../zmq_driver/client/zmq_receivers.py | 25 +++++ .../_drivers/zmq_driver/client/zmq_request.py | 7 +- .../zmq_driver/client/zmq_response.py | 16 +++- .../_drivers/zmq_driver/client/zmq_senders.py | 91 +++++++++++++------ .../zmq_driver/proxy/zmq_base_proxy.py | 21 +++-- .../_drivers/zmq_driver/proxy/zmq_sender.py | 73 +++++++++++---- .../server/consumers/zmq_consumer_base.py | 8 +- .../server/consumers/zmq_dealer_consumer.py | 74 ++++++++++----- .../server/consumers/zmq_router_consumer.py | 51 ++++++++--- .../server/consumers/zmq_sub_consumer.py | 35 +++++-- .../_drivers/zmq_driver/zmq_names.py | 8 +- .../_drivers/zmq_driver/zmq_version.py | 60 ++++++++++++ .../tests/drivers/zmq/test_pub_sub.py | 20 ++-- .../tests/drivers/zmq/test_zmq_version.py | 63 +++++++++++++ 15 files changed, 432 insertions(+), 126 deletions(-) create mode 100644 oslo_messaging/_drivers/zmq_driver/zmq_version.py create mode 100644 oslo_messaging/tests/drivers/zmq/test_zmq_version.py diff --git a/oslo_messaging/_drivers/zmq_driver/client/publishers/zmq_publisher_base.py b/oslo_messaging/_drivers/zmq_driver/client/publishers/zmq_publisher_base.py index 5de32bbe5..09bec6f42 100644 --- a/oslo_messaging/_drivers/zmq_driver/client/publishers/zmq_publisher_base.py +++ b/oslo_messaging/_drivers/zmq_driver/client/publishers/zmq_publisher_base.py @@ -35,7 +35,7 @@ class PublisherBase(object): def __init__(self, sockets_manager, sender, receiver): - """Construct publisher + """Construct publisher. Accept sockets manager, sender and receiver objects. @@ -54,7 +54,7 @@ class PublisherBase(object): @abc.abstractmethod def acquire_connection(self, request): - """Get socket to publish request on it + """Get socket to publish request on it. :param request: request object :type senders: zmq_request.Request @@ -62,7 +62,7 @@ class PublisherBase(object): @abc.abstractmethod def send_request(self, socket, request): - """Publish request on a socket + """Publish request on a socket. :param socket: socket object to publish request on :type socket: zmq_socket.ZmqSocket diff --git a/oslo_messaging/_drivers/zmq_driver/client/zmq_receivers.py b/oslo_messaging/_drivers/zmq_driver/client/zmq_receivers.py index 5bdfb0fa9..7754c2413 100644 --- a/oslo_messaging/_drivers/zmq_driver/client/zmq_receivers.py +++ b/oslo_messaging/_drivers/zmq_driver/client/zmq_receivers.py @@ -22,6 +22,7 @@ import six from oslo_messaging._drivers.zmq_driver.client import zmq_response from oslo_messaging._drivers.zmq_driver import zmq_async from oslo_messaging._drivers.zmq_driver import zmq_names +from oslo_messaging._drivers.zmq_driver import zmq_version from oslo_messaging._i18n import _LE LOG = logging.getLogger(__name__) @@ -52,6 +53,8 @@ class ReceiverBase(object): self._lock = threading.Lock() self._requests = {} self._poller = zmq_async.get_poller() + self._receive_response_versions = \ + zmq_version.get_method_versions(self, 'receive_response') self._executor = zmq_async.get_executor(self._run_loop) self._executor.execute() @@ -121,6 +124,12 @@ class ReceiverBase(object): {"msg_type": zmq_names.message_type_str(message_type), "msg_id": message_id}) + def _get_receive_response_version(self, version): + receive_response_version = self._receive_response_versions.get(version) + if receive_response_version is None: + raise zmq_version.UnsupportedMessageVersionError(version) + return receive_response_version + class ReceiverProxy(ReceiverBase): @@ -128,6 +137,14 @@ class ReceiverProxy(ReceiverBase): def receive_response(self, socket): empty = socket.recv() assert empty == b'', "Empty delimiter expected!" + message_version = socket.recv_string() + assert message_version != b'', "Valid message version expected!" + + receive_response_version = \ + self._get_receive_response_version(message_version) + return receive_response_version(socket) + + def _receive_response_v_1_0(self, socket): reply_id = socket.recv() assert reply_id != b'', "Valid reply id expected!" message_type = int(socket.recv()) @@ -153,6 +170,14 @@ class ReceiverDirect(ReceiverBase): def receive_response(self, socket): empty = socket.recv() assert empty == b'', "Empty delimiter expected!" + message_version = socket.recv_string() + assert message_version != b'', "Valid message version expected!" + + receive_response_version = \ + self._get_receive_response_version(message_version) + return receive_response_version(socket) + + def _receive_response_v_1_0(self, socket): message_type = int(socket.recv()) assert message_type in zmq_names.RESPONSE_TYPES, "Response expected!" message_id = socket.recv_string() diff --git a/oslo_messaging/_drivers/zmq_driver/client/zmq_request.py b/oslo_messaging/_drivers/zmq_driver/client/zmq_request.py index 5d1a04495..d2b3110de 100644 --- a/oslo_messaging/_drivers/zmq_driver/client/zmq_request.py +++ b/oslo_messaging/_drivers/zmq_driver/client/zmq_request.py @@ -1,4 +1,4 @@ -# Copyright 2015 Mirantis, Inc. +# Copyright 2015-2016 Mirantis, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain @@ -20,6 +20,7 @@ import six from oslo_messaging._drivers.zmq_driver import zmq_async from oslo_messaging._drivers.zmq_driver import zmq_names +from oslo_messaging._drivers.zmq_driver import zmq_version from oslo_messaging._i18n import _LE LOG = logging.getLogger(__name__) @@ -73,6 +74,10 @@ class Request(object): def msg_type(self): """ZMQ request type""" + @property + def message_version(self): + return zmq_version.MESSAGE_VERSION + class RpcRequest(Request): diff --git a/oslo_messaging/_drivers/zmq_driver/client/zmq_response.py b/oslo_messaging/_drivers/zmq_driver/client/zmq_response.py index 3da30b670..140feed46 100644 --- a/oslo_messaging/_drivers/zmq_driver/client/zmq_response.py +++ b/oslo_messaging/_drivers/zmq_driver/client/zmq_response.py @@ -22,12 +22,13 @@ from oslo_messaging._drivers.zmq_driver import zmq_names @six.add_metaclass(abc.ABCMeta) class Response(object): - def __init__(self, message_id=None, reply_id=None): + def __init__(self, message_id=None, reply_id=None, message_version=None): if self.msg_type not in zmq_names.RESPONSE_TYPES: raise RuntimeError("Unknown response type!") self._message_id = message_id self._reply_id = reply_id + self._message_version = message_version @abc.abstractproperty def msg_type(self): @@ -41,9 +42,14 @@ class Response(object): def reply_id(self): return self._reply_id + @property + def message_version(self): + return self._message_version + def to_dict(self): return {zmq_names.FIELD_MSG_ID: self._message_id, - zmq_names.FIELD_REPLY_ID: self._reply_id} + zmq_names.FIELD_REPLY_ID: self._reply_id, + zmq_names.FIELD_MSG_VERSION: self._message_version} def __str__(self): return str(self.to_dict()) @@ -58,9 +64,9 @@ class Reply(Response): msg_type = zmq_names.REPLY_TYPE - def __init__(self, message_id=None, reply_id=None, reply_body=None, - failure=None): - super(Reply, self).__init__(message_id, reply_id) + def __init__(self, message_id=None, reply_id=None, message_version=None, + reply_body=None, failure=None): + super(Reply, self).__init__(message_id, reply_id, message_version) self._reply_body = reply_body self._failure = failure diff --git a/oslo_messaging/_drivers/zmq_driver/client/zmq_senders.py b/oslo_messaging/_drivers/zmq_driver/client/zmq_senders.py index 909c8689b..f63e1d716 100644 --- a/oslo_messaging/_drivers/zmq_driver/client/zmq_senders.py +++ b/oslo_messaging/_drivers/zmq_driver/client/zmq_senders.py @@ -20,6 +20,7 @@ import six from oslo_messaging._drivers.zmq_driver import zmq_async from oslo_messaging._drivers.zmq_driver import zmq_names +from oslo_messaging._drivers.zmq_driver import zmq_version LOG = logging.getLogger(__name__) @@ -28,11 +29,18 @@ zmq = zmq_async.import_zmq() @six.add_metaclass(abc.ABCMeta) class SenderBase(object): - """Base request/ack/reply sending interface.""" + """Base request/response sending interface.""" def __init__(self, conf): self.conf = conf self._lock = threading.Lock() + self._send_versions = zmq_version.get_method_versions(self, 'send') + + def _get_send_version(self, version): + send_version = self._send_versions.get(version) + if send_version is None: + raise zmq_version.UnsupportedMessageVersionError(version) + return send_version @abc.abstractmethod def send(self, socket, message): @@ -54,18 +62,24 @@ class ReplySenderBase(SenderBase): class RequestSenderProxy(RequestSenderBase): def send(self, socket, request): + assert request.msg_type in zmq_names.REQUEST_TYPES, "Request expected!" + + send_version = self._get_send_version(request.message_version) + with self._lock: - self._send(socket, request) + send_version(socket, request) LOG.debug("->[proxy:%(addr)s] Sending %(msg_type)s message " - "%(msg_id)s to target %(target)s", + "%(msg_id)s to target %(target)s (v%(msg_version)s)", {"addr": list(socket.connections), "msg_type": zmq_names.message_type_str(request.msg_type), "msg_id": request.message_id, - "target": request.target}) + "target": request.target, + "msg_version": request.message_version}) - def _send(self, socket, request): + def _send_v_1_0(self, socket, request): socket.send(b'', zmq.SNDMORE) + socket.send_string('1.0', zmq.SNDMORE) socket.send(six.b(str(request.msg_type)), zmq.SNDMORE) socket.send(request.routing_key, zmq.SNDMORE) socket.send_string(request.message_id, zmq.SNDMORE) @@ -77,16 +91,21 @@ class AckSenderProxy(AckSenderBase): def send(self, socket, ack): assert ack.msg_type == zmq_names.ACK_TYPE, "Ack expected!" - with self._lock: - self._send(socket, ack) + send_version = self._get_send_version(ack.message_version) - LOG.debug("->[proxy:%(addr)s] Sending %(msg_type)s for %(msg_id)s", + with self._lock: + send_version(socket, ack) + + LOG.debug("->[proxy:%(addr)s] Sending %(msg_type)s for %(msg_id)s " + "(v%(msg_version)s)", {"addr": list(socket.connections), "msg_type": zmq_names.message_type_str(ack.msg_type), - "msg_id": ack.message_id}) + "msg_id": ack.message_id, + "msg_version": ack.message_version}) - def _send(self, socket, ack): + def _send_v_1_0(self, socket, ack): socket.send(b'', zmq.SNDMORE) + socket.send_string('1.0', zmq.SNDMORE) socket.send(six.b(str(ack.msg_type)), zmq.SNDMORE) socket.send(ack.reply_id, zmq.SNDMORE) socket.send_string(ack.message_id) @@ -97,16 +116,21 @@ class ReplySenderProxy(ReplySenderBase): def send(self, socket, reply): assert reply.msg_type == zmq_names.REPLY_TYPE, "Reply expected!" - with self._lock: - self._send(socket, reply) + send_version = self._get_send_version(reply.message_version) - LOG.debug("->[proxy:%(addr)s] Sending %(msg_type)s for %(msg_id)s", + with self._lock: + send_version(socket, reply) + + LOG.debug("->[proxy:%(addr)s] Sending %(msg_type)s for %(msg_id)s " + "(v%(msg_version)s)", {"addr": list(socket.connections), "msg_type": zmq_names.message_type_str(reply.msg_type), - "msg_id": reply.message_id}) + "msg_id": reply.message_id, + "msg_version": reply.message_version}) - def _send(self, socket, reply): + def _send_v_1_0(self, socket, reply): socket.send(b'', zmq.SNDMORE) + socket.send_string('1.0', zmq.SNDMORE) socket.send(six.b(str(reply.msg_type)), zmq.SNDMORE) socket.send(reply.reply_id, zmq.SNDMORE) socket.send_string(reply.message_id, zmq.SNDMORE) @@ -116,17 +140,23 @@ class ReplySenderProxy(ReplySenderBase): class RequestSenderDirect(RequestSenderBase): def send(self, socket, request): + assert request.msg_type in zmq_names.REQUEST_TYPES, "Request expected!" + + send_version = self._get_send_version(request.message_version) + with self._lock: - self._send(socket, request) + send_version(socket, request) LOG.debug("Sending %(msg_type)s message %(msg_id)s to " - "target %(target)s", + "target %(target)s (v%(msg_version)s)", {"msg_type": zmq_names.message_type_str(request.msg_type), "msg_id": request.message_id, - "target": request.target}) + "target": request.target, + "msg_version": request.message_version}) - def _send(self, socket, request): + def _send_v_1_0(self, socket, request): socket.send(b'', zmq.SNDMORE) + socket.send_string('1.0', zmq.SNDMORE) socket.send(six.b(str(request.msg_type)), zmq.SNDMORE) socket.send_string(request.message_id, zmq.SNDMORE) socket.send_dumped([request.context, request.message]) @@ -137,14 +167,17 @@ class AckSenderDirect(AckSenderBase): def send(self, socket, ack): assert ack.msg_type == zmq_names.ACK_TYPE, "Ack expected!" + send_version = self._get_send_version(ack.message_version) + with self._lock: - self._send(socket, ack) + send_version(socket, ack) - LOG.debug("Sending %(msg_type)s for %(msg_id)s", + LOG.debug("Sending %(msg_type)s for %(msg_id)s (v%(msg_version)s)", {"msg_type": zmq_names.message_type_str(ack.msg_type), - "msg_id": ack.message_id}) + "msg_id": ack.message_id, + "msg_version": ack.message_version}) - def _send(self, socket, ack): + def _send_v_1_0(self, socket, ack): raise NotImplementedError() @@ -153,16 +186,20 @@ class ReplySenderDirect(ReplySenderBase): def send(self, socket, reply): assert reply.msg_type == zmq_names.REPLY_TYPE, "Reply expected!" + send_version = self._get_send_version(reply.message_version) + with self._lock: - self._send(socket, reply) + send_version(socket, reply) - LOG.debug("Sending %(msg_type)s for %(msg_id)s", + LOG.debug("Sending %(msg_type)s for %(msg_id)s (v%(msg_version)s)", {"msg_type": zmq_names.message_type_str(reply.msg_type), - "msg_id": reply.message_id}) + "msg_id": reply.message_id, + "msg_version": reply.message_version}) - def _send(self, socket, reply): + def _send_v_1_0(self, socket, reply): socket.send(reply.reply_id, zmq.SNDMORE) socket.send(b'', zmq.SNDMORE) + socket.send_string('1.0', zmq.SNDMORE) socket.send(six.b(str(reply.msg_type)), zmq.SNDMORE) socket.send_string(reply.message_id, zmq.SNDMORE) socket.send_dumped([reply.reply_body, reply.failure]) diff --git a/oslo_messaging/_drivers/zmq_driver/proxy/zmq_base_proxy.py b/oslo_messaging/_drivers/zmq_driver/proxy/zmq_base_proxy.py index 4bfe521a5..a1600b5d8 100644 --- a/oslo_messaging/_drivers/zmq_driver/proxy/zmq_base_proxy.py +++ b/oslo_messaging/_drivers/zmq_driver/proxy/zmq_base_proxy.py @@ -28,12 +28,16 @@ zmq = zmq_async.import_zmq() def check_message_format(func): - def _check_message_format(*args, **kwargs): + def _check_message_format(socket): try: - return func(*args, **kwargs) + return func(socket) except Exception as e: - LOG.error(_LE("Received message with wrong format")) - LOG.exception(e) + LOG.error(_LE("Received message with wrong format: %r. " + "Dropping invalid message"), e) + # NOTE(gdavoian): drop the left parts of a broken message, since + # they most likely will break the order of next messages' parts + if socket.getsockopt(zmq.RCVMORE): + socket.recv_multipart() return _check_message_format @@ -65,11 +69,12 @@ class ProxyBase(object): @check_message_format def _receive_message(socket): message = socket.recv_multipart() - assert len(message) > zmq_names.MESSAGE_ID_IDX, "Not enough parts" - assert message[zmq_names.REPLY_ID_IDX] != b'', "Valid id expected" + assert message[zmq_names.EMPTY_IDX] == b'', "Empty delimiter expected!" message_type = int(message[zmq_names.MESSAGE_TYPE_IDX]) - assert message_type in zmq_names.MESSAGE_TYPES, "Known type expected!" - assert message[zmq_names.EMPTY_IDX] == b'', "Empty delimiter expected" + assert message_type in zmq_names.MESSAGE_TYPES, \ + "Known message type expected!" + assert len(message) > zmq_names.MESSAGE_ID_IDX, \ + "At least %d parts expected!" % (zmq_names.MESSAGE_ID_IDX + 1) return message def cleanup(self): diff --git a/oslo_messaging/_drivers/zmq_driver/proxy/zmq_sender.py b/oslo_messaging/_drivers/zmq_driver/proxy/zmq_sender.py index 8318d3b70..0d1952ac3 100644 --- a/oslo_messaging/_drivers/zmq_driver/proxy/zmq_sender.py +++ b/oslo_messaging/_drivers/zmq_driver/proxy/zmq_sender.py @@ -19,6 +19,8 @@ import six from oslo_messaging._drivers.zmq_driver import zmq_async from oslo_messaging._drivers.zmq_driver import zmq_names +from oslo_messaging._drivers.zmq_driver import zmq_version +from oslo_messaging._i18n import _LW LOG = logging.getLogger(__name__) @@ -30,79 +32,116 @@ class Sender(object): @abc.abstractmethod def send_message(self, socket, multipart_message): - """Send message to a socket from multipart list""" + """Send message to a socket from a multipart list.""" -class CentralRouterSender(Sender): +class CentralSender(Sender): + + def __init__(self): + self._send_message_versions = \ + zmq_version.get_method_versions(self, 'send_message') def send_message(self, socket, multipart_message): + message_version = multipart_message[zmq_names.MESSAGE_VERSION_IDX] + if six.PY3: + message_version = message_version.decode('utf-8') + + send_message_version = self._send_message_versions.get(message_version) + if send_message_version is None: + LOG.warning(_LW("Dropping message with unsupported version %s"), + message_version) + return + + send_message_version(socket, multipart_message) + + +class LocalSender(Sender): + pass + + +class CentralRouterSender(CentralSender): + + def _send_message_v_1_0(self, socket, multipart_message): message_type = int(multipart_message[zmq_names.MESSAGE_TYPE_IDX]) routing_key = multipart_message[zmq_names.ROUTING_KEY_IDX] reply_id = multipart_message[zmq_names.REPLY_ID_IDX] message_id = multipart_message[zmq_names.MESSAGE_ID_IDX] + message_version = multipart_message[zmq_names.MESSAGE_VERSION_IDX] socket.send(routing_key, zmq.SNDMORE) socket.send(b'', zmq.SNDMORE) + socket.send(message_version, zmq.SNDMORE) socket.send(reply_id, zmq.SNDMORE) socket.send(multipart_message[zmq_names.MESSAGE_TYPE_IDX], zmq.SNDMORE) socket.send_multipart(multipart_message[zmq_names.MESSAGE_ID_IDX:]) LOG.debug("Dispatching %(msg_type)s message %(msg_id)s - from %(rid)s " - "-> to %(rkey)s", + "-> to %(rkey)s (v%(msg_version)s)", {"msg_type": zmq_names.message_type_str(message_type), "msg_id": message_id, "rkey": routing_key, - "rid": reply_id}) + "rid": reply_id, + "msg_version": message_version}) -class CentralAckSender(Sender): +class CentralAckSender(CentralSender): - def send_message(self, socket, multipart_message): + def _send_message_v_1_0(self, socket, multipart_message): message_type = zmq_names.ACK_TYPE message_id = multipart_message[zmq_names.MESSAGE_ID_IDX] routing_key = socket.handle.identity reply_id = multipart_message[zmq_names.REPLY_ID_IDX] + message_version = multipart_message[zmq_names.MESSAGE_VERSION_IDX] socket.send(reply_id, zmq.SNDMORE) socket.send(b'', zmq.SNDMORE) + socket.send(message_version, zmq.SNDMORE) socket.send(routing_key, zmq.SNDMORE) socket.send(six.b(str(message_type)), zmq.SNDMORE) socket.send_string(message_id) LOG.debug("Sending %(msg_type)s for %(msg_id)s to %(rid)s " - "[from %(rkey)s]", + "[from %(rkey)s] (v%(msg_version)s)", {"msg_type": zmq_names.message_type_str(message_type), "msg_id": message_id, "rid": reply_id, - "rkey": routing_key}) + "rkey": routing_key, + "msg_version": message_version}) -class CentralPublisherSender(Sender): +class CentralPublisherSender(CentralSender): - def send_message(self, socket, multipart_message): + def _send_message_v_1_0(self, socket, multipart_message): message_type = int(multipart_message[zmq_names.MESSAGE_TYPE_IDX]) assert message_type in zmq_names.MULTISEND_TYPES, "Fanout expected!" topic_filter = multipart_message[zmq_names.ROUTING_KEY_IDX] message_id = multipart_message[zmq_names.MESSAGE_ID_IDX] + message_version = multipart_message[zmq_names.MESSAGE_VERSION_IDX] socket.send(topic_filter, zmq.SNDMORE) + socket.send(message_version, zmq.SNDMORE) socket.send(six.b(str(message_type)), zmq.SNDMORE) socket.send_multipart(multipart_message[zmq_names.MESSAGE_ID_IDX:]) - LOG.debug("Publishing message %(message_id)s on [%(topic)s]", + LOG.debug("Publishing message %(msg_id)s on [%(topic)s] " + "(v%(msg_version)s)", {"topic": topic_filter, - "message_id": message_id}) + "msg_id": message_id, + "msg_version": message_version}) -class LocalPublisherSender(Sender): +class LocalPublisherSender(LocalSender): TOPIC_IDX = 0 - MSG_TYPE_IDX = 1 - MSG_ID_IDX = 2 + MSG_VERSION_IDX = 1 + MSG_TYPE_IDX = 2 + MSG_ID_IDX = 3 def send_message(self, socket, multipart_message): socket.send_multipart(multipart_message) - LOG.debug("Publishing message %(message_id)s on [%(topic)s]", + LOG.debug("Publishing message %(msg_id)s on [%(topic)s] " + "(v%(msg_version)s)", {"topic": multipart_message[self.TOPIC_IDX], - "message_id": multipart_message[self.MSG_ID_IDX]}) + "msg_id": multipart_message[self.MSG_ID_IDX], + "msg_version": multipart_message[self.MSG_VERSION_IDX]}) diff --git a/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_consumer_base.py b/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_consumer_base.py index 0b9fe6200..b2e69fca4 100644 --- a/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_consumer_base.py +++ b/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_consumer_base.py @@ -42,11 +42,11 @@ class ConsumerBase(object): self.context = zmq.Context() def stop(self): - """Stop consumer polling/updates""" + """Stop consumer polling/updating.""" @abc.abstractmethod - def receive_message(self, target): - """Method for poller - receiving message routine""" + def receive_request(self, socket): + """Receive a request via a socket.""" def cleanup(self): for socket in self.sockets: @@ -81,7 +81,7 @@ class SingleSocketConsumer(ConsumerBase): "port": socket.port}) self.host = zmq_address.combine_address( self.conf.oslo_messaging_zmq.rpc_zmq_host, socket.port) - self.poller.register(socket, self.receive_message) + self.poller.register(socket, self.receive_request) return socket except zmq.ZMQError as e: errmsg = _LE("Failed binding to port %(port)d: %(e)s")\ diff --git a/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_dealer_consumer.py b/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_dealer_consumer.py index 57c51ca02..52fea538e 100644 --- a/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_dealer_consumer.py +++ b/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_dealer_consumer.py @@ -29,6 +29,7 @@ from oslo_messaging._drivers.zmq_driver import zmq_address from oslo_messaging._drivers.zmq_driver import zmq_async from oslo_messaging._drivers.zmq_driver import zmq_names from oslo_messaging._drivers.zmq_driver import zmq_updater +from oslo_messaging._drivers.zmq_driver import zmq_version from oslo_messaging._i18n import _LE, _LI, _LW LOG = logging.getLogger(__name__) @@ -44,6 +45,8 @@ class DealerConsumer(zmq_consumer_base.SingleSocketConsumer): conf, server.matchmaker, zmq.DEALER) self.host = None super(DealerConsumer, self).__init__(conf, poller, server, zmq.DEALER) + self._receive_request_versions = \ + zmq_version.get_method_versions(self, 'receive_request') self.connection_updater = ConsumerConnectionUpdater( conf, self.matchmaker, self.socket) LOG.info(_LI("[%s] Run DEALER consumer"), self.host) @@ -59,7 +62,7 @@ class DealerConsumer(zmq_consumer_base.SingleSocketConsumer): self._generate_identity()) self.sockets.append(socket) self.host = socket.handle.identity - self.poller.register(socket, self.receive_message) + self.poller.register(socket, self.receive_request) return socket except zmq.ZMQError as e: LOG.error(_LE("Failed connecting to ROUTER socket %(e)s") % e) @@ -70,48 +73,66 @@ class DealerConsumer(zmq_consumer_base.SingleSocketConsumer): failure = rpc_common.serialize_remote_exception(failure) reply = zmq_response.Reply(message_id=rpc_message.message_id, reply_id=rpc_message.reply_id, + message_version=rpc_message.message_version, reply_body=reply, failure=failure) self.reply_sender.send(rpc_message.socket, reply) return reply - def _create_message(self, context, message, reply_id, message_id, socket, - message_type): + def _create_message(self, context, message, message_version, reply_id, + message_id, socket, message_type): if message_type == zmq_names.CALL_TYPE: message = zmq_incoming_message.ZmqIncomingMessage( - context, message, reply_id=reply_id, message_id=message_id, + context, message, message_version=message_version, + reply_id=reply_id, message_id=message_id, socket=socket, reply_method=self._reply ) else: message = zmq_incoming_message.ZmqIncomingMessage(context, message) - LOG.debug("[%(host)s] Received %(msg_type)s message %(msg_id)s", + LOG.debug("[%(host)s] Received %(msg_type)s message %(msg_id)s " + "(v%(msg_version)s)", {"host": self.host, "msg_type": zmq_names.message_type_str(message_type), - "msg_id": message_id}) + "msg_id": message_id, + "msg_version": message_version}) return message - def receive_message(self, socket): + def _get_receive_request_version(self, version): + receive_request_version = self._receive_request_versions.get(version) + if receive_request_version is None: + raise zmq_version.UnsupportedMessageVersionError(version) + return receive_request_version + + def receive_request(self, socket): try: empty = socket.recv() assert empty == b'', "Empty delimiter expected!" - reply_id = socket.recv() - assert reply_id != b'', "Valid reply id expected!" - message_type = int(socket.recv()) - assert message_type in zmq_names.REQUEST_TYPES, \ - "Request message type expected!" - message_id = socket.recv_string() - assert message_id != '', "Valid message id expected!" - context, message = socket.recv_loaded() + message_version = socket.recv_string() + assert message_version != b'', "Valid message version expected!" - return self._create_message(context, message, reply_id, - message_id, socket, message_type) - except (zmq.ZMQError, AssertionError, ValueError) as e: + receive_request_version = \ + self._get_receive_request_version(message_version) + return receive_request_version(socket) + except (zmq.ZMQError, AssertionError, ValueError, + zmq_version.UnsupportedMessageVersionError) as e: LOG.error(_LE("Receiving message failure: %s"), str(e)) # NOTE(gdavoian): drop the left parts of a broken message if socket.getsockopt(zmq.RCVMORE): socket.recv_multipart() + def _receive_request_v_1_0(self, socket): + reply_id = socket.recv() + assert reply_id != b'', "Valid reply id expected!" + message_type = int(socket.recv()) + assert message_type in zmq_names.REQUEST_TYPES, "Request expected!" + message_id = socket.recv_string() + assert message_id != '', "Valid message id expected!" + context, message = socket.recv_loaded() + + return self._create_message(context, message, '1.0', reply_id, + message_id, socket, message_type) + def cleanup(self): LOG.info(_LI("[%s] Destroy DEALER consumer"), self.host) self.connection_updater.cleanup() @@ -127,9 +148,10 @@ class DealerConsumerWithAcks(DealerConsumer): ttl=conf.oslo_messaging_zmq.rpc_message_ttl ) - def _acknowledge(self, reply_id, message_id, socket): + def _acknowledge(self, message_version, reply_id, message_id, socket): ack = zmq_response.Ack(message_id=message_id, - reply_id=reply_id) + reply_id=reply_id, + message_version=message_version) self.ack_sender.send(socket, ack) def _reply(self, rpc_message, reply, failure): @@ -143,8 +165,8 @@ class DealerConsumerWithAcks(DealerConsumer): if reply is not None: self.reply_sender.send(socket, reply) - def _create_message(self, context, message, reply_id, message_id, socket, - message_type): + def _create_message(self, context, message, message_version, reply_id, + message_id, socket, message_type): # drop a duplicate message if message_id in self.messages_cache: LOG.warning( @@ -159,7 +181,8 @@ class DealerConsumerWithAcks(DealerConsumer): # for the CALL message also try to resend its reply # (of course, if it was already obtained and cached). if message_type in zmq_names.DIRECT_TYPES: - self._acknowledge(reply_id, message_id, socket) + self._acknowledge(message_version, reply_id, message_id, + socket) if message_type == zmq_names.CALL_TYPE: self._reply_from_cache(message_id, socket) return None @@ -170,10 +193,11 @@ class DealerConsumerWithAcks(DealerConsumer): # be too late to wait until the message will be # dispatched and processed by a RPC server if message_type in zmq_names.DIRECT_TYPES: - self._acknowledge(reply_id, message_id, socket) + self._acknowledge(message_version, reply_id, message_id, socket) return super(DealerConsumerWithAcks, self)._create_message( - context, message, reply_id, message_id, socket, message_type + context, message, message_version, reply_id, + message_id, socket, message_type ) def cleanup(self): diff --git a/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_router_consumer.py b/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_router_consumer.py index fe09bc99a..3395f04dd 100644 --- a/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_router_consumer.py +++ b/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_router_consumer.py @@ -22,6 +22,7 @@ from oslo_messaging._drivers.zmq_driver.server.consumers \ from oslo_messaging._drivers.zmq_driver.server import zmq_incoming_message from oslo_messaging._drivers.zmq_driver import zmq_async from oslo_messaging._drivers.zmq_driver import zmq_names +from oslo_messaging._drivers.zmq_driver import zmq_version from oslo_messaging._i18n import _LE, _LI LOG = logging.getLogger(__name__) @@ -34,6 +35,8 @@ class RouterConsumer(zmq_consumer_base.SingleSocketConsumer): def __init__(self, conf, poller, server): self.reply_sender = zmq_senders.ReplySenderDirect(conf) super(RouterConsumer, self).__init__(conf, poller, server, zmq.ROUTER) + self._receive_request_versions = \ + zmq_version.get_method_versions(self, 'receive_request') LOG.info(_LI("[%s] Run ROUTER consumer"), self.host) def _reply(self, rpc_message, reply, failure): @@ -41,48 +44,66 @@ class RouterConsumer(zmq_consumer_base.SingleSocketConsumer): failure = rpc_common.serialize_remote_exception(failure) reply = zmq_response.Reply(message_id=rpc_message.message_id, reply_id=rpc_message.reply_id, + message_version=rpc_message.message_version, reply_body=reply, failure=failure) self.reply_sender.send(rpc_message.socket, reply) return reply - def _create_message(self, context, message, reply_id, message_id, socket, - message_type): + def _create_message(self, context, message, message_version, reply_id, + message_id, socket, message_type): if message_type == zmq_names.CALL_TYPE: message = zmq_incoming_message.ZmqIncomingMessage( - context, message, reply_id=reply_id, message_id=message_id, + context, message, message_version=message_version, + reply_id=reply_id, message_id=message_id, socket=socket, reply_method=self._reply ) else: message = zmq_incoming_message.ZmqIncomingMessage(context, message) - LOG.debug("[%(host)s] Received %(msg_type)s message %(msg_id)s", + LOG.debug("[%(host)s] Received %(msg_type)s message %(msg_id)s " + "(v%(msg_version)s)", {"host": self.host, "msg_type": zmq_names.message_type_str(message_type), - "msg_id": message_id}) + "msg_id": message_id, + "msg_version": message_version}) return message - def receive_message(self, socket): + def _get_receive_request_version(self, version): + receive_request_version = self._receive_request_versions.get(version) + if receive_request_version is None: + raise zmq_version.UnsupportedMessageVersionError(version) + return receive_request_version + + def receive_request(self, socket): try: reply_id = socket.recv() assert reply_id != b'', "Valid reply id expected!" empty = socket.recv() assert empty == b'', "Empty delimiter expected!" - message_type = int(socket.recv()) - assert message_type in zmq_names.REQUEST_TYPES, \ - "Request message type expected!" - message_id = socket.recv_string() - assert message_id != '', "Valid message id expected!" - context, message = socket.recv_loaded() + message_version = socket.recv_string() + assert message_version != b'', "Valid message version expected!" - return self._create_message(context, message, reply_id, - message_id, socket, message_type) - except (zmq.ZMQError, AssertionError, ValueError) as e: + receive_request_version = \ + self._get_receive_request_version(message_version) + return receive_request_version(reply_id, socket) + except (zmq.ZMQError, AssertionError, ValueError, + zmq_version.UnsupportedMessageVersionError) as e: LOG.error(_LE("Receiving message failed: %s"), str(e)) # NOTE(gdavoian): drop the left parts of a broken message if socket.getsockopt(zmq.RCVMORE): socket.recv_multipart() + def _receive_request_v_1_0(self, reply_id, socket): + message_type = int(socket.recv()) + assert message_type in zmq_names.REQUEST_TYPES, "Request expected!" + message_id = socket.recv_string() + assert message_id != '', "Valid message id expected!" + context, message = socket.recv_loaded() + + return self._create_message(context, message, '1.0', reply_id, + message_id, socket, message_type) + def cleanup(self): LOG.info(_LI("[%s] Destroy ROUTER consumer"), self.host) super(RouterConsumer, self).cleanup() diff --git a/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_sub_consumer.py b/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_sub_consumer.py index c843d3c44..2d53489f7 100644 --- a/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_sub_consumer.py +++ b/oslo_messaging/_drivers/zmq_driver/server/consumers/zmq_sub_consumer.py @@ -25,6 +25,7 @@ from oslo_messaging._drivers.zmq_driver import zmq_async from oslo_messaging._drivers.zmq_driver import zmq_names from oslo_messaging._drivers.zmq_driver import zmq_socket from oslo_messaging._drivers.zmq_driver import zmq_updater +from oslo_messaging._drivers.zmq_driver import zmq_version from oslo_messaging._i18n import _LE, _LI LOG = logging.getLogger(__name__) @@ -45,9 +46,11 @@ class SubConsumer(zmq_consumer_base.ConsumerBase): self.sockets.append(self.socket) self.host = self.socket.handle.identity self._subscribe_to_topic() + self._receive_request_versions = \ + zmq_version.get_method_versions(self, 'receive_request') self.connection_updater = SubscriberConnectionUpdater( conf, self.matchmaker, self.socket) - self.poller.register(self.socket, self.receive_message) + self.poller.register(self.socket, self.receive_request) LOG.info(_LI("[%s] Run SUB consumer"), self.host) def _generate_identity(self): @@ -61,27 +64,39 @@ class SubConsumer(zmq_consumer_base.ConsumerBase): LOG.debug("[%(host)s] Subscribing to topic %(filter)s", {"host": self.host, "filter": topic_filter}) - def _receive_request(self, socket): - topic_filter = socket.recv() + def _get_receive_request_version(self, version): + receive_request_version = self._receive_request_versions.get(version) + if receive_request_version is None: + raise zmq_version.UnsupportedMessageVersionError(version) + return receive_request_version + + def _receive_request_v_1_0(self, topic_filter, socket): message_type = int(socket.recv()) + assert message_type in zmq_names.MULTISEND_TYPES, "Fanout expected!" message_id = socket.recv() context, message = socket.recv_loaded() LOG.debug("[%(host)s] Received on topic %(filter)s message %(msg_id)s " - "%(msg_type)s", + "(v%(msg_version)s)", {'host': self.host, 'filter': topic_filter, 'msg_id': message_id, - 'msg_type': zmq_names.message_type_str(message_type)}) + 'msg_version': '1.0'}) return context, message - def receive_message(self, socket): + def receive_request(self, socket): try: - context, message = self._receive_request(socket) - if not message: - return None + topic_filter = socket.recv() + message_version = socket.recv_string() + receive_request_version = \ + self._get_receive_request_version(message_version) + context, message = receive_request_version(topic_filter, socket) return zmq_incoming_message.ZmqIncomingMessage(context, message) - except (zmq.ZMQError, AssertionError) as e: + except (zmq.ZMQError, AssertionError, ValueError, + zmq_version.UnsupportedMessageVersionError) as e: LOG.error(_LE("Receiving message failed: %s"), str(e)) + # NOTE(gdavoian): drop the left parts of a broken message + if socket.getsockopt(zmq.RCVMORE): + socket.recv_multipart() def cleanup(self): LOG.info(_LI("[%s] Destroy SUB consumer"), self.host) diff --git a/oslo_messaging/_drivers/zmq_driver/zmq_names.py b/oslo_messaging/_drivers/zmq_driver/zmq_names.py index 31b208031..3e345df90 100644 --- a/oslo_messaging/_drivers/zmq_driver/zmq_names.py +++ b/oslo_messaging/_drivers/zmq_driver/zmq_names.py @@ -19,15 +19,17 @@ zmq = zmq_async.import_zmq() FIELD_MSG_ID = 'message_id' FIELD_REPLY_ID = 'reply_id' +FIELD_MSG_VERSION = 'message_version' FIELD_REPLY_BODY = 'reply_body' FIELD_FAILURE = 'failure' REPLY_ID_IDX = 0 EMPTY_IDX = 1 -MESSAGE_TYPE_IDX = 2 -ROUTING_KEY_IDX = 3 -MESSAGE_ID_IDX = 4 +MESSAGE_VERSION_IDX = 2 +MESSAGE_TYPE_IDX = 3 +ROUTING_KEY_IDX = 4 +MESSAGE_ID_IDX = 5 DEFAULT_TYPE = 0 diff --git a/oslo_messaging/_drivers/zmq_driver/zmq_version.py b/oslo_messaging/_drivers/zmq_driver/zmq_version.py new file mode 100644 index 000000000..92baf7174 --- /dev/null +++ b/oslo_messaging/_drivers/zmq_driver/zmq_version.py @@ -0,0 +1,60 @@ +# Copyright 2016 Mirantis, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import re + +from oslo_messaging._drivers import common as rpc_common +from oslo_messaging._i18n import _ + + +# current driver's version for representing internal message format +MESSAGE_VERSION = '1.0' + + +class UnsupportedMessageVersionError(rpc_common.RPCException): + msg_fmt = _("Message version %(version)s is not supported.") + + def __init__(self, version): + super(UnsupportedMessageVersionError, self).__init__(version=version) + + +def get_method_versions(obj, method_name): + """Useful function for initializing versioned senders/receivers. + + Returns a dictionary of different internal versions of the given method. + + Assumes that the object has the particular versioned method and this method + is public. Thus versions are private implementations of the method. + + For example, for a method 'func' methods '_func_v_1_0', '_func_v_1_5', + '_func_v_2_0', etc. are assumed as its respective 1.0, 1.5, 2.0 versions. + """ + + assert callable(getattr(obj, method_name, None)), \ + "Object must have specified method!" + assert not method_name.startswith('_'), "Method must be public!" + + method_versions = {} + for attr_name in dir(obj): + if attr_name == method_name: + continue + attr = getattr(obj, attr_name, None) + if not callable(attr): + continue + match_obj = re.match(r'^_%s_v_(\d)_(\d)$' % method_name, attr_name) + if match_obj is not None: + version = '.'.join([match_obj.group(1), match_obj.group(2)]) + method_versions[version] = attr + + return method_versions diff --git a/oslo_messaging/tests/drivers/zmq/test_pub_sub.py b/oslo_messaging/tests/drivers/zmq/test_pub_sub.py index da21c604b..ec302c6c7 100644 --- a/oslo_messaging/tests/drivers/zmq/test_pub_sub.py +++ b/oslo_messaging/tests/drivers/zmq/test_pub_sub.py @@ -1,4 +1,4 @@ -# Copyright 2015 Mirantis, Inc. +# Copyright 2015-2016 Mirantis, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain @@ -28,6 +28,7 @@ from oslo_messaging._drivers.zmq_driver.proxy import zmq_proxy from oslo_messaging._drivers.zmq_driver import zmq_address from oslo_messaging._drivers.zmq_driver import zmq_async from oslo_messaging._drivers.zmq_driver import zmq_names +from oslo_messaging._drivers.zmq_driver import zmq_version from oslo_messaging.tests.drivers.zmq import zmq_common load_tests = testscenarios.load_tests_apply_scenarios @@ -58,15 +59,14 @@ class TestPubSub(zmq_common.ZmqBaseTestCase): self.config(group='oslo_messaging_zmq', **kwargs) self.config(host="127.0.0.1", group="zmq_proxy_opts") - self.config(publisher_port="0", group="zmq_proxy_opts") + self.config(publisher_port=0, group="zmq_proxy_opts") self.publisher = zmq_publisher_proxy.PublisherProxy( self.conf, self.driver.matchmaker) - self.driver.matchmaker.register_publisher( - (self.publisher.host, "")) + self.driver.matchmaker.register_publisher((self.publisher.host, '')) self.listeners = [] - for i in range(self.LISTENERS_COUNT): + for _ in range(self.LISTENERS_COUNT): self.listeners.append(zmq_common.TestServerListener(self.driver)) def tearDown(self): @@ -83,10 +83,14 @@ class TestPubSub(zmq_common.ZmqBaseTestCase): message = {'method': 'hello-world'} self.publisher.send_request( - [b'', b'', zmq_names.CAST_FANOUT_TYPE, + [b"reply_id", + b'', + six.b(zmq_version.MESSAGE_VERSION), + six.b(str(zmq_names.CAST_FANOUT_TYPE)), zmq_address.target_to_subscribe_filter(target), - b"0000-0000", - self.dumps([context, message])]) + b"message_id", + self.dumps([context, message])] + ) def _check_listener(self, listener): listener._received.wait(timeout=5) diff --git a/oslo_messaging/tests/drivers/zmq/test_zmq_version.py b/oslo_messaging/tests/drivers/zmq/test_zmq_version.py new file mode 100644 index 000000000..9b0189403 --- /dev/null +++ b/oslo_messaging/tests/drivers/zmq/test_zmq_version.py @@ -0,0 +1,63 @@ +# Copyright 2016 Mirantis, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from oslo_messaging._drivers.zmq_driver import zmq_version +from oslo_messaging.tests import utils as test_utils + + +class Doer(object): + + def __init__(self): + self.x = 1 + self.y = 2 + self.z = 3 + + def _sudo(self): + pass + + def do(self): + pass + + def _do_v_1_1(self): + pass + + def _do_v_2_2(self): + pass + + def _do_v_3_3(self): + pass + + +class TestZmqVersion(test_utils.BaseTestCase): + + def setUp(self): + super(TestZmqVersion, self).setUp() + self.doer = Doer() + + def test_get_unknown_attr_versions(self): + self.assertRaises(AssertionError, zmq_version.get_method_versions, + self.doer, 'qwerty') + + def test_get_non_method_attr_versions(self): + for attr_name in vars(self.doer): + self.assertRaises(AssertionError, zmq_version.get_method_versions, + self.doer, attr_name) + + def test_get_private_method_versions(self): + self.assertRaises(AssertionError, zmq_version.get_method_versions, + self.doer, '_sudo') + + def test_get_public_method_versions(self): + do_versions = zmq_version.get_method_versions(self.doer, 'do') + self.assertEqual(['1.1', '2.2', '3.3'], sorted(do_versions.keys()))