Files
python-ganttclient/nova/rpc.py
2010-12-14 16:05:39 -08:00

394 lines
14 KiB
Python

# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
AMQP-based RPC. Queues have consumers and publishers.
No fan-out support yet.
"""
import json
import logging
import sys
import time
import traceback
import uuid
from carrot import connection as carrot_connection
from carrot import messaging
from eventlet import greenthread
from nova import context
from nova import exception
from nova import fakerabbit
from nova import flags
from nova import utils
FLAGS = flags.FLAGS
LOG = logging.getLogger('amqplib')
LOG.setLevel(logging.DEBUG)
class Connection(carrot_connection.BrokerConnection):
"""Connection instance object"""
@classmethod
def instance(cls, new=False):
"""Returns the instance"""
if new or not hasattr(cls, '_instance'):
params = dict(hostname=FLAGS.rabbit_host,
port=FLAGS.rabbit_port,
userid=FLAGS.rabbit_userid,
password=FLAGS.rabbit_password,
virtual_host=FLAGS.rabbit_virtual_host)
if FLAGS.fake_rabbit:
params['backend_cls'] = fakerabbit.Backend
# NOTE(vish): magic is fun!
# pylint: disable-msg=W0142
if new:
return cls(**params)
else:
cls._instance = cls(**params)
return cls._instance
@classmethod
def recreate(cls):
"""Recreates the connection instance
This is necessary to recover from some network errors/disconnects"""
del cls._instance
return cls.instance()
class Consumer(messaging.Consumer):
"""Consumer base class
Contains methods for connecting the fetch method to async loops
"""
def __init__(self, *args, **kwargs):
for i in xrange(FLAGS.rabbit_max_retries):
if i > 0:
time.sleep(FLAGS.rabbit_retry_interval)
try:
super(Consumer, self).__init__(*args, **kwargs)
self.failed_connection = False
break
except: # Catching all because carrot sucks
logging.exception("AMQP server on %s:%d is unreachable." \
" Trying again in %d seconds." % (
FLAGS.rabbit_host,
FLAGS.rabbit_port,
FLAGS.rabbit_retry_interval))
self.failed_connection = True
if self.failed_connection:
logging.exception("Unable to connect to AMQP server" \
" after %d tries. Shutting down." % FLAGS.rabbit_max_retries)
sys.exit(1)
def fetch(self, no_ack=None, auto_ack=None, enable_callbacks=False):
"""Wraps the parent fetch with some logic for failed connections"""
# TODO(vish): the logic for failed connections and logging should be
# refactored into some sort of connection manager object
try:
if self.failed_connection:
# NOTE(vish): connection is defined in the parent class, we can
# recreate it as long as we create the backend too
# pylint: disable-msg=W0201
self.connection = Connection.recreate()
self.backend = self.connection.create_backend()
self.declare()
super(Consumer, self).fetch(no_ack, auto_ack, enable_callbacks)
if self.failed_connection:
logging.error("Reconnected to queue")
self.failed_connection = False
# NOTE(vish): This is catching all errors because we really don't
# exceptions to be logged 10 times a second if some
# persistent failure occurs.
except Exception: # pylint: disable-msg=W0703
if not self.failed_connection:
logging.exception("Failed to fetch message from queue")
self.failed_connection = True
def attach_to_eventlet(self):
"""Only needed for unit tests!"""
timer = utils.LoopingCall(self.fetch, enable_callbacks=True)
timer.start(0.1)
return timer
class Publisher(messaging.Publisher):
"""Publisher base class"""
pass
class TopicConsumer(Consumer):
"""Consumes messages on a specific topic"""
exchange_type = "topic"
def __init__(self, connection=None, topic="broadcast"):
self.queue = topic
self.routing_key = topic
self.exchange = FLAGS.control_exchange
self.durable = False
super(TopicConsumer, self).__init__(connection=connection)
class AdapterConsumer(TopicConsumer):
"""Calls methods on a proxy object based on method and args"""
def __init__(self, connection=None, topic="broadcast", proxy=None):
LOG.debug('Initing the Adapter Consumer for %s' % (topic))
self.proxy = proxy
super(AdapterConsumer, self).__init__(connection=connection,
topic=topic)
@exception.wrap_exception
def receive(self, message_data, message):
"""Magically looks for a method on the proxy object and calls it
Message data should be a dictionary with two keys:
method: string representing the method to call
args: dictionary of arg: value
Example: {'method': 'echo', 'args': {'value': 42}}
"""
LOG.debug('received %s' % (message_data))
msg_id = message_data.pop('_msg_id', None)
ctxt = _unpack_context(message_data)
method = message_data.get('method')
args = message_data.get('args', {})
message.ack()
if not method:
# NOTE(vish): we may not want to ack here, but that means that bad
# messages stay in the queue indefinitely, so for now
# we just log the message and send an error string
# back to the caller
LOG.warn('no method for message: %s' % (message_data))
msg_reply(msg_id, 'No method for message: %s' % message_data)
return
node_func = getattr(self.proxy, str(method))
node_args = dict((str(k), v) for k, v in args.iteritems())
# NOTE(vish): magic is fun!
try:
rval = node_func(context=ctxt, **node_args)
if msg_id:
msg_reply(msg_id, rval, None)
except Exception as e:
if msg_id:
msg_reply(msg_id, None, sys.exc_info())
return
class TopicPublisher(Publisher):
"""Publishes messages on a specific topic"""
exchange_type = "topic"
def __init__(self, connection=None, topic="broadcast"):
self.routing_key = topic
self.exchange = FLAGS.control_exchange
self.durable = False
super(TopicPublisher, self).__init__(connection=connection)
class DirectConsumer(Consumer):
"""Consumes messages directly on a channel specified by msg_id"""
exchange_type = "direct"
def __init__(self, connection=None, msg_id=None):
self.queue = msg_id
self.routing_key = msg_id
self.exchange = msg_id
self.auto_delete = True
self.exclusive = True
super(DirectConsumer, self).__init__(connection=connection)
class DirectPublisher(Publisher):
"""Publishes messages directly on a channel specified by msg_id"""
exchange_type = "direct"
def __init__(self, connection=None, msg_id=None):
self.routing_key = msg_id
self.exchange = msg_id
self.auto_delete = True
super(DirectPublisher, self).__init__(connection=connection)
def msg_reply(msg_id, reply=None, failure=None):
"""Sends a reply or an error on the channel signified by msg_id
failure should be a sys.exc_info() tuple.
"""
if failure:
message = str(failure[1])
tb = traceback.format_exception(*failure)
logging.error("Returning exception %s to caller", message)
logging.error(tb)
failure = (failure[0].__name__, str(failure[1]), tb)
conn = Connection.instance()
publisher = DirectPublisher(connection=conn, msg_id=msg_id)
try:
publisher.send({'result': reply, 'failure': failure})
except TypeError:
publisher.send(
{'result': dict((k, repr(v))
for k, v in reply.__dict__.iteritems()),
'failure': failure})
publisher.close()
class RemoteError(exception.Error):
"""Signifies that a remote class has raised an exception
Containes a string representation of the type of the original exception,
the value of the original exception, and the traceback. These are
sent to the parent as a joined string so printing the exception
contains all of the relevent info."""
def __init__(self, exc_type, value, traceback):
self.exc_type = exc_type
self.value = value
self.traceback = traceback
super(RemoteError, self).__init__("%s %s\n%s" % (exc_type,
value,
traceback))
def _unpack_context(msg):
"""Unpack context from msg."""
context_dict = {}
for key in list(msg.keys()):
# NOTE(vish): Some versions of python don't like unicode keys
# in kwargs.
key = str(key)
if key.startswith('_context_'):
value = msg.pop(key)
context_dict[key[9:]] = value
LOG.debug('unpacked context: %s', context_dict)
return context.RequestContext.from_dict(context_dict)
def _pack_context(msg, context):
"""Pack context into msg.
Values for message keys need to be less than 255 chars, so we pull
context out into a bunch of separate keys. If we want to support
more arguments in rabbit messages, we may want to do the same
for args at some point.
"""
context = dict([('_context_%s' % key, value)
for (key, value) in context.to_dict().iteritems()])
msg.update(context)
def call(context, topic, msg):
"""Sends a message on a topic and wait for a response"""
LOG.debug("Making asynchronous call...")
msg_id = uuid.uuid4().hex
msg.update({'_msg_id': msg_id})
LOG.debug("MSG_ID is %s" % (msg_id))
_pack_context(msg, context)
class WaitMessage(object):
def __call__(self, data, message):
"""Acks message and sets result."""
message.ack()
if data['failure']:
self.result = RemoteError(*data['failure'])
else:
self.result = data['result']
wait_msg = WaitMessage()
conn = Connection.instance(True)
consumer = DirectConsumer(connection=conn, msg_id=msg_id)
consumer.register_callback(wait_msg)
conn = Connection.instance()
publisher = TopicPublisher(connection=conn, topic=topic)
publisher.send(msg)
publisher.close()
try:
consumer.wait(limit=1)
except StopIteration:
pass
consumer.close()
# NOTE(termie): this is a little bit of a change from the original
# non-eventlet code where returning a Failure
# instance from a deferred call is very similar to
# raising an exception
if isinstance(wait_msg.result, Exception):
raise wait_msg.result
return wait_msg.result
def cast(context, topic, msg):
"""Sends a message on a topic without waiting for a response"""
LOG.debug("Making asynchronous cast...")
_pack_context(msg, context)
conn = Connection.instance()
publisher = TopicPublisher(connection=conn, topic=topic)
publisher.send(msg)
publisher.close()
def generic_response(message_data, message):
"""Logs a result and exits"""
LOG.debug('response %s', message_data)
message.ack()
sys.exit(0)
def send_message(topic, message, wait=True):
"""Sends a message for testing"""
msg_id = uuid.uuid4().hex
message.update({'_msg_id': msg_id})
LOG.debug('topic is %s', topic)
LOG.debug('message %s', message)
if wait:
consumer = messaging.Consumer(connection=Connection.instance(),
queue=msg_id,
exchange=msg_id,
auto_delete=True,
exchange_type="direct",
routing_key=msg_id)
consumer.register_callback(generic_response)
publisher = messaging.Publisher(connection=Connection.instance(),
exchange=FLAGS.control_exchange,
durable=False,
exchange_type="topic",
routing_key=topic)
publisher.send(message)
publisher.close()
if wait:
consumer.wait()
if __name__ == "__main__":
# NOTE(vish): you can send messages from the command line using
# topic and a json sting representing a dictionary
# for the method
send_message(sys.argv[1], json.loads(sys.argv[2]))