Refactor nova.rpc config handling.

This patch does a couple of things:

1) Remove the dependency of nova.rpc on nova.flags.  This is a step
toward decoupling nova.rpc from the rest of nova so that it can be moved
to openstack-common.

2) Refactor nova.rpc so that a configuration object is passed around as
needed instead of depending on nova.flags.FLAGS.

This was done by avoiding changing the nova.rpc API as much as possible
so that existing usage of nova.rpc would not have to be touched.  So,
instead, a config object gets registered, cached, and then passed into
the rpc implementations as needed.  Getting rid of this global config
reference in nova.rpc will require changing the public API and I wanted
to avoid doing that until there was a better reason than this.

Change-Id: I9a7fa67bd12ced877c83e48e31f5ef7263be6815
This commit is contained in:
Russell Bryant 2012-04-09 18:08:50 -04:00
parent ca4aee67e3
commit 0b11668e64
16 changed files with 319 additions and 259 deletions

View File

@ -74,6 +74,7 @@ if __name__ == '__main__':
utils.default_flagfile() utils.default_flagfile()
args = flags.FLAGS(sys.argv) args = flags.FLAGS(sys.argv)
logging.setup() logging.setup()
rpc.register_opts(flags.FLAGS)
delete_queues(args[1:]) delete_queues(args[1:])
if FLAGS.delete_exchange: if FLAGS.delete_exchange:
delete_exchange(FLAGS.control_exchange) delete_exchange(FLAGS.control_exchange)

View File

@ -99,6 +99,8 @@ def main():
argv = FLAGS(sys.argv) argv = FLAGS(sys.argv)
logging.setup() logging.setup()
rpc.register_opts(FLAGS)
if int(os.environ.get('TESTING', '0')): if int(os.environ.get('TESTING', '0')):
from nova.tests import fake_flags from nova.tests import fake_flags

View File

@ -1671,6 +1671,8 @@ def main():
except Exception: except Exception:
print 'sudo failed, continuing as if nothing happened' print 'sudo failed, continuing as if nothing happened'
rpc.register_opts(FLAGS)
try: try:
argv = FLAGS(sys.argv) argv = FLAGS(sys.argv)
logging.setup() logging.setup()
@ -1679,7 +1681,6 @@ def main():
print _('Please re-run nova-manage as root.') print _('Please re-run nova-manage as root.')
sys.exit(2) sys.exit(2)
raise raise
script_name = argv.pop(0) script_name = argv.pop(0)
if len(argv) < 1: if len(argv) < 1:
print _("\nOpenStack Nova version: %(version)s (%(vcs)s)\n") % \ print _("\nOpenStack Nova version: %(version)s (%(vcs)s)\n") % \

View File

@ -17,17 +17,37 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
from nova import flags
from nova.openstack.common import cfg from nova.openstack.common import cfg
from nova import utils from nova import utils
rpc_backend_opt = cfg.StrOpt('rpc_backend', rpc_opts = [
cfg.StrOpt('rpc_backend',
default='nova.rpc.impl_kombu', default='nova.rpc.impl_kombu',
help="The messaging module to use, defaults to kombu.") help="The messaging module to use, defaults to kombu."),
cfg.IntOpt('rpc_thread_pool_size',
default=64,
help='Size of RPC thread pool'),
cfg.IntOpt('rpc_conn_pool_size',
default=30,
help='Size of RPC connection pool'),
cfg.IntOpt('rpc_response_timeout',
default=60,
help='Seconds to wait for a response from call or multicall'),
cfg.IntOpt('allowed_rpc_exception_modules',
default=['nova.exception'],
help='Modules of exceptions that are permitted to be recreated'
'upon receiving exception data from an rpc call.'),
]
FLAGS = flags.FLAGS _CONF = None
FLAGS.register_opt(rpc_backend_opt)
def register_opts(conf):
global _CONF
_CONF = conf
_CONF.register_opts(rpc_opts)
_get_impl().register_opts(_CONF)
def create_connection(new=True): def create_connection(new=True):
@ -43,7 +63,7 @@ def create_connection(new=True):
:returns: An instance of nova.rpc.common.Connection :returns: An instance of nova.rpc.common.Connection
""" """
return _get_impl().create_connection(new=new) return _get_impl().create_connection(_CONF, new=new)
def call(context, topic, msg, timeout=None): def call(context, topic, msg, timeout=None):
@ -65,7 +85,7 @@ def call(context, topic, msg, timeout=None):
:raises: nova.rpc.common.Timeout if a complete response is not received :raises: nova.rpc.common.Timeout if a complete response is not received
before the timeout is reached. before the timeout is reached.
""" """
return _get_impl().call(context, topic, msg, timeout) return _get_impl().call(_CONF, context, topic, msg, timeout)
def cast(context, topic, msg): def cast(context, topic, msg):
@ -82,7 +102,7 @@ def cast(context, topic, msg):
:returns: None :returns: None
""" """
return _get_impl().cast(context, topic, msg) return _get_impl().cast(_CONF, context, topic, msg)
def fanout_cast(context, topic, msg): def fanout_cast(context, topic, msg):
@ -102,7 +122,7 @@ def fanout_cast(context, topic, msg):
:returns: None :returns: None
""" """
return _get_impl().fanout_cast(context, topic, msg) return _get_impl().fanout_cast(_CONF, context, topic, msg)
def multicall(context, topic, msg, timeout=None): def multicall(context, topic, msg, timeout=None):
@ -131,7 +151,7 @@ def multicall(context, topic, msg, timeout=None):
:raises: nova.rpc.common.Timeout if a complete response is not received :raises: nova.rpc.common.Timeout if a complete response is not received
before the timeout is reached. before the timeout is reached.
""" """
return _get_impl().multicall(context, topic, msg, timeout) return _get_impl().multicall(_CONF, context, topic, msg, timeout)
def notify(context, topic, msg): def notify(context, topic, msg):
@ -144,7 +164,7 @@ def notify(context, topic, msg):
:returns: None :returns: None
""" """
return _get_impl().notify(context, topic, msg) return _get_impl().notify(_CONF, context, topic, msg)
def cleanup(): def cleanup():
@ -172,7 +192,8 @@ def cast_to_server(context, server_params, topic, msg):
:returns: None :returns: None
""" """
return _get_impl().cast_to_server(context, server_params, topic, msg) return _get_impl().cast_to_server(_CONF, context, server_params, topic,
msg)
def fanout_cast_to_server(context, server_params, topic, msg): def fanout_cast_to_server(context, server_params, topic, msg):
@ -187,16 +208,16 @@ def fanout_cast_to_server(context, server_params, topic, msg):
:returns: None :returns: None
""" """
return _get_impl().fanout_cast_to_server(context, server_params, topic, return _get_impl().fanout_cast_to_server(_CONF, context, server_params,
msg) topic, msg)
_RPCIMPL = None _RPCIMPL = None
def _get_impl(): def _get_impl():
"""Delay import of rpc_backend until FLAGS are loaded.""" """Delay import of rpc_backend until configuration is loaded."""
global _RPCIMPL global _RPCIMPL
if _RPCIMPL is None: if _RPCIMPL is None:
_RPCIMPL = utils.import_object(FLAGS.rpc_backend) _RPCIMPL = utils.import_object(_CONF.rpc_backend)
return _RPCIMPL return _RPCIMPL

View File

@ -31,10 +31,10 @@ import uuid
from eventlet import greenpool from eventlet import greenpool
from eventlet import pools from eventlet import pools
from eventlet import semaphore
from nova import context from nova import context
from nova import exception from nova import exception
from nova import flags
from nova import log as logging from nova import log as logging
from nova.openstack.common import local from nova.openstack.common import local
import nova.rpc.common as rpc_common import nova.rpc.common as rpc_common
@ -43,27 +43,36 @@ from nova import utils
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
FLAGS = flags.FLAGS
class Pool(pools.Pool): class Pool(pools.Pool):
"""Class that implements a Pool of Connections.""" """Class that implements a Pool of Connections."""
def __init__(self, *args, **kwargs): def __init__(self, conf, connection_cls, *args, **kwargs):
self.connection_cls = kwargs.pop("connection_cls", None) self.connection_cls = connection_cls
kwargs.setdefault("max_size", FLAGS.rpc_conn_pool_size) self.conf = conf
kwargs.setdefault("max_size", self.conf.rpc_conn_pool_size)
kwargs.setdefault("order_as_stack", True) kwargs.setdefault("order_as_stack", True)
super(Pool, self).__init__(*args, **kwargs) super(Pool, self).__init__(*args, **kwargs)
# TODO(comstud): Timeout connections not used in a while # TODO(comstud): Timeout connections not used in a while
def create(self): def create(self):
LOG.debug('Pool creating new connection') LOG.debug('Pool creating new connection')
return self.connection_cls() return self.connection_cls(self.conf)
def empty(self): def empty(self):
while self.free_items: while self.free_items:
self.get().close() self.get().close()
_pool_create_sem = semaphore.Semaphore()
def get_connection_pool(conf, connection_cls):
with _pool_create_sem:
# Make sure only one thread tries to create the connection pool.
if not connection_cls.pool:
connection_cls.pool = Pool(conf, connection_cls)
return connection_cls.pool
class ConnectionContext(rpc_common.Connection): class ConnectionContext(rpc_common.Connection):
"""The class that is actually returned to the caller of """The class that is actually returned to the caller of
create_connection(). This is a essentially a wrapper around create_connection(). This is a essentially a wrapper around
@ -75,14 +84,15 @@ class ConnectionContext(rpc_common.Connection):
the pool. the pool.
""" """
def __init__(self, connection_pool, pooled=True, server_params=None): def __init__(self, conf, connection_pool, pooled=True, server_params=None):
"""Create a new connection, or get one from the pool""" """Create a new connection, or get one from the pool"""
self.connection = None self.connection = None
self.conf = conf
self.connection_pool = connection_pool self.connection_pool = connection_pool
if pooled: if pooled:
self.connection = connection_pool.get() self.connection = connection_pool.get()
else: else:
self.connection = connection_pool.connection_cls( self.connection = connection_pool.connection_cls(conf,
server_params=server_params) server_params=server_params)
self.pooled = pooled self.pooled = pooled
@ -133,13 +143,14 @@ class ConnectionContext(rpc_common.Connection):
raise exception.InvalidRPCConnectionReuse() raise exception.InvalidRPCConnectionReuse()
def msg_reply(msg_id, connection_pool, reply=None, failure=None, ending=False): def msg_reply(conf, msg_id, connection_pool, reply=None, failure=None,
ending=False):
"""Sends a reply or an error on the channel signified by msg_id. """Sends a reply or an error on the channel signified by msg_id.
Failure should be a sys.exc_info() tuple. Failure should be a sys.exc_info() tuple.
""" """
with ConnectionContext(connection_pool) as conn: with ConnectionContext(conf, connection_pool) as conn:
if failure: if failure:
failure = rpc_common.serialize_remote_exception(failure) failure = rpc_common.serialize_remote_exception(failure)
@ -158,18 +169,19 @@ class RpcContext(context.RequestContext):
"""Context that supports replying to a rpc.call""" """Context that supports replying to a rpc.call"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.msg_id = kwargs.pop('msg_id', None) self.msg_id = kwargs.pop('msg_id', None)
self.conf = kwargs.pop('conf')
super(RpcContext, self).__init__(*args, **kwargs) super(RpcContext, self).__init__(*args, **kwargs)
def reply(self, reply=None, failure=None, ending=False, def reply(self, reply=None, failure=None, ending=False,
connection_pool=None): connection_pool=None):
if self.msg_id: if self.msg_id:
msg_reply(self.msg_id, connection_pool, reply, failure, msg_reply(self.conf, self.msg_id, connection_pool, reply, failure,
ending) ending)
if ending: if ending:
self.msg_id = None self.msg_id = None
def unpack_context(msg): def unpack_context(conf, msg):
"""Unpack context from msg.""" """Unpack context from msg."""
context_dict = {} context_dict = {}
for key in list(msg.keys()): for key in list(msg.keys()):
@ -180,6 +192,7 @@ def unpack_context(msg):
value = msg.pop(key) value = msg.pop(key)
context_dict[key[9:]] = value context_dict[key[9:]] = value
context_dict['msg_id'] = msg.pop('_msg_id', None) context_dict['msg_id'] = msg.pop('_msg_id', None)
context_dict['conf'] = conf
ctx = RpcContext.from_dict(context_dict) ctx = RpcContext.from_dict(context_dict)
rpc_common._safe_log(LOG.debug, _('unpacked context: %s'), ctx.to_dict()) rpc_common._safe_log(LOG.debug, _('unpacked context: %s'), ctx.to_dict())
return ctx return ctx
@ -202,10 +215,11 @@ def pack_context(msg, context):
class ProxyCallback(object): class ProxyCallback(object):
"""Calls methods on a proxy object based on method and args.""" """Calls methods on a proxy object based on method and args."""
def __init__(self, proxy, connection_pool): def __init__(self, conf, proxy, connection_pool):
self.proxy = proxy self.proxy = proxy
self.pool = greenpool.GreenPool(FLAGS.rpc_thread_pool_size) self.pool = greenpool.GreenPool(conf.rpc_thread_pool_size)
self.connection_pool = connection_pool self.connection_pool = connection_pool
self.conf = conf
def __call__(self, message_data): def __call__(self, message_data):
"""Consumer callback to call a method on a proxy object. """Consumer callback to call a method on a proxy object.
@ -225,7 +239,7 @@ class ProxyCallback(object):
if hasattr(local.store, 'context'): if hasattr(local.store, 'context'):
del local.store.context del local.store.context
rpc_common._safe_log(LOG.debug, _('received %s'), message_data) rpc_common._safe_log(LOG.debug, _('received %s'), message_data)
ctxt = unpack_context(message_data) ctxt = unpack_context(self.conf, message_data)
method = message_data.get('method') method = message_data.get('method')
args = message_data.get('args', {}) args = message_data.get('args', {})
if not method: if not method:
@ -262,13 +276,14 @@ class ProxyCallback(object):
class MulticallWaiter(object): class MulticallWaiter(object):
def __init__(self, connection, timeout): def __init__(self, conf, connection, timeout):
self._connection = connection self._connection = connection
self._iterator = connection.iterconsume( self._iterator = connection.iterconsume(
timeout=timeout or FLAGS.rpc_response_timeout) timeout=timeout or conf.rpc_response_timeout)
self._result = None self._result = None
self._done = False self._done = False
self._got_ending = False self._got_ending = False
self._conf = conf
def done(self): def done(self):
if self._done: if self._done:
@ -282,7 +297,8 @@ class MulticallWaiter(object):
"""The consume() callback will call this. Store the result.""" """The consume() callback will call this. Store the result."""
if data['failure']: if data['failure']:
failure = data['failure'] failure = data['failure']
self._result = rpc_common.deserialize_remote_exception(failure) self._result = rpc_common.deserialize_remote_exception(self._conf,
failure)
elif data.get('ending', False): elif data.get('ending', False):
self._got_ending = True self._got_ending = True
@ -309,12 +325,12 @@ class MulticallWaiter(object):
yield result yield result
def create_connection(new, connection_pool): def create_connection(conf, new, connection_pool):
"""Create a connection""" """Create a connection"""
return ConnectionContext(connection_pool, pooled=not new) return ConnectionContext(conf, connection_pool, pooled=not new)
def multicall(context, topic, msg, timeout, connection_pool): def multicall(conf, context, topic, msg, timeout, connection_pool):
"""Make a call that returns multiple times.""" """Make a call that returns multiple times."""
# Can't use 'with' for multicall, as it returns an iterator # Can't use 'with' for multicall, as it returns an iterator
# that will continue to use the connection. When it's done, # that will continue to use the connection. When it's done,
@ -326,16 +342,16 @@ def multicall(context, topic, msg, timeout, connection_pool):
LOG.debug(_('MSG_ID is %s') % (msg_id)) LOG.debug(_('MSG_ID is %s') % (msg_id))
pack_context(msg, context) pack_context(msg, context)
conn = ConnectionContext(connection_pool) conn = ConnectionContext(conf, connection_pool)
wait_msg = MulticallWaiter(conn, timeout) wait_msg = MulticallWaiter(conf, conn, timeout)
conn.declare_direct_consumer(msg_id, wait_msg) conn.declare_direct_consumer(msg_id, wait_msg)
conn.topic_send(topic, msg) conn.topic_send(topic, msg)
return wait_msg return wait_msg
def call(context, topic, msg, timeout, connection_pool): def call(conf, context, topic, msg, timeout, connection_pool):
"""Sends a message on a topic and wait for a response.""" """Sends a message on a topic and wait for a response."""
rv = multicall(context, topic, msg, timeout, connection_pool) rv = multicall(conf, context, topic, msg, timeout, connection_pool)
# NOTE(vish): return the last result from the multicall # NOTE(vish): return the last result from the multicall
rv = list(rv) rv = list(rv)
if not rv: if not rv:
@ -343,47 +359,48 @@ def call(context, topic, msg, timeout, connection_pool):
return rv[-1] return rv[-1]
def cast(context, topic, msg, connection_pool): def cast(conf, context, topic, msg, connection_pool):
"""Sends a message on a topic without waiting for a response.""" """Sends a message on a topic without waiting for a response."""
LOG.debug(_('Making asynchronous cast on %s...'), topic) LOG.debug(_('Making asynchronous cast on %s...'), topic)
pack_context(msg, context) pack_context(msg, context)
with ConnectionContext(connection_pool) as conn: with ConnectionContext(conf, connection_pool) as conn:
conn.topic_send(topic, msg) conn.topic_send(topic, msg)
def fanout_cast(context, topic, msg, connection_pool): def fanout_cast(conf, context, topic, msg, connection_pool):
"""Sends a message on a fanout exchange without waiting for a response.""" """Sends a message on a fanout exchange without waiting for a response."""
LOG.debug(_('Making asynchronous fanout cast...')) LOG.debug(_('Making asynchronous fanout cast...'))
pack_context(msg, context) pack_context(msg, context)
with ConnectionContext(connection_pool) as conn: with ConnectionContext(conf, connection_pool) as conn:
conn.fanout_send(topic, msg) conn.fanout_send(topic, msg)
def cast_to_server(context, server_params, topic, msg, connection_pool): def cast_to_server(conf, context, server_params, topic, msg, connection_pool):
"""Sends a message on a topic to a specific server.""" """Sends a message on a topic to a specific server."""
pack_context(msg, context) pack_context(msg, context)
with ConnectionContext(connection_pool, pooled=False, with ConnectionContext(conf, connection_pool, pooled=False,
server_params=server_params) as conn: server_params=server_params) as conn:
conn.topic_send(topic, msg) conn.topic_send(topic, msg)
def fanout_cast_to_server(context, server_params, topic, msg, def fanout_cast_to_server(conf, context, server_params, topic, msg,
connection_pool): connection_pool):
"""Sends a message on a fanout exchange to a specific server.""" """Sends a message on a fanout exchange to a specific server."""
pack_context(msg, context) pack_context(msg, context)
with ConnectionContext(connection_pool, pooled=False, with ConnectionContext(conf, connection_pool, pooled=False,
server_params=server_params) as conn: server_params=server_params) as conn:
conn.fanout_send(topic, msg) conn.fanout_send(topic, msg)
def notify(context, topic, msg, connection_pool): def notify(conf, context, topic, msg, connection_pool):
"""Sends a notification event on a topic.""" """Sends a notification event on a topic."""
event_type = msg.get('event_type') event_type = msg.get('event_type')
LOG.debug(_('Sending %(event_type)s on %(topic)s'), locals()) LOG.debug(_('Sending %(event_type)s on %(topic)s'), locals())
pack_context(msg, context) pack_context(msg, context)
with ConnectionContext(connection_pool) as conn: with ConnectionContext(conf, connection_pool) as conn:
conn.notify_send(topic, msg) conn.notify_send(topic, msg)
def cleanup(connection_pool): def cleanup(connection_pool):
if connection_pool:
connection_pool.empty() connection_pool.empty()

View File

@ -22,7 +22,6 @@ import sys
import traceback import traceback
from nova import exception from nova import exception
from nova import flags
from nova import log as logging from nova import log as logging
from nova.openstack.common import cfg from nova.openstack.common import cfg
from nova import utils from nova import utils
@ -30,25 +29,6 @@ from nova import utils
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
rpc_opts = [
cfg.IntOpt('rpc_thread_pool_size',
default=64,
help='Size of RPC thread pool'),
cfg.IntOpt('rpc_conn_pool_size',
default=30,
help='Size of RPC connection pool'),
cfg.IntOpt('rpc_response_timeout',
default=60,
help='Seconds to wait for a response from call or multicall'),
cfg.IntOpt('allowed_rpc_exception_modules',
default=['nova.exception'],
help='Modules of exceptions that are permitted to be recreated'
'upon receiving exception data from an rpc call.'),
]
flags.FLAGS.register_opts(rpc_opts)
FLAGS = flags.FLAGS
class RemoteError(exception.NovaException): class RemoteError(exception.NovaException):
"""Signifies that a remote class has raised an exception. """Signifies that a remote class has raised an exception.
@ -95,7 +75,7 @@ class Connection(object):
""" """
raise NotImplementedError() raise NotImplementedError()
def create_consumer(self, topic, proxy, fanout=False): def create_consumer(self, conf, topic, proxy, fanout=False):
"""Create a consumer on this connection. """Create a consumer on this connection.
A consumer is associated with a message queue on the backend message A consumer is associated with a message queue on the backend message
@ -104,6 +84,7 @@ class Connection(object):
off of the queue will determine which method gets called on the proxy off of the queue will determine which method gets called on the proxy
object. object.
:param conf: An openstack.common.cfg configuration object.
:param topic: This is a name associated with what to consume from. :param topic: This is a name associated with what to consume from.
Multiple instances of a service may consume from the same Multiple instances of a service may consume from the same
topic. For example, all instances of nova-compute consume topic. For example, all instances of nova-compute consume
@ -197,7 +178,7 @@ def serialize_remote_exception(failure_info):
return json_data return json_data
def deserialize_remote_exception(data): def deserialize_remote_exception(conf, data):
failure = utils.loads(str(data)) failure = utils.loads(str(data))
trace = failure.get('tb', []) trace = failure.get('tb', [])
@ -207,7 +188,7 @@ def deserialize_remote_exception(data):
# NOTE(ameade): We DO NOT want to allow just any module to be imported, in # NOTE(ameade): We DO NOT want to allow just any module to be imported, in
# order to prevent arbitrary code execution. # order to prevent arbitrary code execution.
if not module in FLAGS.allowed_rpc_exception_modules: if not module in conf.allowed_rpc_exception_modules:
return RemoteError(name, failure.get('message'), trace) return RemoteError(name, failure.get('message'), trace)
try: try:

View File

@ -27,13 +27,10 @@ import traceback
import eventlet import eventlet
from nova import context from nova import context
from nova import flags
from nova.rpc import common as rpc_common from nova.rpc import common as rpc_common
CONSUMERS = {} CONSUMERS = {}
FLAGS = flags.FLAGS
class RpcContext(context.RequestContext): class RpcContext(context.RequestContext):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -116,7 +113,7 @@ class Connection(object):
pass pass
def create_connection(new=True): def create_connection(conf, new=True):
"""Create a connection""" """Create a connection"""
return Connection() return Connection()
@ -126,7 +123,7 @@ def check_serialize(msg):
json.dumps(msg) json.dumps(msg)
def multicall(context, topic, msg, timeout=None): def multicall(conf, context, topic, msg, timeout=None):
"""Make a call that returns multiple times.""" """Make a call that returns multiple times."""
check_serialize(msg) check_serialize(msg)
@ -144,9 +141,9 @@ def multicall(context, topic, msg, timeout=None):
return consumer.call(context, method, args, timeout) return consumer.call(context, method, args, timeout)
def call(context, topic, msg, timeout=None): def call(conf, context, topic, msg, timeout=None):
"""Sends a message on a topic and wait for a response.""" """Sends a message on a topic and wait for a response."""
rv = multicall(context, topic, msg, timeout) rv = multicall(conf, context, topic, msg, timeout)
# NOTE(vish): return the last result from the multicall # NOTE(vish): return the last result from the multicall
rv = list(rv) rv = list(rv)
if not rv: if not rv:
@ -154,14 +151,14 @@ def call(context, topic, msg, timeout=None):
return rv[-1] return rv[-1]
def cast(context, topic, msg): def cast(conf, context, topic, msg):
try: try:
call(context, topic, msg) call(conf, context, topic, msg)
except Exception: except Exception:
pass pass
def notify(context, topic, msg): def notify(conf, context, topic, msg):
check_serialize(msg) check_serialize(msg)
@ -169,7 +166,7 @@ def cleanup():
pass pass
def fanout_cast(context, topic, msg): def fanout_cast(conf, context, topic, msg):
"""Cast to all consumers of a topic""" """Cast to all consumers of a topic"""
check_serialize(msg) check_serialize(msg)
method = msg.get('method') method = msg.get('method')
@ -182,3 +179,7 @@ def fanout_cast(context, topic, msg):
consumer.call(context, method, args, None) consumer.call(context, method, args, None)
except Exception: except Exception:
pass pass
def register_opts(conf):
pass

View File

@ -28,7 +28,6 @@ import kombu.entity
import kombu.messaging import kombu.messaging
import kombu.connection import kombu.connection
from nova import flags
from nova.openstack.common import cfg from nova.openstack.common import cfg
from nova.rpc import amqp as rpc_amqp from nova.rpc import amqp as rpc_amqp
from nova.rpc import common as rpc_common from nova.rpc import common as rpc_common
@ -49,8 +48,6 @@ kombu_opts = [
'(valid only if SSL enabled)')), '(valid only if SSL enabled)')),
] ]
FLAGS = flags.FLAGS
FLAGS.register_opts(kombu_opts)
LOG = rpc_common.LOG LOG = rpc_common.LOG
@ -126,7 +123,7 @@ class ConsumerBase(object):
class DirectConsumer(ConsumerBase): class DirectConsumer(ConsumerBase):
"""Queue/consumer class for 'direct'""" """Queue/consumer class for 'direct'"""
def __init__(self, channel, msg_id, callback, tag, **kwargs): def __init__(self, conf, channel, msg_id, callback, tag, **kwargs):
"""Init a 'direct' queue. """Init a 'direct' queue.
'channel' is the amqp channel to use 'channel' is the amqp channel to use
@ -159,7 +156,7 @@ class DirectConsumer(ConsumerBase):
class TopicConsumer(ConsumerBase): class TopicConsumer(ConsumerBase):
"""Consumer class for 'topic'""" """Consumer class for 'topic'"""
def __init__(self, channel, topic, callback, tag, **kwargs): def __init__(self, conf, channel, topic, callback, tag, **kwargs):
"""Init a 'topic' queue. """Init a 'topic' queue.
'channel' is the amqp channel to use 'channel' is the amqp channel to use
@ -170,12 +167,12 @@ class TopicConsumer(ConsumerBase):
Other kombu options may be passed Other kombu options may be passed
""" """
# Default options # Default options
options = {'durable': FLAGS.rabbit_durable_queues, options = {'durable': conf.rabbit_durable_queues,
'auto_delete': False, 'auto_delete': False,
'exclusive': False} 'exclusive': False}
options.update(kwargs) options.update(kwargs)
exchange = kombu.entity.Exchange( exchange = kombu.entity.Exchange(
name=FLAGS.control_exchange, name=conf.control_exchange,
type='topic', type='topic',
durable=options['durable'], durable=options['durable'],
auto_delete=options['auto_delete']) auto_delete=options['auto_delete'])
@ -192,7 +189,7 @@ class TopicConsumer(ConsumerBase):
class FanoutConsumer(ConsumerBase): class FanoutConsumer(ConsumerBase):
"""Consumer class for 'fanout'""" """Consumer class for 'fanout'"""
def __init__(self, channel, topic, callback, tag, **kwargs): def __init__(self, conf, channel, topic, callback, tag, **kwargs):
"""Init a 'fanout' queue. """Init a 'fanout' queue.
'channel' is the amqp channel to use 'channel' is the amqp channel to use
@ -252,7 +249,7 @@ class Publisher(object):
class DirectPublisher(Publisher): class DirectPublisher(Publisher):
"""Publisher class for 'direct'""" """Publisher class for 'direct'"""
def __init__(self, channel, msg_id, **kwargs): def __init__(self, conf, channel, msg_id, **kwargs):
"""init a 'direct' publisher. """init a 'direct' publisher.
Kombu options may be passed as keyword args to override defaults Kombu options may be passed as keyword args to override defaults
@ -271,17 +268,17 @@ class DirectPublisher(Publisher):
class TopicPublisher(Publisher): class TopicPublisher(Publisher):
"""Publisher class for 'topic'""" """Publisher class for 'topic'"""
def __init__(self, channel, topic, **kwargs): def __init__(self, conf, channel, topic, **kwargs):
"""init a 'topic' publisher. """init a 'topic' publisher.
Kombu options may be passed as keyword args to override defaults Kombu options may be passed as keyword args to override defaults
""" """
options = {'durable': FLAGS.rabbit_durable_queues, options = {'durable': conf.rabbit_durable_queues,
'auto_delete': False, 'auto_delete': False,
'exclusive': False} 'exclusive': False}
options.update(kwargs) options.update(kwargs)
super(TopicPublisher, self).__init__(channel, super(TopicPublisher, self).__init__(channel,
FLAGS.control_exchange, conf.control_exchange,
topic, topic,
type='topic', type='topic',
**options) **options)
@ -289,7 +286,7 @@ class TopicPublisher(Publisher):
class FanoutPublisher(Publisher): class FanoutPublisher(Publisher):
"""Publisher class for 'fanout'""" """Publisher class for 'fanout'"""
def __init__(self, channel, topic, **kwargs): def __init__(self, conf, channel, topic, **kwargs):
"""init a 'fanout' publisher. """init a 'fanout' publisher.
Kombu options may be passed as keyword args to override defaults Kombu options may be passed as keyword args to override defaults
@ -308,9 +305,9 @@ class FanoutPublisher(Publisher):
class NotifyPublisher(TopicPublisher): class NotifyPublisher(TopicPublisher):
"""Publisher class for 'notify'""" """Publisher class for 'notify'"""
def __init__(self, *args, **kwargs): def __init__(self, conf, channel, topic, **kwargs):
self.durable = kwargs.pop('durable', FLAGS.rabbit_durable_queues) self.durable = kwargs.pop('durable', conf.rabbit_durable_queues)
super(NotifyPublisher, self).__init__(*args, **kwargs) super(NotifyPublisher, self).__init__(conf, channel, topic, **kwargs)
def reconnect(self, channel): def reconnect(self, channel):
super(NotifyPublisher, self).reconnect(channel) super(NotifyPublisher, self).reconnect(channel)
@ -329,15 +326,18 @@ class NotifyPublisher(TopicPublisher):
class Connection(object): class Connection(object):
"""Connection object.""" """Connection object."""
def __init__(self, server_params=None): pool = None
def __init__(self, conf, server_params=None):
self.consumers = [] self.consumers = []
self.consumer_thread = None self.consumer_thread = None
self.max_retries = FLAGS.rabbit_max_retries self.conf = conf
self.max_retries = self.conf.rabbit_max_retries
# Try forever? # Try forever?
if self.max_retries <= 0: if self.max_retries <= 0:
self.max_retries = None self.max_retries = None
self.interval_start = FLAGS.rabbit_retry_interval self.interval_start = self.conf.rabbit_retry_interval
self.interval_stepping = FLAGS.rabbit_retry_backoff self.interval_stepping = self.conf.rabbit_retry_backoff
# max retry-interval = 30 seconds # max retry-interval = 30 seconds
self.interval_max = 30 self.interval_max = 30
self.memory_transport = False self.memory_transport = False
@ -353,21 +353,21 @@ class Connection(object):
p_key = server_params_to_kombu_params.get(sp_key, sp_key) p_key = server_params_to_kombu_params.get(sp_key, sp_key)
params[p_key] = value params[p_key] = value
params.setdefault('hostname', FLAGS.rabbit_host) params.setdefault('hostname', self.conf.rabbit_host)
params.setdefault('port', FLAGS.rabbit_port) params.setdefault('port', self.conf.rabbit_port)
params.setdefault('userid', FLAGS.rabbit_userid) params.setdefault('userid', self.conf.rabbit_userid)
params.setdefault('password', FLAGS.rabbit_password) params.setdefault('password', self.conf.rabbit_password)
params.setdefault('virtual_host', FLAGS.rabbit_virtual_host) params.setdefault('virtual_host', self.conf.rabbit_virtual_host)
self.params = params self.params = params
if FLAGS.fake_rabbit: if self.conf.fake_rabbit:
self.params['transport'] = 'memory' self.params['transport'] = 'memory'
self.memory_transport = True self.memory_transport = True
else: else:
self.memory_transport = False self.memory_transport = False
if FLAGS.rabbit_use_ssl: if self.conf.rabbit_use_ssl:
self.params['ssl'] = self._fetch_ssl_params() self.params['ssl'] = self._fetch_ssl_params()
self.connection = None self.connection = None
@ -379,14 +379,14 @@ class Connection(object):
ssl_params = dict() ssl_params = dict()
# http://docs.python.org/library/ssl.html - ssl.wrap_socket # http://docs.python.org/library/ssl.html - ssl.wrap_socket
if FLAGS.kombu_ssl_version: if self.conf.kombu_ssl_version:
ssl_params['ssl_version'] = FLAGS.kombu_ssl_version ssl_params['ssl_version'] = self.conf.kombu_ssl_version
if FLAGS.kombu_ssl_keyfile: if self.conf.kombu_ssl_keyfile:
ssl_params['keyfile'] = FLAGS.kombu_ssl_keyfile ssl_params['keyfile'] = self.conf.kombu_ssl_keyfile
if FLAGS.kombu_ssl_certfile: if self.conf.kombu_ssl_certfile:
ssl_params['certfile'] = FLAGS.kombu_ssl_certfile ssl_params['certfile'] = self.conf.kombu_ssl_certfile
if FLAGS.kombu_ssl_ca_certs: if self.conf.kombu_ssl_ca_certs:
ssl_params['ca_certs'] = FLAGS.kombu_ssl_ca_certs ssl_params['ca_certs'] = self.conf.kombu_ssl_ca_certs
# We might want to allow variations in the # We might want to allow variations in the
# future with this? # future with this?
ssl_params['cert_reqs'] = ssl.CERT_REQUIRED ssl_params['cert_reqs'] = ssl.CERT_REQUIRED
@ -534,7 +534,7 @@ class Connection(object):
"%(err_str)s") % log_info) "%(err_str)s") % log_info)
def _declare_consumer(): def _declare_consumer():
consumer = consumer_cls(self.channel, topic, callback, consumer = consumer_cls(self.conf, self.channel, topic, callback,
self.consumer_num.next()) self.consumer_num.next())
self.consumers.append(consumer) self.consumers.append(consumer)
return consumer return consumer
@ -590,7 +590,7 @@ class Connection(object):
"'%(topic)s': %(err_str)s") % log_info) "'%(topic)s': %(err_str)s") % log_info)
def _publish(): def _publish():
publisher = cls(self.channel, topic, **kwargs) publisher = cls(self.conf, self.channel, topic, **kwargs)
publisher.send(msg) publisher.send(msg)
self.ensure(_error_callback, _publish) self.ensure(_error_callback, _publish)
@ -648,58 +648,66 @@ class Connection(object):
def create_consumer(self, topic, proxy, fanout=False): def create_consumer(self, topic, proxy, fanout=False):
"""Create a consumer that calls a method in a proxy object""" """Create a consumer that calls a method in a proxy object"""
proxy_cb = rpc_amqp.ProxyCallback(self.conf, proxy,
rpc_amqp.get_connection_pool(self, Connection))
if fanout: if fanout:
self.declare_fanout_consumer(topic, self.declare_fanout_consumer(topic, proxy_cb)
rpc_amqp.ProxyCallback(proxy, Connection.pool))
else: else:
self.declare_topic_consumer(topic, self.declare_topic_consumer(topic, proxy_cb)
rpc_amqp.ProxyCallback(proxy, Connection.pool))
Connection.pool = rpc_amqp.Pool(connection_cls=Connection) def create_connection(conf, new=True):
def create_connection(new=True):
"""Create a connection""" """Create a connection"""
return rpc_amqp.create_connection(new, Connection.pool) return rpc_amqp.create_connection(conf, new,
rpc_amqp.get_connection_pool(conf, Connection))
def multicall(context, topic, msg, timeout=None): def multicall(conf, context, topic, msg, timeout=None):
"""Make a call that returns multiple times.""" """Make a call that returns multiple times."""
return rpc_amqp.multicall(context, topic, msg, timeout, Connection.pool) return rpc_amqp.multicall(conf, context, topic, msg, timeout,
rpc_amqp.get_connection_pool(conf, Connection))
def call(context, topic, msg, timeout=None): def call(conf, context, topic, msg, timeout=None):
"""Sends a message on a topic and wait for a response.""" """Sends a message on a topic and wait for a response."""
return rpc_amqp.call(context, topic, msg, timeout, Connection.pool) return rpc_amqp.call(conf, context, topic, msg, timeout,
rpc_amqp.get_connection_pool(conf, Connection))
def cast(context, topic, msg): def cast(conf, context, topic, msg):
"""Sends a message on a topic without waiting for a response.""" """Sends a message on a topic without waiting for a response."""
return rpc_amqp.cast(context, topic, msg, Connection.pool) return rpc_amqp.cast(conf, context, topic, msg,
rpc_amqp.get_connection_pool(conf, Connection))
def fanout_cast(context, topic, msg): def fanout_cast(conf, context, topic, msg):
"""Sends a message on a fanout exchange without waiting for a response.""" """Sends a message on a fanout exchange without waiting for a response."""
return rpc_amqp.fanout_cast(context, topic, msg, Connection.pool) return rpc_amqp.fanout_cast(conf, context, topic, msg,
rpc_amqp.get_connection_pool(conf, Connection))
def cast_to_server(context, server_params, topic, msg): def cast_to_server(conf, context, server_params, topic, msg):
"""Sends a message on a topic to a specific server.""" """Sends a message on a topic to a specific server."""
return rpc_amqp.cast_to_server(context, server_params, topic, msg, return rpc_amqp.cast_to_server(conf, context, server_params, topic, msg,
Connection.pool) rpc_amqp.get_connection_pool(conf, Connection))
def fanout_cast_to_server(context, server_params, topic, msg): def fanout_cast_to_server(conf, context, server_params, topic, msg):
"""Sends a message on a fanout exchange to a specific server.""" """Sends a message on a fanout exchange to a specific server."""
return rpc_amqp.cast_to_server(context, server_params, topic, msg, return rpc_amqp.cast_to_server(conf, context, server_params, topic, msg,
Connection.pool) rpc_amqp.get_connection_pool(conf, Connection))
def notify(context, topic, msg): def notify(conf, context, topic, msg):
"""Sends a notification event on a topic.""" """Sends a notification event on a topic."""
return rpc_amqp.notify(context, topic, msg, Connection.pool) return rpc_amqp.notify(conf, context, topic, msg,
rpc_amqp.get_connection_pool(conf, Connection))
def cleanup(): def cleanup():
return rpc_amqp.cleanup(Connection.pool) return rpc_amqp.cleanup(Connection.pool)
def register_opts(conf):
conf.register_opts(kombu_opts)

View File

@ -25,7 +25,6 @@ import greenlet
import qpid.messaging import qpid.messaging
import qpid.messaging.exceptions import qpid.messaging.exceptions
from nova import flags
from nova import log as logging from nova import log as logging
from nova.openstack.common import cfg from nova.openstack.common import cfg
from nova.rpc import amqp as rpc_amqp from nova.rpc import amqp as rpc_amqp
@ -78,9 +77,6 @@ qpid_opts = [
help='Disable Nagle algorithm'), help='Disable Nagle algorithm'),
] ]
FLAGS = flags.FLAGS
FLAGS.register_opts(qpid_opts)
class ConsumerBase(object): class ConsumerBase(object):
"""Consumer base class.""" """Consumer base class."""
@ -147,7 +143,7 @@ class ConsumerBase(object):
class DirectConsumer(ConsumerBase): class DirectConsumer(ConsumerBase):
"""Queue/consumer class for 'direct'""" """Queue/consumer class for 'direct'"""
def __init__(self, session, msg_id, callback): def __init__(self, conf, session, msg_id, callback):
"""Init a 'direct' queue. """Init a 'direct' queue.
'session' is the amqp session to use 'session' is the amqp session to use
@ -165,7 +161,7 @@ class DirectConsumer(ConsumerBase):
class TopicConsumer(ConsumerBase): class TopicConsumer(ConsumerBase):
"""Consumer class for 'topic'""" """Consumer class for 'topic'"""
def __init__(self, session, topic, callback): def __init__(self, conf, session, topic, callback):
"""Init a 'topic' queue. """Init a 'topic' queue.
'session' is the amqp session to use 'session' is the amqp session to use
@ -174,14 +170,14 @@ class TopicConsumer(ConsumerBase):
""" """
super(TopicConsumer, self).__init__(session, callback, super(TopicConsumer, self).__init__(session, callback,
"%s/%s" % (FLAGS.control_exchange, topic), {}, "%s/%s" % (conf.control_exchange, topic), {},
topic, {}) topic, {})
class FanoutConsumer(ConsumerBase): class FanoutConsumer(ConsumerBase):
"""Consumer class for 'fanout'""" """Consumer class for 'fanout'"""
def __init__(self, session, topic, callback): def __init__(self, conf, session, topic, callback):
"""Init a 'fanout' queue. """Init a 'fanout' queue.
'session' is the amqp session to use 'session' is the amqp session to use
@ -236,7 +232,7 @@ class Publisher(object):
class DirectPublisher(Publisher): class DirectPublisher(Publisher):
"""Publisher class for 'direct'""" """Publisher class for 'direct'"""
def __init__(self, session, msg_id): def __init__(self, conf, session, msg_id):
"""Init a 'direct' publisher.""" """Init a 'direct' publisher."""
super(DirectPublisher, self).__init__(session, msg_id, super(DirectPublisher, self).__init__(session, msg_id,
{"type": "Direct"}) {"type": "Direct"})
@ -244,16 +240,16 @@ class DirectPublisher(Publisher):
class TopicPublisher(Publisher): class TopicPublisher(Publisher):
"""Publisher class for 'topic'""" """Publisher class for 'topic'"""
def __init__(self, session, topic): def __init__(self, conf, session, topic):
"""init a 'topic' publisher. """init a 'topic' publisher.
""" """
super(TopicPublisher, self).__init__(session, super(TopicPublisher, self).__init__(session,
"%s/%s" % (FLAGS.control_exchange, topic)) "%s/%s" % (conf.control_exchange, topic))
class FanoutPublisher(Publisher): class FanoutPublisher(Publisher):
"""Publisher class for 'fanout'""" """Publisher class for 'fanout'"""
def __init__(self, session, topic): def __init__(self, conf, session, topic):
"""init a 'fanout' publisher. """init a 'fanout' publisher.
""" """
super(FanoutPublisher, self).__init__(session, super(FanoutPublisher, self).__init__(session,
@ -262,29 +258,32 @@ class FanoutPublisher(Publisher):
class NotifyPublisher(Publisher): class NotifyPublisher(Publisher):
"""Publisher class for notifications""" """Publisher class for notifications"""
def __init__(self, session, topic): def __init__(self, conf, session, topic):
"""init a 'topic' publisher. """init a 'topic' publisher.
""" """
super(NotifyPublisher, self).__init__(session, super(NotifyPublisher, self).__init__(session,
"%s/%s" % (FLAGS.control_exchange, topic), "%s/%s" % (conf.control_exchange, topic),
{"durable": True}) {"durable": True})
class Connection(object): class Connection(object):
"""Connection object.""" """Connection object."""
def __init__(self, server_params=None): pool = None
def __init__(self, conf, server_params=None):
self.session = None self.session = None
self.consumers = {} self.consumers = {}
self.consumer_thread = None self.consumer_thread = None
self.conf = conf
if server_params is None: if server_params is None:
server_params = {} server_params = {}
default_params = dict(hostname=FLAGS.qpid_hostname, default_params = dict(hostname=self.conf.qpid_hostname,
port=FLAGS.qpid_port, port=self.conf.qpid_port,
username=FLAGS.qpid_username, username=self.conf.qpid_username,
password=FLAGS.qpid_password) password=self.conf.qpid_password)
params = server_params params = server_params
for key in default_params.keys(): for key in default_params.keys():
@ -298,23 +297,25 @@ class Connection(object):
# before we call open # before we call open
self.connection.username = params['username'] self.connection.username = params['username']
self.connection.password = params['password'] self.connection.password = params['password']
self.connection.sasl_mechanisms = FLAGS.qpid_sasl_mechanisms self.connection.sasl_mechanisms = self.conf.qpid_sasl_mechanisms
self.connection.reconnect = FLAGS.qpid_reconnect self.connection.reconnect = self.conf.qpid_reconnect
if FLAGS.qpid_reconnect_timeout: if self.conf.qpid_reconnect_timeout:
self.connection.reconnect_timeout = FLAGS.qpid_reconnect_timeout self.connection.reconnect_timeout = (
if FLAGS.qpid_reconnect_limit: self.conf.qpid_reconnect_timeout)
self.connection.reconnect_limit = FLAGS.qpid_reconnect_limit if self.conf.qpid_reconnect_limit:
if FLAGS.qpid_reconnect_interval_max: self.connection.reconnect_limit = self.conf.qpid_reconnect_limit
if self.conf.qpid_reconnect_interval_max:
self.connection.reconnect_interval_max = ( self.connection.reconnect_interval_max = (
FLAGS.qpid_reconnect_interval_max) self.conf.qpid_reconnect_interval_max)
if FLAGS.qpid_reconnect_interval_min: if self.conf.qpid_reconnect_interval_min:
self.connection.reconnect_interval_min = ( self.connection.reconnect_interval_min = (
FLAGS.qpid_reconnect_interval_min) self.conf.qpid_reconnect_interval_min)
if FLAGS.qpid_reconnect_interval: if self.conf.qpid_reconnect_interval:
self.connection.reconnect_interval = FLAGS.qpid_reconnect_interval self.connection.reconnect_interval = (
self.connection.hearbeat = FLAGS.qpid_heartbeat self.conf.qpid_reconnect_interval)
self.connection.protocol = FLAGS.qpid_protocol self.connection.hearbeat = self.conf.qpid_heartbeat
self.connection.tcp_nodelay = FLAGS.qpid_tcp_nodelay self.connection.protocol = self.conf.qpid_protocol
self.connection.tcp_nodelay = self.conf.qpid_tcp_nodelay
# Open is part of reconnect - # Open is part of reconnect -
# NOTE(WGH) not sure we need this with the reconnect flags # NOTE(WGH) not sure we need this with the reconnect flags
@ -339,7 +340,7 @@ class Connection(object):
self.connection.open() self.connection.open()
except qpid.messaging.exceptions.ConnectionError, e: except qpid.messaging.exceptions.ConnectionError, e:
LOG.error(_('Unable to connect to AMQP server: %s'), e) LOG.error(_('Unable to connect to AMQP server: %s'), e)
time.sleep(FLAGS.qpid_reconnect_interval or 1) time.sleep(self.conf.qpid_reconnect_interval or 1)
else: else:
break break
@ -386,7 +387,7 @@ class Connection(object):
"%(err_str)s") % log_info) "%(err_str)s") % log_info)
def _declare_consumer(): def _declare_consumer():
consumer = consumer_cls(self.session, topic, callback) consumer = consumer_cls(self.conf, self.session, topic, callback)
self._register_consumer(consumer) self._register_consumer(consumer)
return consumer return consumer
@ -435,7 +436,7 @@ class Connection(object):
"'%(topic)s': %(err_str)s") % log_info) "'%(topic)s': %(err_str)s") % log_info)
def _publisher_send(): def _publisher_send():
publisher = cls(self.session, topic) publisher = cls(self.conf, self.session, topic)
publisher.send(msg) publisher.send(msg)
return self.ensure(_connect_error, _publisher_send) return self.ensure(_connect_error, _publisher_send)
@ -493,60 +494,70 @@ class Connection(object):
def create_consumer(self, topic, proxy, fanout=False): def create_consumer(self, topic, proxy, fanout=False):
"""Create a consumer that calls a method in a proxy object""" """Create a consumer that calls a method in a proxy object"""
proxy_cb = rpc_amqp.ProxyCallback(self.conf, proxy,
rpc_amqp.get_connection_pool(self, Connection))
if fanout: if fanout:
consumer = FanoutConsumer(self.session, topic, consumer = FanoutConsumer(self.conf, self.session, topic, proxy_cb)
rpc_amqp.ProxyCallback(proxy, Connection.pool))
else: else:
consumer = TopicConsumer(self.session, topic, consumer = TopicConsumer(self.conf, self.session, topic, proxy_cb)
rpc_amqp.ProxyCallback(proxy, Connection.pool))
self._register_consumer(consumer) self._register_consumer(consumer)
return consumer return consumer
Connection.pool = rpc_amqp.Pool(connection_cls=Connection) def create_connection(conf, new=True):
def create_connection(new=True):
"""Create a connection""" """Create a connection"""
return rpc_amqp.create_connection(new, Connection.pool) return rpc_amqp.create_connection(conf, new,
rpc_amqp.get_connection_pool(conf, Connection))
def multicall(context, topic, msg, timeout=None): def multicall(conf, context, topic, msg, timeout=None):
"""Make a call that returns multiple times.""" """Make a call that returns multiple times."""
return rpc_amqp.multicall(context, topic, msg, timeout, Connection.pool) return rpc_amqp.multicall(conf, context, topic, msg, timeout,
rpc_amqp.get_connection_pool(conf, Connection))
def call(context, topic, msg, timeout=None): def call(conf, context, topic, msg, timeout=None):
"""Sends a message on a topic and wait for a response.""" """Sends a message on a topic and wait for a response."""
return rpc_amqp.call(context, topic, msg, timeout, Connection.pool) return rpc_amqp.call(conf, context, topic, msg, timeout,
rpc_amqp.get_connection_pool(conf, Connection))
def cast(context, topic, msg): def cast(conf, context, topic, msg):
"""Sends a message on a topic without waiting for a response.""" """Sends a message on a topic without waiting for a response."""
return rpc_amqp.cast(context, topic, msg, Connection.pool) return rpc_amqp.cast(conf, context, topic, msg,
rpc_amqp.get_connection_pool(conf, Connection))
def fanout_cast(context, topic, msg): def fanout_cast(conf, context, topic, msg):
"""Sends a message on a fanout exchange without waiting for a response.""" """Sends a message on a fanout exchange without waiting for a response."""
return rpc_amqp.fanout_cast(context, topic, msg, Connection.pool) return rpc_amqp.fanout_cast(conf, context, topic, msg,
rpc_amqp.get_connection_pool(conf, Connection))
def cast_to_server(context, server_params, topic, msg): def cast_to_server(conf, context, server_params, topic, msg):
"""Sends a message on a topic to a specific server.""" """Sends a message on a topic to a specific server."""
return rpc_amqp.cast_to_server(context, server_params, topic, msg, return rpc_amqp.cast_to_server(conf, context, server_params, topic, msg,
Connection.pool) rpc_amqp.get_connection_pool(conf, Connection))
def fanout_cast_to_server(context, server_params, topic, msg): def fanout_cast_to_server(conf, context, server_params, topic, msg):
"""Sends a message on a fanout exchange to a specific server.""" """Sends a message on a fanout exchange to a specific server."""
return rpc_amqp.fanout_cast_to_server(context, server_params, topic, return rpc_amqp.fanout_cast_to_server(conf, context, server_params, topic,
msg, Connection.pool) msg, rpc_amqp.get_connection_pool(conf, Connection))
def notify(context, topic, msg): def notify(conf, context, topic, msg):
"""Sends a notification event on a topic.""" """Sends a notification event on a topic."""
return rpc_amqp.notify(context, topic, msg, Connection.pool) return rpc_amqp.notify(conf, context, topic, msg,
rpc_amqp.get_connection_pool(conf, Connection))
def cleanup(): def cleanup():
return rpc_amqp.cleanup(Connection.pool) return rpc_amqp.cleanup(Connection.pool)
def register_opts(conf):
conf.register_opts(qpid_opts)

View File

@ -177,6 +177,7 @@ class Service(object):
LOG.audit(_('Starting %(topic)s node (version %(vcs_string)s)'), LOG.audit(_('Starting %(topic)s node (version %(vcs_string)s)'),
{'topic': self.topic, 'vcs_string': vcs_string}) {'topic': self.topic, 'vcs_string': vcs_string})
utils.cleanup_file_locks() utils.cleanup_file_locks()
rpc.register_opts(FLAGS)
self.manager.init_host() self.manager.init_host()
self.model_disconnected = False self.model_disconnected = False
ctxt = context.get_admin_context() ctxt = context.get_admin_context()
@ -393,6 +394,7 @@ class WSGIService(object):
""" """
utils.cleanup_file_locks() utils.cleanup_file_locks()
rpc.register_opts(FLAGS)
if self.manager: if self.manager:
self.manager.init_host() self.manager.init_host()
self.server.start() self.server.start()

View File

@ -59,6 +59,9 @@ def reset_db():
def setup(): def setup():
import mox # Fail fast if you don't have mox. Workaround for bug 810424 import mox # Fail fast if you don't have mox. Workaround for bug 810424
from nova import rpc # Register rpc_backend before fake_flags sets it
FLAGS.register_opts(rpc.rpc_opts)
from nova import context from nova import context
from nova import db from nova import db
from nova.db import migration from nova.db import migration

View File

@ -26,19 +26,21 @@ import nose
from nova import context from nova import context
from nova import exception from nova import exception
from nova import flags
from nova import log as logging from nova import log as logging
from nova.rpc import amqp as rpc_amqp from nova.rpc import amqp as rpc_amqp
from nova.rpc import common as rpc_common from nova.rpc import common as rpc_common
from nova import test from nova import test
FLAGS = flags.FLAGS
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
class BaseRpcTestCase(test.TestCase): class BaseRpcTestCase(test.TestCase):
def setUp(self, supports_timeouts=True): def setUp(self, supports_timeouts=True):
super(BaseRpcTestCase, self).setUp() super(BaseRpcTestCase, self).setUp()
self.conn = self.rpc.create_connection(True) self.conn = self.rpc.create_connection(FLAGS, True)
self.receiver = TestReceiver() self.receiver = TestReceiver()
self.conn.create_consumer('test', self.receiver, False) self.conn.create_consumer('test', self.receiver, False)
self.conn.consume_in_thread() self.conn.consume_in_thread()
@ -51,20 +53,20 @@ class BaseRpcTestCase(test.TestCase):
def test_call_succeed(self): def test_call_succeed(self):
value = 42 value = 42
result = self.rpc.call(self.context, 'test', {"method": "echo", result = self.rpc.call(FLAGS, self.context, 'test',
"args": {"value": value}}) {"method": "echo", "args": {"value": value}})
self.assertEqual(value, result) self.assertEqual(value, result)
def test_call_succeed_despite_multiple_returns_yield(self): def test_call_succeed_despite_multiple_returns_yield(self):
value = 42 value = 42
result = self.rpc.call(self.context, 'test', result = self.rpc.call(FLAGS, self.context, 'test',
{"method": "echo_three_times_yield", {"method": "echo_three_times_yield",
"args": {"value": value}}) "args": {"value": value}})
self.assertEqual(value + 2, result) self.assertEqual(value + 2, result)
def test_multicall_succeed_once(self): def test_multicall_succeed_once(self):
value = 42 value = 42
result = self.rpc.multicall(self.context, result = self.rpc.multicall(FLAGS, self.context,
'test', 'test',
{"method": "echo", {"method": "echo",
"args": {"value": value}}) "args": {"value": value}})
@ -75,7 +77,7 @@ class BaseRpcTestCase(test.TestCase):
def test_multicall_three_nones(self): def test_multicall_three_nones(self):
value = 42 value = 42
result = self.rpc.multicall(self.context, result = self.rpc.multicall(FLAGS, self.context,
'test', 'test',
{"method": "multicall_three_nones", {"method": "multicall_three_nones",
"args": {"value": value}}) "args": {"value": value}})
@ -86,7 +88,7 @@ class BaseRpcTestCase(test.TestCase):
def test_multicall_succeed_three_times_yield(self): def test_multicall_succeed_three_times_yield(self):
value = 42 value = 42
result = self.rpc.multicall(self.context, result = self.rpc.multicall(FLAGS, self.context,
'test', 'test',
{"method": "echo_three_times_yield", {"method": "echo_three_times_yield",
"args": {"value": value}}) "args": {"value": value}})
@ -96,7 +98,7 @@ class BaseRpcTestCase(test.TestCase):
def test_context_passed(self): def test_context_passed(self):
"""Makes sure a context is passed through rpc call.""" """Makes sure a context is passed through rpc call."""
value = 42 value = 42
result = self.rpc.call(self.context, result = self.rpc.call(FLAGS, self.context,
'test', {"method": "context", 'test', {"method": "context",
"args": {"value": value}}) "args": {"value": value}})
self.assertEqual(self.context.to_dict(), result) self.assertEqual(self.context.to_dict(), result)
@ -112,7 +114,7 @@ class BaseRpcTestCase(test.TestCase):
# TODO(comstud): # TODO(comstud):
# so, it will replay the context and use the same REQID? # so, it will replay the context and use the same REQID?
# that's bizarre. # that's bizarre.
ret = self.rpc.call(context, ret = self.rpc.call(FLAGS, context,
queue, queue,
{"method": "echo", {"method": "echo",
"args": {"value": value}}) "args": {"value": value}})
@ -120,11 +122,11 @@ class BaseRpcTestCase(test.TestCase):
return value return value
nested = Nested() nested = Nested()
conn = self.rpc.create_connection(True) conn = self.rpc.create_connection(FLAGS, True)
conn.create_consumer('nested', nested, False) conn.create_consumer('nested', nested, False)
conn.consume_in_thread() conn.consume_in_thread()
value = 42 value = 42
result = self.rpc.call(self.context, result = self.rpc.call(FLAGS, self.context,
'nested', {"method": "echo", 'nested', {"method": "echo",
"args": {"queue": "test", "args": {"queue": "test",
"value": value}}) "value": value}})
@ -139,12 +141,12 @@ class BaseRpcTestCase(test.TestCase):
value = 42 value = 42
self.assertRaises(rpc_common.Timeout, self.assertRaises(rpc_common.Timeout,
self.rpc.call, self.rpc.call,
self.context, FLAGS, self.context,
'test', 'test',
{"method": "block", {"method": "block",
"args": {"value": value}}, timeout=1) "args": {"value": value}}, timeout=1)
try: try:
self.rpc.call(self.context, self.rpc.call(FLAGS, self.context,
'test', 'test',
{"method": "block", {"method": "block",
"args": {"value": value}}, "args": {"value": value}},
@ -169,8 +171,8 @@ class BaseRpcAMQPTestCase(BaseRpcTestCase):
self.stubs.Set(rpc_amqp, 'unpack_context', fake_unpack_context) self.stubs.Set(rpc_amqp, 'unpack_context', fake_unpack_context)
value = 41 value = 41
self.rpc.cast(self.context, 'test', {"method": "echo", self.rpc.cast(FLAGS, self.context, 'test',
"args": {"value": value}}) {"method": "echo", "args": {"value": value}})
# Wait for the cast to complete. # Wait for the cast to complete.
for x in xrange(50): for x in xrange(50):
@ -185,7 +187,7 @@ class BaseRpcAMQPTestCase(BaseRpcTestCase):
self.stubs.Set(rpc_amqp, 'unpack_context', orig_unpack) self.stubs.Set(rpc_amqp, 'unpack_context', orig_unpack)
value = 42 value = 42
result = self.rpc.call(self.context, 'test', result = self.rpc.call(FLAGS, self.context, 'test',
{"method": "echo", {"method": "echo",
"args": {"value": value}}) "args": {"value": value}})
self.assertEqual(value, result) self.assertEqual(value, result)

View File

@ -93,7 +93,7 @@ class RpcCommonTestCase(test.TestCase):
} }
serialized = json.dumps(failure) serialized = json.dumps(failure)
after_exc = rpc_common.deserialize_remote_exception(serialized) after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized)
self.assertTrue(isinstance(after_exc, exception.NovaException)) self.assertTrue(isinstance(after_exc, exception.NovaException))
self.assertTrue('test message' in unicode(after_exc)) self.assertTrue('test message' in unicode(after_exc))
#assure the traceback was added #assure the traceback was added
@ -108,7 +108,7 @@ class RpcCommonTestCase(test.TestCase):
} }
serialized = json.dumps(failure) serialized = json.dumps(failure)
after_exc = rpc_common.deserialize_remote_exception(serialized) after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized)
self.assertTrue(isinstance(after_exc, rpc_common.RemoteError)) self.assertTrue(isinstance(after_exc, rpc_common.RemoteError))
def test_deserialize_remote_exception_user_defined_exception(self): def test_deserialize_remote_exception_user_defined_exception(self):
@ -121,7 +121,7 @@ class RpcCommonTestCase(test.TestCase):
} }
serialized = json.dumps(failure) serialized = json.dumps(failure)
after_exc = rpc_common.deserialize_remote_exception(serialized) after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized)
self.assertTrue(isinstance(after_exc, FakeUserDefinedException)) self.assertTrue(isinstance(after_exc, FakeUserDefinedException))
#assure the traceback was added #assure the traceback was added
self.assertTrue('raise FakeUserDefinedException' in unicode(after_exc)) self.assertTrue('raise FakeUserDefinedException' in unicode(after_exc))
@ -141,7 +141,7 @@ class RpcCommonTestCase(test.TestCase):
} }
serialized = json.dumps(failure) serialized = json.dumps(failure)
after_exc = rpc_common.deserialize_remote_exception(serialized) after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized)
self.assertTrue(isinstance(after_exc, rpc_common.RemoteError)) self.assertTrue(isinstance(after_exc, rpc_common.RemoteError))
#assure the traceback was added #assure the traceback was added
self.assertTrue('raise FakeIDontExistException' in unicode(after_exc)) self.assertTrue('raise FakeIDontExistException' in unicode(after_exc))

View File

@ -53,6 +53,7 @@ def _raise_exc_stub(stubs, times, obj, method, exc_msg,
class RpcKombuTestCase(common.BaseRpcAMQPTestCase): class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
def setUp(self): def setUp(self):
self.rpc = impl_kombu self.rpc = impl_kombu
impl_kombu.register_opts(FLAGS)
super(RpcKombuTestCase, self).setUp() super(RpcKombuTestCase, self).setUp()
def tearDown(self): def tearDown(self):
@ -61,10 +62,10 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
def test_reusing_connection(self): def test_reusing_connection(self):
"""Test that reusing a connection returns same one.""" """Test that reusing a connection returns same one."""
conn_context = self.rpc.create_connection(new=False) conn_context = self.rpc.create_connection(FLAGS, new=False)
conn1 = conn_context.connection conn1 = conn_context.connection
conn_context.close() conn_context.close()
conn_context = self.rpc.create_connection(new=False) conn_context = self.rpc.create_connection(FLAGS, new=False)
conn2 = conn_context.connection conn2 = conn_context.connection
conn_context.close() conn_context.close()
self.assertEqual(conn1, conn2) self.assertEqual(conn1, conn2)
@ -72,7 +73,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
def test_topic_send_receive(self): def test_topic_send_receive(self):
"""Test sending to a topic exchange/queue""" """Test sending to a topic exchange/queue"""
conn = self.rpc.create_connection() conn = self.rpc.create_connection(FLAGS)
message = 'topic test message' message = 'topic test message'
self.received_message = None self.received_message = None
@ -89,7 +90,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
def test_direct_send_receive(self): def test_direct_send_receive(self):
"""Test sending to a direct exchange/queue""" """Test sending to a direct exchange/queue"""
conn = self.rpc.create_connection() conn = self.rpc.create_connection(FLAGS)
message = 'direct test message' message = 'direct test message'
self.received_message = None self.received_message = None
@ -123,10 +124,10 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
def topic_send(_context, topic, msg): def topic_send(_context, topic, msg):
pass pass
MyConnection.pool = rpc_amqp.Pool(connection_cls=MyConnection) MyConnection.pool = rpc_amqp.Pool(FLAGS, MyConnection)
self.stubs.Set(impl_kombu, 'Connection', MyConnection) self.stubs.Set(impl_kombu, 'Connection', MyConnection)
impl_kombu.cast(ctxt, 'fake_topic', {'msg': 'fake'}) impl_kombu.cast(FLAGS, ctxt, 'fake_topic', {'msg': 'fake'})
def test_cast_to_server_uses_server_params(self): def test_cast_to_server_uses_server_params(self):
"""Test kombu rpc.cast""" """Test kombu rpc.cast"""
@ -153,10 +154,10 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
def topic_send(_context, topic, msg): def topic_send(_context, topic, msg):
pass pass
MyConnection.pool = rpc_amqp.Pool(connection_cls=MyConnection) MyConnection.pool = rpc_amqp.Pool(FLAGS, MyConnection)
self.stubs.Set(impl_kombu, 'Connection', MyConnection) self.stubs.Set(impl_kombu, 'Connection', MyConnection)
impl_kombu.cast_to_server(ctxt, server_params, impl_kombu.cast_to_server(FLAGS, ctxt, server_params,
'fake_topic', {'msg': 'fake'}) 'fake_topic', {'msg': 'fake'})
@test.skip_test("kombu memory transport seems buggy with fanout queues " @test.skip_test("kombu memory transport seems buggy with fanout queues "
@ -192,7 +193,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectConsumer, info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectConsumer,
'__init__', 'foo timeout foo') '__init__', 'foo timeout foo')
conn = self.rpc.Connection() conn = self.rpc.Connection(FLAGS)
result = conn.declare_consumer(self.rpc.DirectConsumer, result = conn.declare_consumer(self.rpc.DirectConsumer,
'test_topic', None) 'test_topic', None)
@ -206,7 +207,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectConsumer, info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectConsumer,
'__init__', 'meow') '__init__', 'meow')
conn = self.rpc.Connection() conn = self.rpc.Connection(FLAGS)
conn.connection_errors = (MyException, ) conn.connection_errors = (MyException, )
result = conn.declare_consumer(self.rpc.DirectConsumer, result = conn.declare_consumer(self.rpc.DirectConsumer,
@ -220,7 +221,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectConsumer, info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectConsumer,
'__init__', 'Socket closed', exc_class=IOError) '__init__', 'Socket closed', exc_class=IOError)
conn = self.rpc.Connection() conn = self.rpc.Connection(FLAGS)
result = conn.declare_consumer(self.rpc.DirectConsumer, result = conn.declare_consumer(self.rpc.DirectConsumer,
'test_topic', None) 'test_topic', None)
@ -234,7 +235,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectPublisher, info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectPublisher,
'__init__', 'foo timeout foo') '__init__', 'foo timeout foo')
conn = self.rpc.Connection() conn = self.rpc.Connection(FLAGS)
conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg') conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg')
self.assertEqual(info['called'], 3) self.assertEqual(info['called'], 3)
@ -243,7 +244,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectPublisher, info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectPublisher,
'send', 'foo timeout foo') 'send', 'foo timeout foo')
conn = self.rpc.Connection() conn = self.rpc.Connection(FLAGS)
conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg') conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg')
self.assertEqual(info['called'], 3) self.assertEqual(info['called'], 3)
@ -256,7 +257,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectPublisher, info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectPublisher,
'__init__', 'meow') '__init__', 'meow')
conn = self.rpc.Connection() conn = self.rpc.Connection(FLAGS)
conn.connection_errors = (MyException, ) conn.connection_errors = (MyException, )
conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg') conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg')
@ -267,7 +268,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectPublisher, info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectPublisher,
'send', 'meow') 'send', 'meow')
conn = self.rpc.Connection() conn = self.rpc.Connection(FLAGS)
conn.connection_errors = (MyException, ) conn.connection_errors = (MyException, )
conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg') conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg')
@ -275,7 +276,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
self.assertEqual(info['called'], 2) self.assertEqual(info['called'], 2)
def test_iterconsume_errors_will_reconnect(self): def test_iterconsume_errors_will_reconnect(self):
conn = self.rpc.Connection() conn = self.rpc.Connection(FLAGS)
message = 'reconnect test message' message = 'reconnect test message'
self.received_message = None self.received_message = None
@ -305,12 +306,13 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
value = "This is the exception message" value = "This is the exception message"
self.assertRaises(NotImplementedError, self.assertRaises(NotImplementedError,
self.rpc.call, self.rpc.call,
FLAGS,
self.context, self.context,
'test', 'test',
{"method": "fail", {"method": "fail",
"args": {"value": value}}) "args": {"value": value}})
try: try:
self.rpc.call(self.context, self.rpc.call(FLAGS, self.context,
'test', 'test',
{"method": "fail", {"method": "fail",
"args": {"value": value}}) "args": {"value": value}})
@ -330,12 +332,13 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
value = "This is the exception message" value = "This is the exception message"
self.assertRaises(exception.ConvertedException, self.assertRaises(exception.ConvertedException,
self.rpc.call, self.rpc.call,
FLAGS,
self.context, self.context,
'test', 'test',
{"method": "fail_converted", {"method": "fail_converted",
"args": {"value": value}}) "args": {"value": value}})
try: try:
self.rpc.call(self.context, self.rpc.call(FLAGS, self.context,
'test', 'test',
{"method": "fail_converted", {"method": "fail_converted",
"args": {"value": value}}) "args": {"value": value}})

View File

@ -19,6 +19,7 @@
Unit Tests for remote procedure calls using kombu + ssl Unit Tests for remote procedure calls using kombu + ssl
""" """
from nova import flags
from nova import test from nova import test
from nova.rpc import impl_kombu from nova.rpc import impl_kombu
@ -28,11 +29,14 @@ SSL_CERT = "/tmp/cert.blah.blah"
SSL_CA_CERT = "/tmp/cert.ca.blah.blah" SSL_CA_CERT = "/tmp/cert.ca.blah.blah"
SSL_KEYFILE = "/tmp/keyfile.blah.blah" SSL_KEYFILE = "/tmp/keyfile.blah.blah"
FLAGS = flags.FLAGS
class RpcKombuSslTestCase(test.TestCase): class RpcKombuSslTestCase(test.TestCase):
def setUp(self): def setUp(self):
super(RpcKombuSslTestCase, self).setUp() super(RpcKombuSslTestCase, self).setUp()
impl_kombu.register_opts(FLAGS)
self.flags(kombu_ssl_keyfile=SSL_KEYFILE, self.flags(kombu_ssl_keyfile=SSL_KEYFILE,
kombu_ssl_ca_certs=SSL_CA_CERT, kombu_ssl_ca_certs=SSL_CA_CERT,
kombu_ssl_certfile=SSL_CERT, kombu_ssl_certfile=SSL_CERT,
@ -41,7 +45,7 @@ class RpcKombuSslTestCase(test.TestCase):
def test_ssl_on_extended(self): def test_ssl_on_extended(self):
rpc = impl_kombu rpc = impl_kombu
conn = rpc.create_connection(True) conn = rpc.create_connection(FLAGS, True)
c = conn.connection c = conn.connection
#This might be kombu version dependent... #This might be kombu version dependent...
#Since we are now peaking into the internals of kombu... #Since we are now peaking into the internals of kombu...

View File

@ -23,6 +23,7 @@ Unit Tests for remote procedure calls using qpid
import mox import mox
from nova import context from nova import context
from nova import flags
from nova import log as logging from nova import log as logging
from nova.rpc import amqp as rpc_amqp from nova.rpc import amqp as rpc_amqp
from nova import test from nova import test
@ -35,6 +36,7 @@ except ImportError:
impl_qpid = None impl_qpid = None
FLAGS = flags.FLAGS
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -64,6 +66,7 @@ class RpcQpidTestCase(test.TestCase):
self.mock_receiver = None self.mock_receiver = None
if qpid: if qpid:
impl_qpid.register_opts(FLAGS)
self.orig_connection = qpid.messaging.Connection self.orig_connection = qpid.messaging.Connection
self.orig_session = qpid.messaging.Session self.orig_session = qpid.messaging.Session
self.orig_sender = qpid.messaging.Sender self.orig_sender = qpid.messaging.Sender
@ -98,7 +101,7 @@ class RpcQpidTestCase(test.TestCase):
self.mox.ReplayAll() self.mox.ReplayAll()
connection = impl_qpid.create_connection() connection = impl_qpid.create_connection(FLAGS)
connection.close() connection.close()
def _test_create_consumer(self, fanout): def _test_create_consumer(self, fanout):
@ -130,7 +133,7 @@ class RpcQpidTestCase(test.TestCase):
self.mox.ReplayAll() self.mox.ReplayAll()
connection = impl_qpid.create_connection() connection = impl_qpid.create_connection(FLAGS)
connection.create_consumer("impl_qpid_test", connection.create_consumer("impl_qpid_test",
lambda *_x, **_y: None, lambda *_x, **_y: None,
fanout) fanout)
@ -176,11 +179,11 @@ class RpcQpidTestCase(test.TestCase):
try: try:
ctx = context.RequestContext("user", "project") ctx = context.RequestContext("user", "project")
args = [ctx, "impl_qpid_test", args = [FLAGS, ctx, "impl_qpid_test",
{"method": "test_method", "args": {}}] {"method": "test_method", "args": {}}]
if server_params: if server_params:
args.insert(1, server_params) args.insert(2, server_params)
if fanout: if fanout:
method = impl_qpid.fanout_cast_to_server method = impl_qpid.fanout_cast_to_server
else: else:
@ -218,7 +221,7 @@ class RpcQpidTestCase(test.TestCase):
server_params['hostname'] + ':' + server_params['hostname'] + ':' +
str(server_params['port'])) str(server_params['port']))
MyConnection.pool = rpc_amqp.Pool(connection_cls=MyConnection) MyConnection.pool = rpc_amqp.Pool(FLAGS, MyConnection)
self.stubs.Set(impl_qpid, 'Connection', MyConnection) self.stubs.Set(impl_qpid, 'Connection', MyConnection)
@test.skip_if(qpid is None, "Test requires qpid") @test.skip_if(qpid is None, "Test requires qpid")
@ -295,7 +298,7 @@ class RpcQpidTestCase(test.TestCase):
else: else:
method = impl_qpid.call method = impl_qpid.call
res = method(ctx, "impl_qpid_test", res = method(FLAGS, ctx, "impl_qpid_test",
{"method": "test_method", "args": {}}) {"method": "test_method", "args": {}})
if multi: if multi: