From 0b11668e64450039dc071a4a123abd02206f865f Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Mon, 9 Apr 2012 18:08:50 -0400 Subject: [PATCH] Refactor nova.rpc config handling. This patch does a couple of things: 1) Remove the dependency of nova.rpc on nova.flags. This is a step toward decoupling nova.rpc from the rest of nova so that it can be moved to openstack-common. 2) Refactor nova.rpc so that a configuration object is passed around as needed instead of depending on nova.flags.FLAGS. This was done by avoiding changing the nova.rpc API as much as possible so that existing usage of nova.rpc would not have to be touched. So, instead, a config object gets registered, cached, and then passed into the rpc implementations as needed. Getting rid of this global config reference in nova.rpc will require changing the public API and I wanted to avoid doing that until there was a better reason than this. Change-Id: I9a7fa67bd12ced877c83e48e31f5ef7263be6815 --- bin/clear_rabbit_queues | 1 + bin/nova-dhcpbridge | 2 + bin/nova-manage | 3 +- nova/rpc/__init__.py | 55 +++++++++---- nova/rpc/amqp.py | 93 ++++++++++++--------- nova/rpc/common.py | 27 +------ nova/rpc/impl_fake.py | 23 +++--- nova/rpc/impl_kombu.py | 132 ++++++++++++++++-------------- nova/rpc/impl_qpid.py | 135 +++++++++++++++++-------------- nova/service.py | 2 + nova/tests/__init__.py | 3 + nova/tests/rpc/common.py | 34 ++++---- nova/tests/rpc/test_common.py | 8 +- nova/tests/rpc/test_kombu.py | 39 ++++----- nova/tests/rpc/test_kombu_ssl.py | 6 +- nova/tests/rpc/test_qpid.py | 15 ++-- 16 files changed, 319 insertions(+), 259 deletions(-) diff --git a/bin/clear_rabbit_queues b/bin/clear_rabbit_queues index 503fa07d8156..29298f18f3b5 100755 --- a/bin/clear_rabbit_queues +++ b/bin/clear_rabbit_queues @@ -74,6 +74,7 @@ if __name__ == '__main__': utils.default_flagfile() args = flags.FLAGS(sys.argv) logging.setup() + rpc.register_opts(flags.FLAGS) delete_queues(args[1:]) if FLAGS.delete_exchange: delete_exchange(FLAGS.control_exchange) diff --git a/bin/nova-dhcpbridge b/bin/nova-dhcpbridge index e9b822c8a993..a4bb04337513 100755 --- a/bin/nova-dhcpbridge +++ b/bin/nova-dhcpbridge @@ -99,6 +99,8 @@ def main(): argv = FLAGS(sys.argv) logging.setup() + rpc.register_opts(FLAGS) + if int(os.environ.get('TESTING', '0')): from nova.tests import fake_flags diff --git a/bin/nova-manage b/bin/nova-manage index 53c4fd6f30a3..c2d97cb8ad74 100755 --- a/bin/nova-manage +++ b/bin/nova-manage @@ -1671,6 +1671,8 @@ def main(): except Exception: print 'sudo failed, continuing as if nothing happened' + rpc.register_opts(FLAGS) + try: argv = FLAGS(sys.argv) logging.setup() @@ -1679,7 +1681,6 @@ def main(): print _('Please re-run nova-manage as root.') sys.exit(2) raise - script_name = argv.pop(0) if len(argv) < 1: print _("\nOpenStack Nova version: %(version)s (%(vcs)s)\n") % \ diff --git a/nova/rpc/__init__.py b/nova/rpc/__init__.py index 4acc5634cb1c..45d8c00b2d2d 100644 --- a/nova/rpc/__init__.py +++ b/nova/rpc/__init__.py @@ -17,17 +17,37 @@ # License for the specific language governing permissions and limitations # under the License. -from nova import flags from nova.openstack.common import cfg from nova import utils -rpc_backend_opt = cfg.StrOpt('rpc_backend', - default='nova.rpc.impl_kombu', - help="The messaging module to use, defaults to kombu.") +rpc_opts = [ + cfg.StrOpt('rpc_backend', + default='nova.rpc.impl_kombu', + help="The messaging module to use, defaults to kombu."), + cfg.IntOpt('rpc_thread_pool_size', + default=64, + help='Size of RPC thread pool'), + cfg.IntOpt('rpc_conn_pool_size', + default=30, + help='Size of RPC connection pool'), + cfg.IntOpt('rpc_response_timeout', + default=60, + help='Seconds to wait for a response from call or multicall'), + cfg.IntOpt('allowed_rpc_exception_modules', + default=['nova.exception'], + help='Modules of exceptions that are permitted to be recreated' + 'upon receiving exception data from an rpc call.'), + ] -FLAGS = flags.FLAGS -FLAGS.register_opt(rpc_backend_opt) +_CONF = None + + +def register_opts(conf): + global _CONF + _CONF = conf + _CONF.register_opts(rpc_opts) + _get_impl().register_opts(_CONF) def create_connection(new=True): @@ -43,7 +63,7 @@ def create_connection(new=True): :returns: An instance of nova.rpc.common.Connection """ - return _get_impl().create_connection(new=new) + return _get_impl().create_connection(_CONF, new=new) def call(context, topic, msg, timeout=None): @@ -65,7 +85,7 @@ def call(context, topic, msg, timeout=None): :raises: nova.rpc.common.Timeout if a complete response is not received before the timeout is reached. """ - return _get_impl().call(context, topic, msg, timeout) + return _get_impl().call(_CONF, context, topic, msg, timeout) def cast(context, topic, msg): @@ -82,7 +102,7 @@ def cast(context, topic, msg): :returns: None """ - return _get_impl().cast(context, topic, msg) + return _get_impl().cast(_CONF, context, topic, msg) def fanout_cast(context, topic, msg): @@ -102,7 +122,7 @@ def fanout_cast(context, topic, msg): :returns: None """ - return _get_impl().fanout_cast(context, topic, msg) + return _get_impl().fanout_cast(_CONF, context, topic, msg) def multicall(context, topic, msg, timeout=None): @@ -131,7 +151,7 @@ def multicall(context, topic, msg, timeout=None): :raises: nova.rpc.common.Timeout if a complete response is not received before the timeout is reached. """ - return _get_impl().multicall(context, topic, msg, timeout) + return _get_impl().multicall(_CONF, context, topic, msg, timeout) def notify(context, topic, msg): @@ -144,7 +164,7 @@ def notify(context, topic, msg): :returns: None """ - return _get_impl().notify(context, topic, msg) + return _get_impl().notify(_CONF, context, topic, msg) def cleanup(): @@ -172,7 +192,8 @@ def cast_to_server(context, server_params, topic, msg): :returns: None """ - return _get_impl().cast_to_server(context, server_params, topic, msg) + return _get_impl().cast_to_server(_CONF, context, server_params, topic, + msg) def fanout_cast_to_server(context, server_params, topic, msg): @@ -187,16 +208,16 @@ def fanout_cast_to_server(context, server_params, topic, msg): :returns: None """ - return _get_impl().fanout_cast_to_server(context, server_params, topic, - msg) + return _get_impl().fanout_cast_to_server(_CONF, context, server_params, + topic, msg) _RPCIMPL = None def _get_impl(): - """Delay import of rpc_backend until FLAGS are loaded.""" + """Delay import of rpc_backend until configuration is loaded.""" global _RPCIMPL if _RPCIMPL is None: - _RPCIMPL = utils.import_object(FLAGS.rpc_backend) + _RPCIMPL = utils.import_object(_CONF.rpc_backend) return _RPCIMPL diff --git a/nova/rpc/amqp.py b/nova/rpc/amqp.py index ac29a625d948..d58806c9e644 100644 --- a/nova/rpc/amqp.py +++ b/nova/rpc/amqp.py @@ -31,10 +31,10 @@ import uuid from eventlet import greenpool from eventlet import pools +from eventlet import semaphore from nova import context from nova import exception -from nova import flags from nova import log as logging from nova.openstack.common import local import nova.rpc.common as rpc_common @@ -43,27 +43,36 @@ from nova import utils LOG = logging.getLogger(__name__) -FLAGS = flags.FLAGS - - class Pool(pools.Pool): """Class that implements a Pool of Connections.""" - def __init__(self, *args, **kwargs): - self.connection_cls = kwargs.pop("connection_cls", None) - kwargs.setdefault("max_size", FLAGS.rpc_conn_pool_size) + def __init__(self, conf, connection_cls, *args, **kwargs): + self.connection_cls = connection_cls + self.conf = conf + kwargs.setdefault("max_size", self.conf.rpc_conn_pool_size) kwargs.setdefault("order_as_stack", True) super(Pool, self).__init__(*args, **kwargs) # TODO(comstud): Timeout connections not used in a while def create(self): LOG.debug('Pool creating new connection') - return self.connection_cls() + return self.connection_cls(self.conf) def empty(self): while self.free_items: self.get().close() +_pool_create_sem = semaphore.Semaphore() + + +def get_connection_pool(conf, connection_cls): + with _pool_create_sem: + # Make sure only one thread tries to create the connection pool. + if not connection_cls.pool: + connection_cls.pool = Pool(conf, connection_cls) + return connection_cls.pool + + class ConnectionContext(rpc_common.Connection): """The class that is actually returned to the caller of create_connection(). This is a essentially a wrapper around @@ -75,14 +84,15 @@ class ConnectionContext(rpc_common.Connection): the pool. """ - def __init__(self, connection_pool, pooled=True, server_params=None): + def __init__(self, conf, connection_pool, pooled=True, server_params=None): """Create a new connection, or get one from the pool""" self.connection = None + self.conf = conf self.connection_pool = connection_pool if pooled: self.connection = connection_pool.get() else: - self.connection = connection_pool.connection_cls( + self.connection = connection_pool.connection_cls(conf, server_params=server_params) self.pooled = pooled @@ -133,13 +143,14 @@ class ConnectionContext(rpc_common.Connection): raise exception.InvalidRPCConnectionReuse() -def msg_reply(msg_id, connection_pool, reply=None, failure=None, ending=False): +def msg_reply(conf, msg_id, connection_pool, reply=None, failure=None, + ending=False): """Sends a reply or an error on the channel signified by msg_id. Failure should be a sys.exc_info() tuple. """ - with ConnectionContext(connection_pool) as conn: + with ConnectionContext(conf, connection_pool) as conn: if failure: failure = rpc_common.serialize_remote_exception(failure) @@ -158,18 +169,19 @@ class RpcContext(context.RequestContext): """Context that supports replying to a rpc.call""" def __init__(self, *args, **kwargs): self.msg_id = kwargs.pop('msg_id', None) + self.conf = kwargs.pop('conf') super(RpcContext, self).__init__(*args, **kwargs) def reply(self, reply=None, failure=None, ending=False, connection_pool=None): if self.msg_id: - msg_reply(self.msg_id, connection_pool, reply, failure, + msg_reply(self.conf, self.msg_id, connection_pool, reply, failure, ending) if ending: self.msg_id = None -def unpack_context(msg): +def unpack_context(conf, msg): """Unpack context from msg.""" context_dict = {} for key in list(msg.keys()): @@ -180,6 +192,7 @@ def unpack_context(msg): value = msg.pop(key) context_dict[key[9:]] = value context_dict['msg_id'] = msg.pop('_msg_id', None) + context_dict['conf'] = conf ctx = RpcContext.from_dict(context_dict) rpc_common._safe_log(LOG.debug, _('unpacked context: %s'), ctx.to_dict()) return ctx @@ -202,10 +215,11 @@ def pack_context(msg, context): class ProxyCallback(object): """Calls methods on a proxy object based on method and args.""" - def __init__(self, proxy, connection_pool): + def __init__(self, conf, proxy, connection_pool): self.proxy = proxy - self.pool = greenpool.GreenPool(FLAGS.rpc_thread_pool_size) + self.pool = greenpool.GreenPool(conf.rpc_thread_pool_size) self.connection_pool = connection_pool + self.conf = conf def __call__(self, message_data): """Consumer callback to call a method on a proxy object. @@ -225,7 +239,7 @@ class ProxyCallback(object): if hasattr(local.store, 'context'): del local.store.context rpc_common._safe_log(LOG.debug, _('received %s'), message_data) - ctxt = unpack_context(message_data) + ctxt = unpack_context(self.conf, message_data) method = message_data.get('method') args = message_data.get('args', {}) if not method: @@ -262,13 +276,14 @@ class ProxyCallback(object): class MulticallWaiter(object): - def __init__(self, connection, timeout): + def __init__(self, conf, connection, timeout): self._connection = connection self._iterator = connection.iterconsume( - timeout=timeout or FLAGS.rpc_response_timeout) + timeout=timeout or conf.rpc_response_timeout) self._result = None self._done = False self._got_ending = False + self._conf = conf def done(self): if self._done: @@ -282,7 +297,8 @@ class MulticallWaiter(object): """The consume() callback will call this. Store the result.""" if data['failure']: failure = data['failure'] - self._result = rpc_common.deserialize_remote_exception(failure) + self._result = rpc_common.deserialize_remote_exception(self._conf, + failure) elif data.get('ending', False): self._got_ending = True @@ -309,12 +325,12 @@ class MulticallWaiter(object): yield result -def create_connection(new, connection_pool): +def create_connection(conf, new, connection_pool): """Create a connection""" - return ConnectionContext(connection_pool, pooled=not new) + return ConnectionContext(conf, connection_pool, pooled=not new) -def multicall(context, topic, msg, timeout, connection_pool): +def multicall(conf, context, topic, msg, timeout, connection_pool): """Make a call that returns multiple times.""" # Can't use 'with' for multicall, as it returns an iterator # that will continue to use the connection. When it's done, @@ -326,16 +342,16 @@ def multicall(context, topic, msg, timeout, connection_pool): LOG.debug(_('MSG_ID is %s') % (msg_id)) pack_context(msg, context) - conn = ConnectionContext(connection_pool) - wait_msg = MulticallWaiter(conn, timeout) + conn = ConnectionContext(conf, connection_pool) + wait_msg = MulticallWaiter(conf, conn, timeout) conn.declare_direct_consumer(msg_id, wait_msg) conn.topic_send(topic, msg) return wait_msg -def call(context, topic, msg, timeout, connection_pool): +def call(conf, context, topic, msg, timeout, connection_pool): """Sends a message on a topic and wait for a response.""" - rv = multicall(context, topic, msg, timeout, connection_pool) + rv = multicall(conf, context, topic, msg, timeout, connection_pool) # NOTE(vish): return the last result from the multicall rv = list(rv) if not rv: @@ -343,47 +359,48 @@ def call(context, topic, msg, timeout, connection_pool): return rv[-1] -def cast(context, topic, msg, connection_pool): +def cast(conf, context, topic, msg, connection_pool): """Sends a message on a topic without waiting for a response.""" LOG.debug(_('Making asynchronous cast on %s...'), topic) pack_context(msg, context) - with ConnectionContext(connection_pool) as conn: + with ConnectionContext(conf, connection_pool) as conn: conn.topic_send(topic, msg) -def fanout_cast(context, topic, msg, connection_pool): +def fanout_cast(conf, context, topic, msg, connection_pool): """Sends a message on a fanout exchange without waiting for a response.""" LOG.debug(_('Making asynchronous fanout cast...')) pack_context(msg, context) - with ConnectionContext(connection_pool) as conn: + with ConnectionContext(conf, connection_pool) as conn: conn.fanout_send(topic, msg) -def cast_to_server(context, server_params, topic, msg, connection_pool): +def cast_to_server(conf, context, server_params, topic, msg, connection_pool): """Sends a message on a topic to a specific server.""" pack_context(msg, context) - with ConnectionContext(connection_pool, pooled=False, + with ConnectionContext(conf, connection_pool, pooled=False, server_params=server_params) as conn: conn.topic_send(topic, msg) -def fanout_cast_to_server(context, server_params, topic, msg, +def fanout_cast_to_server(conf, context, server_params, topic, msg, connection_pool): """Sends a message on a fanout exchange to a specific server.""" pack_context(msg, context) - with ConnectionContext(connection_pool, pooled=False, + with ConnectionContext(conf, connection_pool, pooled=False, server_params=server_params) as conn: conn.fanout_send(topic, msg) -def notify(context, topic, msg, connection_pool): +def notify(conf, context, topic, msg, connection_pool): """Sends a notification event on a topic.""" event_type = msg.get('event_type') LOG.debug(_('Sending %(event_type)s on %(topic)s'), locals()) pack_context(msg, context) - with ConnectionContext(connection_pool) as conn: + with ConnectionContext(conf, connection_pool) as conn: conn.notify_send(topic, msg) def cleanup(connection_pool): - connection_pool.empty() + if connection_pool: + connection_pool.empty() diff --git a/nova/rpc/common.py b/nova/rpc/common.py index a2975e9a57f5..0b9eebf0fa54 100644 --- a/nova/rpc/common.py +++ b/nova/rpc/common.py @@ -22,7 +22,6 @@ import sys import traceback from nova import exception -from nova import flags from nova import log as logging from nova.openstack.common import cfg from nova import utils @@ -30,25 +29,6 @@ from nova import utils LOG = logging.getLogger(__name__) -rpc_opts = [ - cfg.IntOpt('rpc_thread_pool_size', - default=64, - help='Size of RPC thread pool'), - cfg.IntOpt('rpc_conn_pool_size', - default=30, - help='Size of RPC connection pool'), - cfg.IntOpt('rpc_response_timeout', - default=60, - help='Seconds to wait for a response from call or multicall'), - cfg.IntOpt('allowed_rpc_exception_modules', - default=['nova.exception'], - help='Modules of exceptions that are permitted to be recreated' - 'upon receiving exception data from an rpc call.'), - ] - -flags.FLAGS.register_opts(rpc_opts) -FLAGS = flags.FLAGS - class RemoteError(exception.NovaException): """Signifies that a remote class has raised an exception. @@ -95,7 +75,7 @@ class Connection(object): """ raise NotImplementedError() - def create_consumer(self, topic, proxy, fanout=False): + def create_consumer(self, conf, topic, proxy, fanout=False): """Create a consumer on this connection. A consumer is associated with a message queue on the backend message @@ -104,6 +84,7 @@ class Connection(object): off of the queue will determine which method gets called on the proxy object. + :param conf: An openstack.common.cfg configuration object. :param topic: This is a name associated with what to consume from. Multiple instances of a service may consume from the same topic. For example, all instances of nova-compute consume @@ -197,7 +178,7 @@ def serialize_remote_exception(failure_info): return json_data -def deserialize_remote_exception(data): +def deserialize_remote_exception(conf, data): failure = utils.loads(str(data)) trace = failure.get('tb', []) @@ -207,7 +188,7 @@ def deserialize_remote_exception(data): # NOTE(ameade): We DO NOT want to allow just any module to be imported, in # order to prevent arbitrary code execution. - if not module in FLAGS.allowed_rpc_exception_modules: + if not module in conf.allowed_rpc_exception_modules: return RemoteError(name, failure.get('message'), trace) try: diff --git a/nova/rpc/impl_fake.py b/nova/rpc/impl_fake.py index 43aed15c2643..065cca699e7e 100644 --- a/nova/rpc/impl_fake.py +++ b/nova/rpc/impl_fake.py @@ -27,13 +27,10 @@ import traceback import eventlet from nova import context -from nova import flags from nova.rpc import common as rpc_common CONSUMERS = {} -FLAGS = flags.FLAGS - class RpcContext(context.RequestContext): def __init__(self, *args, **kwargs): @@ -116,7 +113,7 @@ class Connection(object): pass -def create_connection(new=True): +def create_connection(conf, new=True): """Create a connection""" return Connection() @@ -126,7 +123,7 @@ def check_serialize(msg): json.dumps(msg) -def multicall(context, topic, msg, timeout=None): +def multicall(conf, context, topic, msg, timeout=None): """Make a call that returns multiple times.""" check_serialize(msg) @@ -144,9 +141,9 @@ def multicall(context, topic, msg, timeout=None): return consumer.call(context, method, args, timeout) -def call(context, topic, msg, timeout=None): +def call(conf, context, topic, msg, timeout=None): """Sends a message on a topic and wait for a response.""" - rv = multicall(context, topic, msg, timeout) + rv = multicall(conf, context, topic, msg, timeout) # NOTE(vish): return the last result from the multicall rv = list(rv) if not rv: @@ -154,14 +151,14 @@ def call(context, topic, msg, timeout=None): return rv[-1] -def cast(context, topic, msg): +def cast(conf, context, topic, msg): try: - call(context, topic, msg) + call(conf, context, topic, msg) except Exception: pass -def notify(context, topic, msg): +def notify(conf, context, topic, msg): check_serialize(msg) @@ -169,7 +166,7 @@ def cleanup(): pass -def fanout_cast(context, topic, msg): +def fanout_cast(conf, context, topic, msg): """Cast to all consumers of a topic""" check_serialize(msg) method = msg.get('method') @@ -182,3 +179,7 @@ def fanout_cast(context, topic, msg): consumer.call(context, method, args, None) except Exception: pass + + +def register_opts(conf): + pass diff --git a/nova/rpc/impl_kombu.py b/nova/rpc/impl_kombu.py index 676aec57240e..6ff87646ca56 100644 --- a/nova/rpc/impl_kombu.py +++ b/nova/rpc/impl_kombu.py @@ -28,7 +28,6 @@ import kombu.entity import kombu.messaging import kombu.connection -from nova import flags from nova.openstack.common import cfg from nova.rpc import amqp as rpc_amqp from nova.rpc import common as rpc_common @@ -49,8 +48,6 @@ kombu_opts = [ '(valid only if SSL enabled)')), ] -FLAGS = flags.FLAGS -FLAGS.register_opts(kombu_opts) LOG = rpc_common.LOG @@ -126,7 +123,7 @@ class ConsumerBase(object): class DirectConsumer(ConsumerBase): """Queue/consumer class for 'direct'""" - def __init__(self, channel, msg_id, callback, tag, **kwargs): + def __init__(self, conf, channel, msg_id, callback, tag, **kwargs): """Init a 'direct' queue. 'channel' is the amqp channel to use @@ -159,7 +156,7 @@ class DirectConsumer(ConsumerBase): class TopicConsumer(ConsumerBase): """Consumer class for 'topic'""" - def __init__(self, channel, topic, callback, tag, **kwargs): + def __init__(self, conf, channel, topic, callback, tag, **kwargs): """Init a 'topic' queue. 'channel' is the amqp channel to use @@ -170,12 +167,12 @@ class TopicConsumer(ConsumerBase): Other kombu options may be passed """ # Default options - options = {'durable': FLAGS.rabbit_durable_queues, + options = {'durable': conf.rabbit_durable_queues, 'auto_delete': False, 'exclusive': False} options.update(kwargs) exchange = kombu.entity.Exchange( - name=FLAGS.control_exchange, + name=conf.control_exchange, type='topic', durable=options['durable'], auto_delete=options['auto_delete']) @@ -192,7 +189,7 @@ class TopicConsumer(ConsumerBase): class FanoutConsumer(ConsumerBase): """Consumer class for 'fanout'""" - def __init__(self, channel, topic, callback, tag, **kwargs): + def __init__(self, conf, channel, topic, callback, tag, **kwargs): """Init a 'fanout' queue. 'channel' is the amqp channel to use @@ -252,7 +249,7 @@ class Publisher(object): class DirectPublisher(Publisher): """Publisher class for 'direct'""" - def __init__(self, channel, msg_id, **kwargs): + def __init__(self, conf, channel, msg_id, **kwargs): """init a 'direct' publisher. Kombu options may be passed as keyword args to override defaults @@ -271,17 +268,17 @@ class DirectPublisher(Publisher): class TopicPublisher(Publisher): """Publisher class for 'topic'""" - def __init__(self, channel, topic, **kwargs): + def __init__(self, conf, channel, topic, **kwargs): """init a 'topic' publisher. Kombu options may be passed as keyword args to override defaults """ - options = {'durable': FLAGS.rabbit_durable_queues, + options = {'durable': conf.rabbit_durable_queues, 'auto_delete': False, 'exclusive': False} options.update(kwargs) super(TopicPublisher, self).__init__(channel, - FLAGS.control_exchange, + conf.control_exchange, topic, type='topic', **options) @@ -289,7 +286,7 @@ class TopicPublisher(Publisher): class FanoutPublisher(Publisher): """Publisher class for 'fanout'""" - def __init__(self, channel, topic, **kwargs): + def __init__(self, conf, channel, topic, **kwargs): """init a 'fanout' publisher. Kombu options may be passed as keyword args to override defaults @@ -308,9 +305,9 @@ class FanoutPublisher(Publisher): class NotifyPublisher(TopicPublisher): """Publisher class for 'notify'""" - def __init__(self, *args, **kwargs): - self.durable = kwargs.pop('durable', FLAGS.rabbit_durable_queues) - super(NotifyPublisher, self).__init__(*args, **kwargs) + def __init__(self, conf, channel, topic, **kwargs): + self.durable = kwargs.pop('durable', conf.rabbit_durable_queues) + super(NotifyPublisher, self).__init__(conf, channel, topic, **kwargs) def reconnect(self, channel): super(NotifyPublisher, self).reconnect(channel) @@ -329,15 +326,18 @@ class NotifyPublisher(TopicPublisher): class Connection(object): """Connection object.""" - def __init__(self, server_params=None): + pool = None + + def __init__(self, conf, server_params=None): self.consumers = [] self.consumer_thread = None - self.max_retries = FLAGS.rabbit_max_retries + self.conf = conf + self.max_retries = self.conf.rabbit_max_retries # Try forever? if self.max_retries <= 0: self.max_retries = None - self.interval_start = FLAGS.rabbit_retry_interval - self.interval_stepping = FLAGS.rabbit_retry_backoff + self.interval_start = self.conf.rabbit_retry_interval + self.interval_stepping = self.conf.rabbit_retry_backoff # max retry-interval = 30 seconds self.interval_max = 30 self.memory_transport = False @@ -353,21 +353,21 @@ class Connection(object): p_key = server_params_to_kombu_params.get(sp_key, sp_key) params[p_key] = value - params.setdefault('hostname', FLAGS.rabbit_host) - params.setdefault('port', FLAGS.rabbit_port) - params.setdefault('userid', FLAGS.rabbit_userid) - params.setdefault('password', FLAGS.rabbit_password) - params.setdefault('virtual_host', FLAGS.rabbit_virtual_host) + params.setdefault('hostname', self.conf.rabbit_host) + params.setdefault('port', self.conf.rabbit_port) + params.setdefault('userid', self.conf.rabbit_userid) + params.setdefault('password', self.conf.rabbit_password) + params.setdefault('virtual_host', self.conf.rabbit_virtual_host) self.params = params - if FLAGS.fake_rabbit: + if self.conf.fake_rabbit: self.params['transport'] = 'memory' self.memory_transport = True else: self.memory_transport = False - if FLAGS.rabbit_use_ssl: + if self.conf.rabbit_use_ssl: self.params['ssl'] = self._fetch_ssl_params() self.connection = None @@ -379,14 +379,14 @@ class Connection(object): ssl_params = dict() # http://docs.python.org/library/ssl.html - ssl.wrap_socket - if FLAGS.kombu_ssl_version: - ssl_params['ssl_version'] = FLAGS.kombu_ssl_version - if FLAGS.kombu_ssl_keyfile: - ssl_params['keyfile'] = FLAGS.kombu_ssl_keyfile - if FLAGS.kombu_ssl_certfile: - ssl_params['certfile'] = FLAGS.kombu_ssl_certfile - if FLAGS.kombu_ssl_ca_certs: - ssl_params['ca_certs'] = FLAGS.kombu_ssl_ca_certs + if self.conf.kombu_ssl_version: + ssl_params['ssl_version'] = self.conf.kombu_ssl_version + if self.conf.kombu_ssl_keyfile: + ssl_params['keyfile'] = self.conf.kombu_ssl_keyfile + if self.conf.kombu_ssl_certfile: + ssl_params['certfile'] = self.conf.kombu_ssl_certfile + if self.conf.kombu_ssl_ca_certs: + ssl_params['ca_certs'] = self.conf.kombu_ssl_ca_certs # We might want to allow variations in the # future with this? ssl_params['cert_reqs'] = ssl.CERT_REQUIRED @@ -534,7 +534,7 @@ class Connection(object): "%(err_str)s") % log_info) def _declare_consumer(): - consumer = consumer_cls(self.channel, topic, callback, + consumer = consumer_cls(self.conf, self.channel, topic, callback, self.consumer_num.next()) self.consumers.append(consumer) return consumer @@ -590,7 +590,7 @@ class Connection(object): "'%(topic)s': %(err_str)s") % log_info) def _publish(): - publisher = cls(self.channel, topic, **kwargs) + publisher = cls(self.conf, self.channel, topic, **kwargs) publisher.send(msg) self.ensure(_error_callback, _publish) @@ -648,58 +648,66 @@ class Connection(object): def create_consumer(self, topic, proxy, fanout=False): """Create a consumer that calls a method in a proxy object""" + proxy_cb = rpc_amqp.ProxyCallback(self.conf, proxy, + rpc_amqp.get_connection_pool(self, Connection)) + if fanout: - self.declare_fanout_consumer(topic, - rpc_amqp.ProxyCallback(proxy, Connection.pool)) + self.declare_fanout_consumer(topic, proxy_cb) else: - self.declare_topic_consumer(topic, - rpc_amqp.ProxyCallback(proxy, Connection.pool)) + self.declare_topic_consumer(topic, proxy_cb) -Connection.pool = rpc_amqp.Pool(connection_cls=Connection) - - -def create_connection(new=True): +def create_connection(conf, new=True): """Create a connection""" - return rpc_amqp.create_connection(new, Connection.pool) + return rpc_amqp.create_connection(conf, new, + rpc_amqp.get_connection_pool(conf, Connection)) -def multicall(context, topic, msg, timeout=None): +def multicall(conf, context, topic, msg, timeout=None): """Make a call that returns multiple times.""" - return rpc_amqp.multicall(context, topic, msg, timeout, Connection.pool) + return rpc_amqp.multicall(conf, context, topic, msg, timeout, + rpc_amqp.get_connection_pool(conf, Connection)) -def call(context, topic, msg, timeout=None): +def call(conf, context, topic, msg, timeout=None): """Sends a message on a topic and wait for a response.""" - return rpc_amqp.call(context, topic, msg, timeout, Connection.pool) + return rpc_amqp.call(conf, context, topic, msg, timeout, + rpc_amqp.get_connection_pool(conf, Connection)) -def cast(context, topic, msg): +def cast(conf, context, topic, msg): """Sends a message on a topic without waiting for a response.""" - return rpc_amqp.cast(context, topic, msg, Connection.pool) + return rpc_amqp.cast(conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) -def fanout_cast(context, topic, msg): +def fanout_cast(conf, context, topic, msg): """Sends a message on a fanout exchange without waiting for a response.""" - return rpc_amqp.fanout_cast(context, topic, msg, Connection.pool) + return rpc_amqp.fanout_cast(conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) -def cast_to_server(context, server_params, topic, msg): +def cast_to_server(conf, context, server_params, topic, msg): """Sends a message on a topic to a specific server.""" - return rpc_amqp.cast_to_server(context, server_params, topic, msg, - Connection.pool) + return rpc_amqp.cast_to_server(conf, context, server_params, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) -def fanout_cast_to_server(context, server_params, topic, msg): +def fanout_cast_to_server(conf, context, server_params, topic, msg): """Sends a message on a fanout exchange to a specific server.""" - return rpc_amqp.cast_to_server(context, server_params, topic, msg, - Connection.pool) + return rpc_amqp.cast_to_server(conf, context, server_params, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) -def notify(context, topic, msg): +def notify(conf, context, topic, msg): """Sends a notification event on a topic.""" - return rpc_amqp.notify(context, topic, msg, Connection.pool) + return rpc_amqp.notify(conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) def cleanup(): return rpc_amqp.cleanup(Connection.pool) + + +def register_opts(conf): + conf.register_opts(kombu_opts) diff --git a/nova/rpc/impl_qpid.py b/nova/rpc/impl_qpid.py index c275246b0fd6..37bb62430264 100644 --- a/nova/rpc/impl_qpid.py +++ b/nova/rpc/impl_qpid.py @@ -25,7 +25,6 @@ import greenlet import qpid.messaging import qpid.messaging.exceptions -from nova import flags from nova import log as logging from nova.openstack.common import cfg from nova.rpc import amqp as rpc_amqp @@ -78,9 +77,6 @@ qpid_opts = [ help='Disable Nagle algorithm'), ] -FLAGS = flags.FLAGS -FLAGS.register_opts(qpid_opts) - class ConsumerBase(object): """Consumer base class.""" @@ -147,7 +143,7 @@ class ConsumerBase(object): class DirectConsumer(ConsumerBase): """Queue/consumer class for 'direct'""" - def __init__(self, session, msg_id, callback): + def __init__(self, conf, session, msg_id, callback): """Init a 'direct' queue. 'session' is the amqp session to use @@ -165,7 +161,7 @@ class DirectConsumer(ConsumerBase): class TopicConsumer(ConsumerBase): """Consumer class for 'topic'""" - def __init__(self, session, topic, callback): + def __init__(self, conf, session, topic, callback): """Init a 'topic' queue. 'session' is the amqp session to use @@ -174,14 +170,14 @@ class TopicConsumer(ConsumerBase): """ super(TopicConsumer, self).__init__(session, callback, - "%s/%s" % (FLAGS.control_exchange, topic), {}, + "%s/%s" % (conf.control_exchange, topic), {}, topic, {}) class FanoutConsumer(ConsumerBase): """Consumer class for 'fanout'""" - def __init__(self, session, topic, callback): + def __init__(self, conf, session, topic, callback): """Init a 'fanout' queue. 'session' is the amqp session to use @@ -236,7 +232,7 @@ class Publisher(object): class DirectPublisher(Publisher): """Publisher class for 'direct'""" - def __init__(self, session, msg_id): + def __init__(self, conf, session, msg_id): """Init a 'direct' publisher.""" super(DirectPublisher, self).__init__(session, msg_id, {"type": "Direct"}) @@ -244,16 +240,16 @@ class DirectPublisher(Publisher): class TopicPublisher(Publisher): """Publisher class for 'topic'""" - def __init__(self, session, topic): + def __init__(self, conf, session, topic): """init a 'topic' publisher. """ super(TopicPublisher, self).__init__(session, - "%s/%s" % (FLAGS.control_exchange, topic)) + "%s/%s" % (conf.control_exchange, topic)) class FanoutPublisher(Publisher): """Publisher class for 'fanout'""" - def __init__(self, session, topic): + def __init__(self, conf, session, topic): """init a 'fanout' publisher. """ super(FanoutPublisher, self).__init__(session, @@ -262,29 +258,32 @@ class FanoutPublisher(Publisher): class NotifyPublisher(Publisher): """Publisher class for notifications""" - def __init__(self, session, topic): + def __init__(self, conf, session, topic): """init a 'topic' publisher. """ super(NotifyPublisher, self).__init__(session, - "%s/%s" % (FLAGS.control_exchange, topic), + "%s/%s" % (conf.control_exchange, topic), {"durable": True}) class Connection(object): """Connection object.""" - def __init__(self, server_params=None): + pool = None + + def __init__(self, conf, server_params=None): self.session = None self.consumers = {} self.consumer_thread = None + self.conf = conf if server_params is None: server_params = {} - default_params = dict(hostname=FLAGS.qpid_hostname, - port=FLAGS.qpid_port, - username=FLAGS.qpid_username, - password=FLAGS.qpid_password) + default_params = dict(hostname=self.conf.qpid_hostname, + port=self.conf.qpid_port, + username=self.conf.qpid_username, + password=self.conf.qpid_password) params = server_params for key in default_params.keys(): @@ -298,23 +297,25 @@ class Connection(object): # before we call open self.connection.username = params['username'] self.connection.password = params['password'] - self.connection.sasl_mechanisms = FLAGS.qpid_sasl_mechanisms - self.connection.reconnect = FLAGS.qpid_reconnect - if FLAGS.qpid_reconnect_timeout: - self.connection.reconnect_timeout = FLAGS.qpid_reconnect_timeout - if FLAGS.qpid_reconnect_limit: - self.connection.reconnect_limit = FLAGS.qpid_reconnect_limit - if FLAGS.qpid_reconnect_interval_max: + self.connection.sasl_mechanisms = self.conf.qpid_sasl_mechanisms + self.connection.reconnect = self.conf.qpid_reconnect + if self.conf.qpid_reconnect_timeout: + self.connection.reconnect_timeout = ( + self.conf.qpid_reconnect_timeout) + if self.conf.qpid_reconnect_limit: + self.connection.reconnect_limit = self.conf.qpid_reconnect_limit + if self.conf.qpid_reconnect_interval_max: self.connection.reconnect_interval_max = ( - FLAGS.qpid_reconnect_interval_max) - if FLAGS.qpid_reconnect_interval_min: + self.conf.qpid_reconnect_interval_max) + if self.conf.qpid_reconnect_interval_min: self.connection.reconnect_interval_min = ( - FLAGS.qpid_reconnect_interval_min) - if FLAGS.qpid_reconnect_interval: - self.connection.reconnect_interval = FLAGS.qpid_reconnect_interval - self.connection.hearbeat = FLAGS.qpid_heartbeat - self.connection.protocol = FLAGS.qpid_protocol - self.connection.tcp_nodelay = FLAGS.qpid_tcp_nodelay + self.conf.qpid_reconnect_interval_min) + if self.conf.qpid_reconnect_interval: + self.connection.reconnect_interval = ( + self.conf.qpid_reconnect_interval) + self.connection.hearbeat = self.conf.qpid_heartbeat + self.connection.protocol = self.conf.qpid_protocol + self.connection.tcp_nodelay = self.conf.qpid_tcp_nodelay # Open is part of reconnect - # NOTE(WGH) not sure we need this with the reconnect flags @@ -339,7 +340,7 @@ class Connection(object): self.connection.open() except qpid.messaging.exceptions.ConnectionError, e: LOG.error(_('Unable to connect to AMQP server: %s'), e) - time.sleep(FLAGS.qpid_reconnect_interval or 1) + time.sleep(self.conf.qpid_reconnect_interval or 1) else: break @@ -386,7 +387,7 @@ class Connection(object): "%(err_str)s") % log_info) def _declare_consumer(): - consumer = consumer_cls(self.session, topic, callback) + consumer = consumer_cls(self.conf, self.session, topic, callback) self._register_consumer(consumer) return consumer @@ -435,7 +436,7 @@ class Connection(object): "'%(topic)s': %(err_str)s") % log_info) def _publisher_send(): - publisher = cls(self.session, topic) + publisher = cls(self.conf, self.session, topic) publisher.send(msg) return self.ensure(_connect_error, _publisher_send) @@ -493,60 +494,70 @@ class Connection(object): def create_consumer(self, topic, proxy, fanout=False): """Create a consumer that calls a method in a proxy object""" + proxy_cb = rpc_amqp.ProxyCallback(self.conf, proxy, + rpc_amqp.get_connection_pool(self, Connection)) + if fanout: - consumer = FanoutConsumer(self.session, topic, - rpc_amqp.ProxyCallback(proxy, Connection.pool)) + consumer = FanoutConsumer(self.conf, self.session, topic, proxy_cb) else: - consumer = TopicConsumer(self.session, topic, - rpc_amqp.ProxyCallback(proxy, Connection.pool)) + consumer = TopicConsumer(self.conf, self.session, topic, proxy_cb) + self._register_consumer(consumer) + return consumer -Connection.pool = rpc_amqp.Pool(connection_cls=Connection) - - -def create_connection(new=True): +def create_connection(conf, new=True): """Create a connection""" - return rpc_amqp.create_connection(new, Connection.pool) + return rpc_amqp.create_connection(conf, new, + rpc_amqp.get_connection_pool(conf, Connection)) -def multicall(context, topic, msg, timeout=None): +def multicall(conf, context, topic, msg, timeout=None): """Make a call that returns multiple times.""" - return rpc_amqp.multicall(context, topic, msg, timeout, Connection.pool) + return rpc_amqp.multicall(conf, context, topic, msg, timeout, + rpc_amqp.get_connection_pool(conf, Connection)) -def call(context, topic, msg, timeout=None): +def call(conf, context, topic, msg, timeout=None): """Sends a message on a topic and wait for a response.""" - return rpc_amqp.call(context, topic, msg, timeout, Connection.pool) + return rpc_amqp.call(conf, context, topic, msg, timeout, + rpc_amqp.get_connection_pool(conf, Connection)) -def cast(context, topic, msg): +def cast(conf, context, topic, msg): """Sends a message on a topic without waiting for a response.""" - return rpc_amqp.cast(context, topic, msg, Connection.pool) + return rpc_amqp.cast(conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) -def fanout_cast(context, topic, msg): +def fanout_cast(conf, context, topic, msg): """Sends a message on a fanout exchange without waiting for a response.""" - return rpc_amqp.fanout_cast(context, topic, msg, Connection.pool) + return rpc_amqp.fanout_cast(conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) -def cast_to_server(context, server_params, topic, msg): +def cast_to_server(conf, context, server_params, topic, msg): """Sends a message on a topic to a specific server.""" - return rpc_amqp.cast_to_server(context, server_params, topic, msg, - Connection.pool) + return rpc_amqp.cast_to_server(conf, context, server_params, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) -def fanout_cast_to_server(context, server_params, topic, msg): +def fanout_cast_to_server(conf, context, server_params, topic, msg): """Sends a message on a fanout exchange to a specific server.""" - return rpc_amqp.fanout_cast_to_server(context, server_params, topic, - msg, Connection.pool) + return rpc_amqp.fanout_cast_to_server(conf, context, server_params, topic, + msg, rpc_amqp.get_connection_pool(conf, Connection)) -def notify(context, topic, msg): +def notify(conf, context, topic, msg): """Sends a notification event on a topic.""" - return rpc_amqp.notify(context, topic, msg, Connection.pool) + return rpc_amqp.notify(conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) def cleanup(): return rpc_amqp.cleanup(Connection.pool) + + +def register_opts(conf): + conf.register_opts(qpid_opts) diff --git a/nova/service.py b/nova/service.py index a351406fe940..c9817bbe8333 100644 --- a/nova/service.py +++ b/nova/service.py @@ -177,6 +177,7 @@ class Service(object): LOG.audit(_('Starting %(topic)s node (version %(vcs_string)s)'), {'topic': self.topic, 'vcs_string': vcs_string}) utils.cleanup_file_locks() + rpc.register_opts(FLAGS) self.manager.init_host() self.model_disconnected = False ctxt = context.get_admin_context() @@ -393,6 +394,7 @@ class WSGIService(object): """ utils.cleanup_file_locks() + rpc.register_opts(FLAGS) if self.manager: self.manager.init_host() self.server.start() diff --git a/nova/tests/__init__.py b/nova/tests/__init__.py index fee29da6cc80..0e33cd7ace66 100644 --- a/nova/tests/__init__.py +++ b/nova/tests/__init__.py @@ -59,6 +59,9 @@ def reset_db(): def setup(): import mox # Fail fast if you don't have mox. Workaround for bug 810424 + from nova import rpc # Register rpc_backend before fake_flags sets it + FLAGS.register_opts(rpc.rpc_opts) + from nova import context from nova import db from nova.db import migration diff --git a/nova/tests/rpc/common.py b/nova/tests/rpc/common.py index 3524e5682800..d04f0561f868 100644 --- a/nova/tests/rpc/common.py +++ b/nova/tests/rpc/common.py @@ -26,19 +26,21 @@ import nose from nova import context from nova import exception +from nova import flags from nova import log as logging from nova.rpc import amqp as rpc_amqp from nova.rpc import common as rpc_common from nova import test +FLAGS = flags.FLAGS LOG = logging.getLogger(__name__) class BaseRpcTestCase(test.TestCase): def setUp(self, supports_timeouts=True): super(BaseRpcTestCase, self).setUp() - self.conn = self.rpc.create_connection(True) + self.conn = self.rpc.create_connection(FLAGS, True) self.receiver = TestReceiver() self.conn.create_consumer('test', self.receiver, False) self.conn.consume_in_thread() @@ -51,20 +53,20 @@ class BaseRpcTestCase(test.TestCase): def test_call_succeed(self): value = 42 - result = self.rpc.call(self.context, 'test', {"method": "echo", - "args": {"value": value}}) + result = self.rpc.call(FLAGS, self.context, 'test', + {"method": "echo", "args": {"value": value}}) self.assertEqual(value, result) def test_call_succeed_despite_multiple_returns_yield(self): value = 42 - result = self.rpc.call(self.context, 'test', + result = self.rpc.call(FLAGS, self.context, 'test', {"method": "echo_three_times_yield", "args": {"value": value}}) self.assertEqual(value + 2, result) def test_multicall_succeed_once(self): value = 42 - result = self.rpc.multicall(self.context, + result = self.rpc.multicall(FLAGS, self.context, 'test', {"method": "echo", "args": {"value": value}}) @@ -75,7 +77,7 @@ class BaseRpcTestCase(test.TestCase): def test_multicall_three_nones(self): value = 42 - result = self.rpc.multicall(self.context, + result = self.rpc.multicall(FLAGS, self.context, 'test', {"method": "multicall_three_nones", "args": {"value": value}}) @@ -86,7 +88,7 @@ class BaseRpcTestCase(test.TestCase): def test_multicall_succeed_three_times_yield(self): value = 42 - result = self.rpc.multicall(self.context, + result = self.rpc.multicall(FLAGS, self.context, 'test', {"method": "echo_three_times_yield", "args": {"value": value}}) @@ -96,7 +98,7 @@ class BaseRpcTestCase(test.TestCase): def test_context_passed(self): """Makes sure a context is passed through rpc call.""" value = 42 - result = self.rpc.call(self.context, + result = self.rpc.call(FLAGS, self.context, 'test', {"method": "context", "args": {"value": value}}) self.assertEqual(self.context.to_dict(), result) @@ -112,7 +114,7 @@ class BaseRpcTestCase(test.TestCase): # TODO(comstud): # so, it will replay the context and use the same REQID? # that's bizarre. - ret = self.rpc.call(context, + ret = self.rpc.call(FLAGS, context, queue, {"method": "echo", "args": {"value": value}}) @@ -120,11 +122,11 @@ class BaseRpcTestCase(test.TestCase): return value nested = Nested() - conn = self.rpc.create_connection(True) + conn = self.rpc.create_connection(FLAGS, True) conn.create_consumer('nested', nested, False) conn.consume_in_thread() value = 42 - result = self.rpc.call(self.context, + result = self.rpc.call(FLAGS, self.context, 'nested', {"method": "echo", "args": {"queue": "test", "value": value}}) @@ -139,12 +141,12 @@ class BaseRpcTestCase(test.TestCase): value = 42 self.assertRaises(rpc_common.Timeout, self.rpc.call, - self.context, + FLAGS, self.context, 'test', {"method": "block", "args": {"value": value}}, timeout=1) try: - self.rpc.call(self.context, + self.rpc.call(FLAGS, self.context, 'test', {"method": "block", "args": {"value": value}}, @@ -169,8 +171,8 @@ class BaseRpcAMQPTestCase(BaseRpcTestCase): self.stubs.Set(rpc_amqp, 'unpack_context', fake_unpack_context) value = 41 - self.rpc.cast(self.context, 'test', {"method": "echo", - "args": {"value": value}}) + self.rpc.cast(FLAGS, self.context, 'test', + {"method": "echo", "args": {"value": value}}) # Wait for the cast to complete. for x in xrange(50): @@ -185,7 +187,7 @@ class BaseRpcAMQPTestCase(BaseRpcTestCase): self.stubs.Set(rpc_amqp, 'unpack_context', orig_unpack) value = 42 - result = self.rpc.call(self.context, 'test', + result = self.rpc.call(FLAGS, self.context, 'test', {"method": "echo", "args": {"value": value}}) self.assertEqual(value, result) diff --git a/nova/tests/rpc/test_common.py b/nova/tests/rpc/test_common.py index 6220bd01a134..4b505db97484 100644 --- a/nova/tests/rpc/test_common.py +++ b/nova/tests/rpc/test_common.py @@ -93,7 +93,7 @@ class RpcCommonTestCase(test.TestCase): } serialized = json.dumps(failure) - after_exc = rpc_common.deserialize_remote_exception(serialized) + after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized) self.assertTrue(isinstance(after_exc, exception.NovaException)) self.assertTrue('test message' in unicode(after_exc)) #assure the traceback was added @@ -108,7 +108,7 @@ class RpcCommonTestCase(test.TestCase): } serialized = json.dumps(failure) - after_exc = rpc_common.deserialize_remote_exception(serialized) + after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized) self.assertTrue(isinstance(after_exc, rpc_common.RemoteError)) def test_deserialize_remote_exception_user_defined_exception(self): @@ -121,7 +121,7 @@ class RpcCommonTestCase(test.TestCase): } serialized = json.dumps(failure) - after_exc = rpc_common.deserialize_remote_exception(serialized) + after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized) self.assertTrue(isinstance(after_exc, FakeUserDefinedException)) #assure the traceback was added self.assertTrue('raise FakeUserDefinedException' in unicode(after_exc)) @@ -141,7 +141,7 @@ class RpcCommonTestCase(test.TestCase): } serialized = json.dumps(failure) - after_exc = rpc_common.deserialize_remote_exception(serialized) + after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized) self.assertTrue(isinstance(after_exc, rpc_common.RemoteError)) #assure the traceback was added self.assertTrue('raise FakeIDontExistException' in unicode(after_exc)) diff --git a/nova/tests/rpc/test_kombu.py b/nova/tests/rpc/test_kombu.py index 966cb3a6905b..a66857567f72 100644 --- a/nova/tests/rpc/test_kombu.py +++ b/nova/tests/rpc/test_kombu.py @@ -53,6 +53,7 @@ def _raise_exc_stub(stubs, times, obj, method, exc_msg, class RpcKombuTestCase(common.BaseRpcAMQPTestCase): def setUp(self): self.rpc = impl_kombu + impl_kombu.register_opts(FLAGS) super(RpcKombuTestCase, self).setUp() def tearDown(self): @@ -61,10 +62,10 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): def test_reusing_connection(self): """Test that reusing a connection returns same one.""" - conn_context = self.rpc.create_connection(new=False) + conn_context = self.rpc.create_connection(FLAGS, new=False) conn1 = conn_context.connection conn_context.close() - conn_context = self.rpc.create_connection(new=False) + conn_context = self.rpc.create_connection(FLAGS, new=False) conn2 = conn_context.connection conn_context.close() self.assertEqual(conn1, conn2) @@ -72,7 +73,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): def test_topic_send_receive(self): """Test sending to a topic exchange/queue""" - conn = self.rpc.create_connection() + conn = self.rpc.create_connection(FLAGS) message = 'topic test message' self.received_message = None @@ -89,7 +90,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): def test_direct_send_receive(self): """Test sending to a direct exchange/queue""" - conn = self.rpc.create_connection() + conn = self.rpc.create_connection(FLAGS) message = 'direct test message' self.received_message = None @@ -123,10 +124,10 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): def topic_send(_context, topic, msg): pass - MyConnection.pool = rpc_amqp.Pool(connection_cls=MyConnection) + MyConnection.pool = rpc_amqp.Pool(FLAGS, MyConnection) self.stubs.Set(impl_kombu, 'Connection', MyConnection) - impl_kombu.cast(ctxt, 'fake_topic', {'msg': 'fake'}) + impl_kombu.cast(FLAGS, ctxt, 'fake_topic', {'msg': 'fake'}) def test_cast_to_server_uses_server_params(self): """Test kombu rpc.cast""" @@ -153,10 +154,10 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): def topic_send(_context, topic, msg): pass - MyConnection.pool = rpc_amqp.Pool(connection_cls=MyConnection) + MyConnection.pool = rpc_amqp.Pool(FLAGS, MyConnection) self.stubs.Set(impl_kombu, 'Connection', MyConnection) - impl_kombu.cast_to_server(ctxt, server_params, + impl_kombu.cast_to_server(FLAGS, ctxt, server_params, 'fake_topic', {'msg': 'fake'}) @test.skip_test("kombu memory transport seems buggy with fanout queues " @@ -192,7 +193,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectConsumer, '__init__', 'foo timeout foo') - conn = self.rpc.Connection() + conn = self.rpc.Connection(FLAGS) result = conn.declare_consumer(self.rpc.DirectConsumer, 'test_topic', None) @@ -206,7 +207,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectConsumer, '__init__', 'meow') - conn = self.rpc.Connection() + conn = self.rpc.Connection(FLAGS) conn.connection_errors = (MyException, ) result = conn.declare_consumer(self.rpc.DirectConsumer, @@ -220,7 +221,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectConsumer, '__init__', 'Socket closed', exc_class=IOError) - conn = self.rpc.Connection() + conn = self.rpc.Connection(FLAGS) result = conn.declare_consumer(self.rpc.DirectConsumer, 'test_topic', None) @@ -234,7 +235,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectPublisher, '__init__', 'foo timeout foo') - conn = self.rpc.Connection() + conn = self.rpc.Connection(FLAGS) conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg') self.assertEqual(info['called'], 3) @@ -243,7 +244,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectPublisher, 'send', 'foo timeout foo') - conn = self.rpc.Connection() + conn = self.rpc.Connection(FLAGS) conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg') self.assertEqual(info['called'], 3) @@ -256,7 +257,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectPublisher, '__init__', 'meow') - conn = self.rpc.Connection() + conn = self.rpc.Connection(FLAGS) conn.connection_errors = (MyException, ) conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg') @@ -267,7 +268,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectPublisher, 'send', 'meow') - conn = self.rpc.Connection() + conn = self.rpc.Connection(FLAGS) conn.connection_errors = (MyException, ) conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg') @@ -275,7 +276,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): self.assertEqual(info['called'], 2) def test_iterconsume_errors_will_reconnect(self): - conn = self.rpc.Connection() + conn = self.rpc.Connection(FLAGS) message = 'reconnect test message' self.received_message = None @@ -305,12 +306,13 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): value = "This is the exception message" self.assertRaises(NotImplementedError, self.rpc.call, + FLAGS, self.context, 'test', {"method": "fail", "args": {"value": value}}) try: - self.rpc.call(self.context, + self.rpc.call(FLAGS, self.context, 'test', {"method": "fail", "args": {"value": value}}) @@ -330,12 +332,13 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): value = "This is the exception message" self.assertRaises(exception.ConvertedException, self.rpc.call, + FLAGS, self.context, 'test', {"method": "fail_converted", "args": {"value": value}}) try: - self.rpc.call(self.context, + self.rpc.call(FLAGS, self.context, 'test', {"method": "fail_converted", "args": {"value": value}}) diff --git a/nova/tests/rpc/test_kombu_ssl.py b/nova/tests/rpc/test_kombu_ssl.py index 2a10835cc8e3..fb5a32eb88b8 100644 --- a/nova/tests/rpc/test_kombu_ssl.py +++ b/nova/tests/rpc/test_kombu_ssl.py @@ -19,6 +19,7 @@ Unit Tests for remote procedure calls using kombu + ssl """ +from nova import flags from nova import test from nova.rpc import impl_kombu @@ -28,11 +29,14 @@ SSL_CERT = "/tmp/cert.blah.blah" SSL_CA_CERT = "/tmp/cert.ca.blah.blah" SSL_KEYFILE = "/tmp/keyfile.blah.blah" +FLAGS = flags.FLAGS + class RpcKombuSslTestCase(test.TestCase): def setUp(self): super(RpcKombuSslTestCase, self).setUp() + impl_kombu.register_opts(FLAGS) self.flags(kombu_ssl_keyfile=SSL_KEYFILE, kombu_ssl_ca_certs=SSL_CA_CERT, kombu_ssl_certfile=SSL_CERT, @@ -41,7 +45,7 @@ class RpcKombuSslTestCase(test.TestCase): def test_ssl_on_extended(self): rpc = impl_kombu - conn = rpc.create_connection(True) + conn = rpc.create_connection(FLAGS, True) c = conn.connection #This might be kombu version dependent... #Since we are now peaking into the internals of kombu... diff --git a/nova/tests/rpc/test_qpid.py b/nova/tests/rpc/test_qpid.py index 616abb1c90e0..7959f3783ba0 100644 --- a/nova/tests/rpc/test_qpid.py +++ b/nova/tests/rpc/test_qpid.py @@ -23,6 +23,7 @@ Unit Tests for remote procedure calls using qpid import mox from nova import context +from nova import flags from nova import log as logging from nova.rpc import amqp as rpc_amqp from nova import test @@ -35,6 +36,7 @@ except ImportError: impl_qpid = None +FLAGS = flags.FLAGS LOG = logging.getLogger(__name__) @@ -64,6 +66,7 @@ class RpcQpidTestCase(test.TestCase): self.mock_receiver = None if qpid: + impl_qpid.register_opts(FLAGS) self.orig_connection = qpid.messaging.Connection self.orig_session = qpid.messaging.Session self.orig_sender = qpid.messaging.Sender @@ -98,7 +101,7 @@ class RpcQpidTestCase(test.TestCase): self.mox.ReplayAll() - connection = impl_qpid.create_connection() + connection = impl_qpid.create_connection(FLAGS) connection.close() def _test_create_consumer(self, fanout): @@ -130,7 +133,7 @@ class RpcQpidTestCase(test.TestCase): self.mox.ReplayAll() - connection = impl_qpid.create_connection() + connection = impl_qpid.create_connection(FLAGS) connection.create_consumer("impl_qpid_test", lambda *_x, **_y: None, fanout) @@ -176,11 +179,11 @@ class RpcQpidTestCase(test.TestCase): try: ctx = context.RequestContext("user", "project") - args = [ctx, "impl_qpid_test", + args = [FLAGS, ctx, "impl_qpid_test", {"method": "test_method", "args": {}}] if server_params: - args.insert(1, server_params) + args.insert(2, server_params) if fanout: method = impl_qpid.fanout_cast_to_server else: @@ -218,7 +221,7 @@ class RpcQpidTestCase(test.TestCase): server_params['hostname'] + ':' + str(server_params['port'])) - MyConnection.pool = rpc_amqp.Pool(connection_cls=MyConnection) + MyConnection.pool = rpc_amqp.Pool(FLAGS, MyConnection) self.stubs.Set(impl_qpid, 'Connection', MyConnection) @test.skip_if(qpid is None, "Test requires qpid") @@ -295,7 +298,7 @@ class RpcQpidTestCase(test.TestCase): else: method = impl_qpid.call - res = method(ctx, "impl_qpid_test", + res = method(FLAGS, ctx, "impl_qpid_test", {"method": "test_method", "args": {}}) if multi: