107 lines
3.7 KiB
Python
107 lines
3.7 KiB
Python
# Copyright 2016 Mirantis, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
|
# not use this file except in compliance with the License. You may obtain
|
|
# a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
# License for the specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
import abc
|
|
from concurrent import futures
|
|
import logging
|
|
|
|
import oslo_messaging
|
|
from oslo_messaging._drivers import common as rpc_common
|
|
from oslo_messaging._drivers.zmq_driver.client.publishers \
|
|
import zmq_publisher_base
|
|
from oslo_messaging._drivers.zmq_driver.client import zmq_response
|
|
from oslo_messaging._drivers.zmq_driver.client import zmq_sockets_manager
|
|
from oslo_messaging._drivers.zmq_driver import zmq_async
|
|
from oslo_messaging._drivers.zmq_driver import zmq_names
|
|
from oslo_messaging._i18n import _LE
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
zmq = zmq_async.import_zmq()
|
|
|
|
|
|
class DealerPublisherBase(zmq_publisher_base.PublisherBase):
|
|
"""Abstract DEALER-publisher."""
|
|
|
|
def __init__(self, conf, matchmaker, sender, receiver):
|
|
sockets_manager = zmq_sockets_manager.SocketsManager(
|
|
conf, matchmaker, zmq.ROUTER, zmq.DEALER
|
|
)
|
|
super(DealerPublisherBase, self).__init__(sockets_manager, sender,
|
|
receiver)
|
|
|
|
@staticmethod
|
|
def _check_pattern(request, supported_pattern):
|
|
if request.msg_type != supported_pattern:
|
|
raise zmq_publisher_base.UnsupportedSendPattern(
|
|
zmq_names.message_type_str(request.msg_type)
|
|
)
|
|
|
|
@staticmethod
|
|
def _raise_timeout(request):
|
|
raise oslo_messaging.MessagingTimeout(
|
|
"Timeout %(tout)s seconds was reached for message %(msg_id)s" %
|
|
{"tout": request.timeout, "msg_id": request.message_id}
|
|
)
|
|
|
|
def _recv_reply(self, request):
|
|
reply_future = \
|
|
self.receiver.track_request(request)[zmq_names.REPLY_TYPE]
|
|
|
|
try:
|
|
_, reply = reply_future.result(timeout=request.timeout)
|
|
assert isinstance(reply, zmq_response.Reply), "Reply expected!"
|
|
except AssertionError:
|
|
LOG.error(_LE("Message format error in reply for %s"),
|
|
request.message_id)
|
|
return None
|
|
except futures.TimeoutError:
|
|
self._raise_timeout(request)
|
|
finally:
|
|
self.receiver.untrack_request(request)
|
|
|
|
if reply.failure:
|
|
raise rpc_common.deserialize_remote_exception(
|
|
reply.failure, request.allowed_remote_exmods
|
|
)
|
|
else:
|
|
return reply.reply_body
|
|
|
|
def send_call(self, request):
|
|
self._check_pattern(request, zmq_names.CALL_TYPE)
|
|
|
|
socket = self.connect_socket(request)
|
|
if not socket:
|
|
self._raise_timeout(request)
|
|
|
|
self.sender.send(socket, request)
|
|
self.receiver.register_socket(socket)
|
|
return self._recv_reply(request)
|
|
|
|
@abc.abstractmethod
|
|
def _send_non_blocking(self, request):
|
|
pass
|
|
|
|
def send_cast(self, request):
|
|
self._check_pattern(request, zmq_names.CAST_TYPE)
|
|
self._send_non_blocking(request)
|
|
|
|
def send_fanout(self, request):
|
|
self._check_pattern(request, zmq_names.CAST_FANOUT_TYPE)
|
|
self._send_non_blocking(request)
|
|
|
|
def send_notify(self, request):
|
|
self._check_pattern(request, zmq_names.NOTIFY_TYPE)
|
|
self._send_non_blocking(request)
|