oslo.messaging/oslo_messaging/_drivers/amqpdriver.py

454 lines
16 KiB
Python

# Copyright 2013 Red Hat, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
__all__ = ['AMQPDriverBase']
import logging
import threading
import uuid
from six import moves
import oslo_messaging
from oslo_messaging._drivers import amqp as rpc_amqp
from oslo_messaging._drivers import base
from oslo_messaging._drivers import common as rpc_common
from oslo_messaging._i18n import _
from oslo_messaging._i18n import _LI
LOG = logging.getLogger(__name__)
class AMQPIncomingMessage(base.IncomingMessage):
def __init__(self, listener, ctxt, message, unique_id, msg_id, reply_q):
super(AMQPIncomingMessage, self).__init__(listener, ctxt,
dict(message))
self.unique_id = unique_id
self.msg_id = msg_id
self.reply_q = reply_q
self.acknowledge_callback = message.acknowledge
self.requeue_callback = message.requeue
def _send_reply(self, conn, reply=None, failure=None,
ending=False, log_failure=True):
if failure:
failure = rpc_common.serialize_remote_exception(failure,
log_failure)
msg = {'result': reply, 'failure': failure}
if ending:
msg['ending'] = True
rpc_amqp._add_unique_id(msg)
# If a reply_q exists, add the msg_id to the reply and pass the
# reply_q to direct_send() to use it as the response queue.
# Otherwise use the msg_id for backward compatibility.
if self.reply_q:
msg['_msg_id'] = self.msg_id
conn.direct_send(self.reply_q, rpc_common.serialize_msg(msg))
else:
conn.direct_send(self.msg_id, rpc_common.serialize_msg(msg))
def reply(self, reply=None, failure=None, log_failure=True):
if not self.msg_id:
# NOTE(Alexei_987) not sending reply, if msg_id is empty
# because reply should not be expected by caller side
return
with self.listener.driver._get_connection() as conn:
self._send_reply(conn, reply, failure, log_failure=log_failure)
self._send_reply(conn, ending=True)
def acknowledge(self):
self.listener.msg_id_cache.add(self.unique_id)
self.acknowledge_callback()
def requeue(self):
# NOTE(sileht): In case of the connection is lost between receiving the
# message and requeing it, this requeue call fail
# but because the message is not acknowledged and not added to the
# msg_id_cache, the message will be reconsumed, the only difference is
# the message stay at the beginning of the queue instead of moving to
# the end.
self.requeue_callback()
class AMQPListener(base.Listener):
def __init__(self, driver, conn):
super(AMQPListener, self).__init__(driver)
self.conn = conn
self.msg_id_cache = rpc_amqp._MsgIdCache()
self.incoming = []
self._stopped = threading.Event()
def __call__(self, message):
ctxt = rpc_amqp.unpack_context(self.conf, message)
# FIXME(sileht): Don't log the message until strutils is more
# efficient, (rpc_amqp.unpack_context already log the context)
# LOG.debug(u'received: %s',
# strutils.mask_password(six.text_type(dict(message))))
unique_id = self.msg_id_cache.check_duplicate_message(message)
self.incoming.append(AMQPIncomingMessage(self,
ctxt.to_dict(),
message,
unique_id,
ctxt.msg_id,
ctxt.reply_q))
def poll(self, timeout=None):
while not self._stopped.is_set():
if self.incoming:
return self.incoming.pop(0)
try:
self.conn.consume(limit=1, timeout=timeout)
except rpc_common.Timeout:
return None
def stop(self):
self._stopped.set()
self.conn.stop_consuming()
def cleanup(self):
# Closes listener connection
self.conn.close()
class ReplyWaiters(object):
WAKE_UP = object()
def __init__(self):
self._queues = {}
self._wrn_threshold = 10
def get(self, msg_id, timeout):
try:
return self._queues[msg_id].get(block=True, timeout=timeout)
except moves.queue.Empty:
raise oslo_messaging.MessagingTimeout(
'Timed out waiting for a reply '
'to message ID %s' % msg_id)
def check(self, msg_id):
try:
return self._queues[msg_id].get(block=False)
except moves.queue.Empty:
return None
def put(self, msg_id, message_data):
queue = self._queues.get(msg_id)
if not queue:
LOG.info(_LI('No calling threads waiting for msg_id : %s'), msg_id)
LOG.debug(' queues: %(queues)s, message: %(message)s',
{'queues': len(self._queues), 'message': message_data})
else:
queue.put(message_data)
def wake_all(self, except_id):
msg_ids = [i for i in self._queues.keys() if i != except_id]
for msg_id in msg_ids:
self.put(msg_id, self.WAKE_UP)
def add(self, msg_id, queue):
self._queues[msg_id] = queue
if len(self._queues) > self._wrn_threshold:
LOG.warn('Number of call queues is greater than warning '
'threshold: %d. There could be a leak. Increasing'
' threshold to: %d', self._wrn_threshold,
self._wrn_threshold * 2)
self._wrn_threshold *= 2
def remove(self, msg_id):
del self._queues[msg_id]
class ReplyWaiter(object):
def __init__(self, conf, reply_q, conn, allowed_remote_exmods):
self.conf = conf
self.conn = conn
self.reply_q = reply_q
self.allowed_remote_exmods = allowed_remote_exmods
self.conn_lock = threading.Lock()
self.incoming = []
self.msg_id_cache = rpc_amqp._MsgIdCache()
self.waiters = ReplyWaiters()
conn.declare_direct_consumer(reply_q, self)
def __call__(self, message):
message.acknowledge()
self.incoming.append(message)
def listen(self, msg_id):
queue = moves.queue.Queue()
self.waiters.add(msg_id, queue)
def unlisten(self, msg_id):
self.waiters.remove(msg_id)
@staticmethod
def _raise_timeout_exception(msg_id):
raise oslo_messaging.MessagingTimeout(
_('Timed out waiting for a reply to message ID %s.') % msg_id)
def _process_reply(self, data):
result = None
ending = False
self.msg_id_cache.check_duplicate_message(data)
if data['failure']:
failure = data['failure']
result = rpc_common.deserialize_remote_exception(
failure, self.allowed_remote_exmods)
elif data.get('ending', False):
ending = True
else:
result = data['result']
return result, ending
def _poll_connection(self, msg_id, timer):
while True:
while self.incoming:
message_data = self.incoming.pop(0)
incoming_msg_id = message_data.pop('_msg_id', None)
if incoming_msg_id == msg_id:
return self._process_reply(message_data)
self.waiters.put(incoming_msg_id, message_data)
timeout = timer.check_return(self._raise_timeout_exception, msg_id)
try:
self.conn.consume(limit=1, timeout=timeout)
except rpc_common.Timeout:
self._raise_timeout_exception(msg_id)
def _poll_queue(self, msg_id, timer):
timeout = timer.check_return(self._raise_timeout_exception, msg_id)
message = self.waiters.get(msg_id, timeout=timeout)
if message is self.waiters.WAKE_UP:
return None, None, True # lock was released
reply, ending = self._process_reply(message)
return reply, ending, False
def _check_queue(self, msg_id):
while True:
message = self.waiters.check(msg_id)
if message is self.waiters.WAKE_UP:
continue
if message is None:
return None, None, True # queue is empty
reply, ending = self._process_reply(message)
return reply, ending, False
def wait(self, msg_id, timeout):
#
# NOTE(markmc): we're waiting for a reply for msg_id to come in for on
# the reply_q, but there may be other threads also waiting for replies
# to other msg_ids
#
# Only one thread can be consuming from the queue using this connection
# and we don't want to hold open a connection per thread, so instead we
# have the first thread take responsibility for passing replies not
# intended for itself to the appropriate thread.
#
timer = rpc_common.DecayingTimer(duration=timeout)
timer.start()
final_reply = None
while True:
if self.conn_lock.acquire(False):
# Ok, we're the thread responsible for polling the connection
try:
# Check the queue to see if a previous lock-holding thread
# queued up a reply already
while True:
reply, ending, empty = self._check_queue(msg_id)
if empty:
break
if not ending:
final_reply = reply
else:
return final_reply
# Now actually poll the connection
while True:
reply, ending = self._poll_connection(msg_id, timer)
if not ending:
final_reply = reply
else:
return final_reply
finally:
self.conn_lock.release()
# We've got our reply, tell the other threads to wake up
# so that one of them will take over the responsibility for
# polling the connection
self.waiters.wake_all(msg_id)
else:
# We're going to wait for the first thread to pass us our reply
reply, ending, trylock = self._poll_queue(msg_id, timer)
if trylock:
# The first thread got its reply, let's try and take over
# the responsibility for polling
continue
if not ending:
final_reply = reply
else:
return final_reply
class AMQPDriverBase(base.BaseDriver):
def __init__(self, conf, url, connection_pool,
default_exchange=None, allowed_remote_exmods=None):
super(AMQPDriverBase, self).__init__(conf, url, default_exchange,
allowed_remote_exmods)
self._default_exchange = default_exchange
self._connection_pool = connection_pool
self._reply_q_lock = threading.Lock()
self._reply_q = None
self._reply_q_conn = None
self._waiter = None
def _get_exchange(self, target):
return target.exchange or self._default_exchange
def _get_connection(self, pooled=True):
return rpc_amqp.ConnectionContext(self._connection_pool,
pooled=pooled)
def _get_reply_q(self):
with self._reply_q_lock:
if self._reply_q is not None:
return self._reply_q
reply_q = 'reply_' + uuid.uuid4().hex
conn = self._get_connection(pooled=False)
self._waiter = ReplyWaiter(self.conf, reply_q, conn,
self._allowed_remote_exmods)
self._reply_q = reply_q
self._reply_q_conn = conn
return self._reply_q
def _send(self, target, ctxt, message,
wait_for_reply=None, timeout=None,
envelope=True, notify=False, retry=None):
# FIXME(markmc): remove this temporary hack
class Context(object):
def __init__(self, d):
self.d = d
def to_dict(self):
return self.d
context = Context(ctxt)
msg = message
if wait_for_reply:
msg_id = uuid.uuid4().hex
msg.update({'_msg_id': msg_id})
LOG.debug('MSG_ID is %s', msg_id)
msg.update({'_reply_q': self._get_reply_q()})
rpc_amqp._add_unique_id(msg)
rpc_amqp.pack_context(msg, context)
if envelope:
msg = rpc_common.serialize_msg(msg)
if wait_for_reply:
self._waiter.listen(msg_id)
try:
with self._get_connection() as conn:
if notify:
conn.notify_send(self._get_exchange(target),
target.topic, msg, retry=retry)
elif target.fanout:
conn.fanout_send(target.topic, msg, retry=retry)
else:
topic = target.topic
if target.server:
topic = '%s.%s' % (target.topic, target.server)
conn.topic_send(exchange_name=self._get_exchange(target),
topic=topic, msg=msg, timeout=timeout,
retry=retry)
if wait_for_reply:
result = self._waiter.wait(msg_id, timeout)
if isinstance(result, Exception):
raise result
return result
finally:
if wait_for_reply:
self._waiter.unlisten(msg_id)
def send(self, target, ctxt, message, wait_for_reply=None, timeout=None,
retry=None):
return self._send(target, ctxt, message, wait_for_reply, timeout,
retry=retry)
def send_notification(self, target, ctxt, message, version, retry=None):
return self._send(target, ctxt, message,
envelope=(version == 2.0), notify=True, retry=retry)
def listen(self, target):
conn = self._get_connection(pooled=False)
listener = AMQPListener(self, conn)
conn.declare_topic_consumer(exchange_name=self._get_exchange(target),
topic=target.topic,
callback=listener)
conn.declare_topic_consumer(exchange_name=self._get_exchange(target),
topic='%s.%s' % (target.topic,
target.server),
callback=listener)
conn.declare_fanout_consumer(target.topic, listener)
return listener
def listen_for_notifications(self, targets_and_priorities, pool):
conn = self._get_connection(pooled=False)
listener = AMQPListener(self, conn)
for target, priority in targets_and_priorities:
conn.declare_topic_consumer(
exchange_name=self._get_exchange(target),
topic='%s.%s' % (target.topic, priority),
callback=listener, queue_name=pool)
return listener
def cleanup(self):
if self._connection_pool:
self._connection_pool.empty()
self._connection_pool = None