Refactor nova.rpc config handling.

This patch does a couple of things:

1) Remove the dependency of nova.rpc on nova.flags.  This is a step
toward decoupling nova.rpc from the rest of nova so that it can be moved
to openstack-common.

2) Refactor nova.rpc so that a configuration object is passed around as
needed instead of depending on nova.flags.FLAGS.

This was done by avoiding changing the nova.rpc API as much as possible
so that existing usage of nova.rpc would not have to be touched.  So,
instead, a config object gets registered, cached, and then passed into
the rpc implementations as needed.  Getting rid of this global config
reference in nova.rpc will require changing the public API and I wanted
to avoid doing that until there was a better reason than this.

Change-Id: I9a7fa67bd12ced877c83e48e31f5ef7263be6815
This commit is contained in:
Russell Bryant 2012-04-09 18:08:50 -04:00
parent ca4aee67e3
commit 0b11668e64
16 changed files with 319 additions and 259 deletions

View File

@ -74,6 +74,7 @@ if __name__ == '__main__':
utils.default_flagfile()
args = flags.FLAGS(sys.argv)
logging.setup()
rpc.register_opts(flags.FLAGS)
delete_queues(args[1:])
if FLAGS.delete_exchange:
delete_exchange(FLAGS.control_exchange)

View File

@ -99,6 +99,8 @@ def main():
argv = FLAGS(sys.argv)
logging.setup()
rpc.register_opts(FLAGS)
if int(os.environ.get('TESTING', '0')):
from nova.tests import fake_flags

View File

@ -1671,6 +1671,8 @@ def main():
except Exception:
print 'sudo failed, continuing as if nothing happened'
rpc.register_opts(FLAGS)
try:
argv = FLAGS(sys.argv)
logging.setup()
@ -1679,7 +1681,6 @@ def main():
print _('Please re-run nova-manage as root.')
sys.exit(2)
raise
script_name = argv.pop(0)
if len(argv) < 1:
print _("\nOpenStack Nova version: %(version)s (%(vcs)s)\n") % \

View File

@ -17,17 +17,37 @@
# License for the specific language governing permissions and limitations
# under the License.
from nova import flags
from nova.openstack.common import cfg
from nova import utils
rpc_backend_opt = cfg.StrOpt('rpc_backend',
rpc_opts = [
cfg.StrOpt('rpc_backend',
default='nova.rpc.impl_kombu',
help="The messaging module to use, defaults to kombu.")
help="The messaging module to use, defaults to kombu."),
cfg.IntOpt('rpc_thread_pool_size',
default=64,
help='Size of RPC thread pool'),
cfg.IntOpt('rpc_conn_pool_size',
default=30,
help='Size of RPC connection pool'),
cfg.IntOpt('rpc_response_timeout',
default=60,
help='Seconds to wait for a response from call or multicall'),
cfg.IntOpt('allowed_rpc_exception_modules',
default=['nova.exception'],
help='Modules of exceptions that are permitted to be recreated'
'upon receiving exception data from an rpc call.'),
]
FLAGS = flags.FLAGS
FLAGS.register_opt(rpc_backend_opt)
_CONF = None
def register_opts(conf):
global _CONF
_CONF = conf
_CONF.register_opts(rpc_opts)
_get_impl().register_opts(_CONF)
def create_connection(new=True):
@ -43,7 +63,7 @@ def create_connection(new=True):
:returns: An instance of nova.rpc.common.Connection
"""
return _get_impl().create_connection(new=new)
return _get_impl().create_connection(_CONF, new=new)
def call(context, topic, msg, timeout=None):
@ -65,7 +85,7 @@ def call(context, topic, msg, timeout=None):
:raises: nova.rpc.common.Timeout if a complete response is not received
before the timeout is reached.
"""
return _get_impl().call(context, topic, msg, timeout)
return _get_impl().call(_CONF, context, topic, msg, timeout)
def cast(context, topic, msg):
@ -82,7 +102,7 @@ def cast(context, topic, msg):
:returns: None
"""
return _get_impl().cast(context, topic, msg)
return _get_impl().cast(_CONF, context, topic, msg)
def fanout_cast(context, topic, msg):
@ -102,7 +122,7 @@ def fanout_cast(context, topic, msg):
:returns: None
"""
return _get_impl().fanout_cast(context, topic, msg)
return _get_impl().fanout_cast(_CONF, context, topic, msg)
def multicall(context, topic, msg, timeout=None):
@ -131,7 +151,7 @@ def multicall(context, topic, msg, timeout=None):
:raises: nova.rpc.common.Timeout if a complete response is not received
before the timeout is reached.
"""
return _get_impl().multicall(context, topic, msg, timeout)
return _get_impl().multicall(_CONF, context, topic, msg, timeout)
def notify(context, topic, msg):
@ -144,7 +164,7 @@ def notify(context, topic, msg):
:returns: None
"""
return _get_impl().notify(context, topic, msg)
return _get_impl().notify(_CONF, context, topic, msg)
def cleanup():
@ -172,7 +192,8 @@ def cast_to_server(context, server_params, topic, msg):
:returns: None
"""
return _get_impl().cast_to_server(context, server_params, topic, msg)
return _get_impl().cast_to_server(_CONF, context, server_params, topic,
msg)
def fanout_cast_to_server(context, server_params, topic, msg):
@ -187,16 +208,16 @@ def fanout_cast_to_server(context, server_params, topic, msg):
:returns: None
"""
return _get_impl().fanout_cast_to_server(context, server_params, topic,
msg)
return _get_impl().fanout_cast_to_server(_CONF, context, server_params,
topic, msg)
_RPCIMPL = None
def _get_impl():
"""Delay import of rpc_backend until FLAGS are loaded."""
"""Delay import of rpc_backend until configuration is loaded."""
global _RPCIMPL
if _RPCIMPL is None:
_RPCIMPL = utils.import_object(FLAGS.rpc_backend)
_RPCIMPL = utils.import_object(_CONF.rpc_backend)
return _RPCIMPL

View File

@ -31,10 +31,10 @@ import uuid
from eventlet import greenpool
from eventlet import pools
from eventlet import semaphore
from nova import context
from nova import exception
from nova import flags
from nova import log as logging
from nova.openstack.common import local
import nova.rpc.common as rpc_common
@ -43,27 +43,36 @@ from nova import utils
LOG = logging.getLogger(__name__)
FLAGS = flags.FLAGS
class Pool(pools.Pool):
"""Class that implements a Pool of Connections."""
def __init__(self, *args, **kwargs):
self.connection_cls = kwargs.pop("connection_cls", None)
kwargs.setdefault("max_size", FLAGS.rpc_conn_pool_size)
def __init__(self, conf, connection_cls, *args, **kwargs):
self.connection_cls = connection_cls
self.conf = conf
kwargs.setdefault("max_size", self.conf.rpc_conn_pool_size)
kwargs.setdefault("order_as_stack", True)
super(Pool, self).__init__(*args, **kwargs)
# TODO(comstud): Timeout connections not used in a while
def create(self):
LOG.debug('Pool creating new connection')
return self.connection_cls()
return self.connection_cls(self.conf)
def empty(self):
while self.free_items:
self.get().close()
_pool_create_sem = semaphore.Semaphore()
def get_connection_pool(conf, connection_cls):
with _pool_create_sem:
# Make sure only one thread tries to create the connection pool.
if not connection_cls.pool:
connection_cls.pool = Pool(conf, connection_cls)
return connection_cls.pool
class ConnectionContext(rpc_common.Connection):
"""The class that is actually returned to the caller of
create_connection(). This is a essentially a wrapper around
@ -75,14 +84,15 @@ class ConnectionContext(rpc_common.Connection):
the pool.
"""
def __init__(self, connection_pool, pooled=True, server_params=None):
def __init__(self, conf, connection_pool, pooled=True, server_params=None):
"""Create a new connection, or get one from the pool"""
self.connection = None
self.conf = conf
self.connection_pool = connection_pool
if pooled:
self.connection = connection_pool.get()
else:
self.connection = connection_pool.connection_cls(
self.connection = connection_pool.connection_cls(conf,
server_params=server_params)
self.pooled = pooled
@ -133,13 +143,14 @@ class ConnectionContext(rpc_common.Connection):
raise exception.InvalidRPCConnectionReuse()
def msg_reply(msg_id, connection_pool, reply=None, failure=None, ending=False):
def msg_reply(conf, msg_id, connection_pool, reply=None, failure=None,
ending=False):
"""Sends a reply or an error on the channel signified by msg_id.
Failure should be a sys.exc_info() tuple.
"""
with ConnectionContext(connection_pool) as conn:
with ConnectionContext(conf, connection_pool) as conn:
if failure:
failure = rpc_common.serialize_remote_exception(failure)
@ -158,18 +169,19 @@ class RpcContext(context.RequestContext):
"""Context that supports replying to a rpc.call"""
def __init__(self, *args, **kwargs):
self.msg_id = kwargs.pop('msg_id', None)
self.conf = kwargs.pop('conf')
super(RpcContext, self).__init__(*args, **kwargs)
def reply(self, reply=None, failure=None, ending=False,
connection_pool=None):
if self.msg_id:
msg_reply(self.msg_id, connection_pool, reply, failure,
msg_reply(self.conf, self.msg_id, connection_pool, reply, failure,
ending)
if ending:
self.msg_id = None
def unpack_context(msg):
def unpack_context(conf, msg):
"""Unpack context from msg."""
context_dict = {}
for key in list(msg.keys()):
@ -180,6 +192,7 @@ def unpack_context(msg):
value = msg.pop(key)
context_dict[key[9:]] = value
context_dict['msg_id'] = msg.pop('_msg_id', None)
context_dict['conf'] = conf
ctx = RpcContext.from_dict(context_dict)
rpc_common._safe_log(LOG.debug, _('unpacked context: %s'), ctx.to_dict())
return ctx
@ -202,10 +215,11 @@ def pack_context(msg, context):
class ProxyCallback(object):
"""Calls methods on a proxy object based on method and args."""
def __init__(self, proxy, connection_pool):
def __init__(self, conf, proxy, connection_pool):
self.proxy = proxy
self.pool = greenpool.GreenPool(FLAGS.rpc_thread_pool_size)
self.pool = greenpool.GreenPool(conf.rpc_thread_pool_size)
self.connection_pool = connection_pool
self.conf = conf
def __call__(self, message_data):
"""Consumer callback to call a method on a proxy object.
@ -225,7 +239,7 @@ class ProxyCallback(object):
if hasattr(local.store, 'context'):
del local.store.context
rpc_common._safe_log(LOG.debug, _('received %s'), message_data)
ctxt = unpack_context(message_data)
ctxt = unpack_context(self.conf, message_data)
method = message_data.get('method')
args = message_data.get('args', {})
if not method:
@ -262,13 +276,14 @@ class ProxyCallback(object):
class MulticallWaiter(object):
def __init__(self, connection, timeout):
def __init__(self, conf, connection, timeout):
self._connection = connection
self._iterator = connection.iterconsume(
timeout=timeout or FLAGS.rpc_response_timeout)
timeout=timeout or conf.rpc_response_timeout)
self._result = None
self._done = False
self._got_ending = False
self._conf = conf
def done(self):
if self._done:
@ -282,7 +297,8 @@ class MulticallWaiter(object):
"""The consume() callback will call this. Store the result."""
if data['failure']:
failure = data['failure']
self._result = rpc_common.deserialize_remote_exception(failure)
self._result = rpc_common.deserialize_remote_exception(self._conf,
failure)
elif data.get('ending', False):
self._got_ending = True
@ -309,12 +325,12 @@ class MulticallWaiter(object):
yield result
def create_connection(new, connection_pool):
def create_connection(conf, new, connection_pool):
"""Create a connection"""
return ConnectionContext(connection_pool, pooled=not new)
return ConnectionContext(conf, connection_pool, pooled=not new)
def multicall(context, topic, msg, timeout, connection_pool):
def multicall(conf, context, topic, msg, timeout, connection_pool):
"""Make a call that returns multiple times."""
# Can't use 'with' for multicall, as it returns an iterator
# that will continue to use the connection. When it's done,
@ -326,16 +342,16 @@ def multicall(context, topic, msg, timeout, connection_pool):
LOG.debug(_('MSG_ID is %s') % (msg_id))
pack_context(msg, context)
conn = ConnectionContext(connection_pool)
wait_msg = MulticallWaiter(conn, timeout)
conn = ConnectionContext(conf, connection_pool)
wait_msg = MulticallWaiter(conf, conn, timeout)
conn.declare_direct_consumer(msg_id, wait_msg)
conn.topic_send(topic, msg)
return wait_msg
def call(context, topic, msg, timeout, connection_pool):
def call(conf, context, topic, msg, timeout, connection_pool):
"""Sends a message on a topic and wait for a response."""
rv = multicall(context, topic, msg, timeout, connection_pool)
rv = multicall(conf, context, topic, msg, timeout, connection_pool)
# NOTE(vish): return the last result from the multicall
rv = list(rv)
if not rv:
@ -343,47 +359,48 @@ def call(context, topic, msg, timeout, connection_pool):
return rv[-1]
def cast(context, topic, msg, connection_pool):
def cast(conf, context, topic, msg, connection_pool):
"""Sends a message on a topic without waiting for a response."""
LOG.debug(_('Making asynchronous cast on %s...'), topic)
pack_context(msg, context)
with ConnectionContext(connection_pool) as conn:
with ConnectionContext(conf, connection_pool) as conn:
conn.topic_send(topic, msg)
def fanout_cast(context, topic, msg, connection_pool):
def fanout_cast(conf, context, topic, msg, connection_pool):
"""Sends a message on a fanout exchange without waiting for a response."""
LOG.debug(_('Making asynchronous fanout cast...'))
pack_context(msg, context)
with ConnectionContext(connection_pool) as conn:
with ConnectionContext(conf, connection_pool) as conn:
conn.fanout_send(topic, msg)
def cast_to_server(context, server_params, topic, msg, connection_pool):
def cast_to_server(conf, context, server_params, topic, msg, connection_pool):
"""Sends a message on a topic to a specific server."""
pack_context(msg, context)
with ConnectionContext(connection_pool, pooled=False,
with ConnectionContext(conf, connection_pool, pooled=False,
server_params=server_params) as conn:
conn.topic_send(topic, msg)
def fanout_cast_to_server(context, server_params, topic, msg,
def fanout_cast_to_server(conf, context, server_params, topic, msg,
connection_pool):
"""Sends a message on a fanout exchange to a specific server."""
pack_context(msg, context)
with ConnectionContext(connection_pool, pooled=False,
with ConnectionContext(conf, connection_pool, pooled=False,
server_params=server_params) as conn:
conn.fanout_send(topic, msg)
def notify(context, topic, msg, connection_pool):
def notify(conf, context, topic, msg, connection_pool):
"""Sends a notification event on a topic."""
event_type = msg.get('event_type')
LOG.debug(_('Sending %(event_type)s on %(topic)s'), locals())
pack_context(msg, context)
with ConnectionContext(connection_pool) as conn:
with ConnectionContext(conf, connection_pool) as conn:
conn.notify_send(topic, msg)
def cleanup(connection_pool):
if connection_pool:
connection_pool.empty()

View File

@ -22,7 +22,6 @@ import sys
import traceback
from nova import exception
from nova import flags
from nova import log as logging
from nova.openstack.common import cfg
from nova import utils
@ -30,25 +29,6 @@ from nova import utils
LOG = logging.getLogger(__name__)
rpc_opts = [
cfg.IntOpt('rpc_thread_pool_size',
default=64,
help='Size of RPC thread pool'),
cfg.IntOpt('rpc_conn_pool_size',
default=30,
help='Size of RPC connection pool'),
cfg.IntOpt('rpc_response_timeout',
default=60,
help='Seconds to wait for a response from call or multicall'),
cfg.IntOpt('allowed_rpc_exception_modules',
default=['nova.exception'],
help='Modules of exceptions that are permitted to be recreated'
'upon receiving exception data from an rpc call.'),
]
flags.FLAGS.register_opts(rpc_opts)
FLAGS = flags.FLAGS
class RemoteError(exception.NovaException):
"""Signifies that a remote class has raised an exception.
@ -95,7 +75,7 @@ class Connection(object):
"""
raise NotImplementedError()
def create_consumer(self, topic, proxy, fanout=False):
def create_consumer(self, conf, topic, proxy, fanout=False):
"""Create a consumer on this connection.
A consumer is associated with a message queue on the backend message
@ -104,6 +84,7 @@ class Connection(object):
off of the queue will determine which method gets called on the proxy
object.
:param conf: An openstack.common.cfg configuration object.
:param topic: This is a name associated with what to consume from.
Multiple instances of a service may consume from the same
topic. For example, all instances of nova-compute consume
@ -197,7 +178,7 @@ def serialize_remote_exception(failure_info):
return json_data
def deserialize_remote_exception(data):
def deserialize_remote_exception(conf, data):
failure = utils.loads(str(data))
trace = failure.get('tb', [])
@ -207,7 +188,7 @@ def deserialize_remote_exception(data):
# NOTE(ameade): We DO NOT want to allow just any module to be imported, in
# order to prevent arbitrary code execution.
if not module in FLAGS.allowed_rpc_exception_modules:
if not module in conf.allowed_rpc_exception_modules:
return RemoteError(name, failure.get('message'), trace)
try:

View File

@ -27,13 +27,10 @@ import traceback
import eventlet
from nova import context
from nova import flags
from nova.rpc import common as rpc_common
CONSUMERS = {}
FLAGS = flags.FLAGS
class RpcContext(context.RequestContext):
def __init__(self, *args, **kwargs):
@ -116,7 +113,7 @@ class Connection(object):
pass
def create_connection(new=True):
def create_connection(conf, new=True):
"""Create a connection"""
return Connection()
@ -126,7 +123,7 @@ def check_serialize(msg):
json.dumps(msg)
def multicall(context, topic, msg, timeout=None):
def multicall(conf, context, topic, msg, timeout=None):
"""Make a call that returns multiple times."""
check_serialize(msg)
@ -144,9 +141,9 @@ def multicall(context, topic, msg, timeout=None):
return consumer.call(context, method, args, timeout)
def call(context, topic, msg, timeout=None):
def call(conf, context, topic, msg, timeout=None):
"""Sends a message on a topic and wait for a response."""
rv = multicall(context, topic, msg, timeout)
rv = multicall(conf, context, topic, msg, timeout)
# NOTE(vish): return the last result from the multicall
rv = list(rv)
if not rv:
@ -154,14 +151,14 @@ def call(context, topic, msg, timeout=None):
return rv[-1]
def cast(context, topic, msg):
def cast(conf, context, topic, msg):
try:
call(context, topic, msg)
call(conf, context, topic, msg)
except Exception:
pass
def notify(context, topic, msg):
def notify(conf, context, topic, msg):
check_serialize(msg)
@ -169,7 +166,7 @@ def cleanup():
pass
def fanout_cast(context, topic, msg):
def fanout_cast(conf, context, topic, msg):
"""Cast to all consumers of a topic"""
check_serialize(msg)
method = msg.get('method')
@ -182,3 +179,7 @@ def fanout_cast(context, topic, msg):
consumer.call(context, method, args, None)
except Exception:
pass
def register_opts(conf):
pass

View File

@ -28,7 +28,6 @@ import kombu.entity
import kombu.messaging
import kombu.connection
from nova import flags
from nova.openstack.common import cfg
from nova.rpc import amqp as rpc_amqp
from nova.rpc import common as rpc_common
@ -49,8 +48,6 @@ kombu_opts = [
'(valid only if SSL enabled)')),
]
FLAGS = flags.FLAGS
FLAGS.register_opts(kombu_opts)
LOG = rpc_common.LOG
@ -126,7 +123,7 @@ class ConsumerBase(object):
class DirectConsumer(ConsumerBase):
"""Queue/consumer class for 'direct'"""
def __init__(self, channel, msg_id, callback, tag, **kwargs):
def __init__(self, conf, channel, msg_id, callback, tag, **kwargs):
"""Init a 'direct' queue.
'channel' is the amqp channel to use
@ -159,7 +156,7 @@ class DirectConsumer(ConsumerBase):
class TopicConsumer(ConsumerBase):
"""Consumer class for 'topic'"""
def __init__(self, channel, topic, callback, tag, **kwargs):
def __init__(self, conf, channel, topic, callback, tag, **kwargs):
"""Init a 'topic' queue.
'channel' is the amqp channel to use
@ -170,12 +167,12 @@ class TopicConsumer(ConsumerBase):
Other kombu options may be passed
"""
# Default options
options = {'durable': FLAGS.rabbit_durable_queues,
options = {'durable': conf.rabbit_durable_queues,
'auto_delete': False,
'exclusive': False}
options.update(kwargs)
exchange = kombu.entity.Exchange(
name=FLAGS.control_exchange,
name=conf.control_exchange,
type='topic',
durable=options['durable'],
auto_delete=options['auto_delete'])
@ -192,7 +189,7 @@ class TopicConsumer(ConsumerBase):
class FanoutConsumer(ConsumerBase):
"""Consumer class for 'fanout'"""
def __init__(self, channel, topic, callback, tag, **kwargs):
def __init__(self, conf, channel, topic, callback, tag, **kwargs):
"""Init a 'fanout' queue.
'channel' is the amqp channel to use
@ -252,7 +249,7 @@ class Publisher(object):
class DirectPublisher(Publisher):
"""Publisher class for 'direct'"""
def __init__(self, channel, msg_id, **kwargs):
def __init__(self, conf, channel, msg_id, **kwargs):
"""init a 'direct' publisher.
Kombu options may be passed as keyword args to override defaults
@ -271,17 +268,17 @@ class DirectPublisher(Publisher):
class TopicPublisher(Publisher):
"""Publisher class for 'topic'"""
def __init__(self, channel, topic, **kwargs):
def __init__(self, conf, channel, topic, **kwargs):
"""init a 'topic' publisher.
Kombu options may be passed as keyword args to override defaults
"""
options = {'durable': FLAGS.rabbit_durable_queues,
options = {'durable': conf.rabbit_durable_queues,
'auto_delete': False,
'exclusive': False}
options.update(kwargs)
super(TopicPublisher, self).__init__(channel,
FLAGS.control_exchange,
conf.control_exchange,
topic,
type='topic',
**options)
@ -289,7 +286,7 @@ class TopicPublisher(Publisher):
class FanoutPublisher(Publisher):
"""Publisher class for 'fanout'"""
def __init__(self, channel, topic, **kwargs):
def __init__(self, conf, channel, topic, **kwargs):
"""init a 'fanout' publisher.
Kombu options may be passed as keyword args to override defaults
@ -308,9 +305,9 @@ class FanoutPublisher(Publisher):
class NotifyPublisher(TopicPublisher):
"""Publisher class for 'notify'"""
def __init__(self, *args, **kwargs):
self.durable = kwargs.pop('durable', FLAGS.rabbit_durable_queues)
super(NotifyPublisher, self).__init__(*args, **kwargs)
def __init__(self, conf, channel, topic, **kwargs):
self.durable = kwargs.pop('durable', conf.rabbit_durable_queues)
super(NotifyPublisher, self).__init__(conf, channel, topic, **kwargs)
def reconnect(self, channel):
super(NotifyPublisher, self).reconnect(channel)
@ -329,15 +326,18 @@ class NotifyPublisher(TopicPublisher):
class Connection(object):
"""Connection object."""
def __init__(self, server_params=None):
pool = None
def __init__(self, conf, server_params=None):
self.consumers = []
self.consumer_thread = None
self.max_retries = FLAGS.rabbit_max_retries
self.conf = conf
self.max_retries = self.conf.rabbit_max_retries
# Try forever?
if self.max_retries <= 0:
self.max_retries = None
self.interval_start = FLAGS.rabbit_retry_interval
self.interval_stepping = FLAGS.rabbit_retry_backoff
self.interval_start = self.conf.rabbit_retry_interval
self.interval_stepping = self.conf.rabbit_retry_backoff
# max retry-interval = 30 seconds
self.interval_max = 30
self.memory_transport = False
@ -353,21 +353,21 @@ class Connection(object):
p_key = server_params_to_kombu_params.get(sp_key, sp_key)
params[p_key] = value
params.setdefault('hostname', FLAGS.rabbit_host)
params.setdefault('port', FLAGS.rabbit_port)
params.setdefault('userid', FLAGS.rabbit_userid)
params.setdefault('password', FLAGS.rabbit_password)
params.setdefault('virtual_host', FLAGS.rabbit_virtual_host)
params.setdefault('hostname', self.conf.rabbit_host)
params.setdefault('port', self.conf.rabbit_port)
params.setdefault('userid', self.conf.rabbit_userid)
params.setdefault('password', self.conf.rabbit_password)
params.setdefault('virtual_host', self.conf.rabbit_virtual_host)
self.params = params
if FLAGS.fake_rabbit:
if self.conf.fake_rabbit:
self.params['transport'] = 'memory'
self.memory_transport = True
else:
self.memory_transport = False
if FLAGS.rabbit_use_ssl:
if self.conf.rabbit_use_ssl:
self.params['ssl'] = self._fetch_ssl_params()
self.connection = None
@ -379,14 +379,14 @@ class Connection(object):
ssl_params = dict()
# http://docs.python.org/library/ssl.html - ssl.wrap_socket
if FLAGS.kombu_ssl_version:
ssl_params['ssl_version'] = FLAGS.kombu_ssl_version
if FLAGS.kombu_ssl_keyfile:
ssl_params['keyfile'] = FLAGS.kombu_ssl_keyfile
if FLAGS.kombu_ssl_certfile:
ssl_params['certfile'] = FLAGS.kombu_ssl_certfile
if FLAGS.kombu_ssl_ca_certs:
ssl_params['ca_certs'] = FLAGS.kombu_ssl_ca_certs
if self.conf.kombu_ssl_version:
ssl_params['ssl_version'] = self.conf.kombu_ssl_version
if self.conf.kombu_ssl_keyfile:
ssl_params['keyfile'] = self.conf.kombu_ssl_keyfile
if self.conf.kombu_ssl_certfile:
ssl_params['certfile'] = self.conf.kombu_ssl_certfile
if self.conf.kombu_ssl_ca_certs:
ssl_params['ca_certs'] = self.conf.kombu_ssl_ca_certs
# We might want to allow variations in the
# future with this?
ssl_params['cert_reqs'] = ssl.CERT_REQUIRED
@ -534,7 +534,7 @@ class Connection(object):
"%(err_str)s") % log_info)
def _declare_consumer():
consumer = consumer_cls(self.channel, topic, callback,
consumer = consumer_cls(self.conf, self.channel, topic, callback,
self.consumer_num.next())
self.consumers.append(consumer)
return consumer
@ -590,7 +590,7 @@ class Connection(object):
"'%(topic)s': %(err_str)s") % log_info)
def _publish():
publisher = cls(self.channel, topic, **kwargs)
publisher = cls(self.conf, self.channel, topic, **kwargs)
publisher.send(msg)
self.ensure(_error_callback, _publish)
@ -648,58 +648,66 @@ class Connection(object):
def create_consumer(self, topic, proxy, fanout=False):
"""Create a consumer that calls a method in a proxy object"""
proxy_cb = rpc_amqp.ProxyCallback(self.conf, proxy,
rpc_amqp.get_connection_pool(self, Connection))
if fanout:
self.declare_fanout_consumer(topic,
rpc_amqp.ProxyCallback(proxy, Connection.pool))
self.declare_fanout_consumer(topic, proxy_cb)
else:
self.declare_topic_consumer(topic,
rpc_amqp.ProxyCallback(proxy, Connection.pool))
self.declare_topic_consumer(topic, proxy_cb)
Connection.pool = rpc_amqp.Pool(connection_cls=Connection)
def create_connection(new=True):
def create_connection(conf, new=True):
"""Create a connection"""
return rpc_amqp.create_connection(new, Connection.pool)
return rpc_amqp.create_connection(conf, new,
rpc_amqp.get_connection_pool(conf, Connection))
def multicall(context, topic, msg, timeout=None):
def multicall(conf, context, topic, msg, timeout=None):
"""Make a call that returns multiple times."""
return rpc_amqp.multicall(context, topic, msg, timeout, Connection.pool)
return rpc_amqp.multicall(conf, context, topic, msg, timeout,
rpc_amqp.get_connection_pool(conf, Connection))
def call(context, topic, msg, timeout=None):
def call(conf, context, topic, msg, timeout=None):
"""Sends a message on a topic and wait for a response."""
return rpc_amqp.call(context, topic, msg, timeout, Connection.pool)
return rpc_amqp.call(conf, context, topic, msg, timeout,
rpc_amqp.get_connection_pool(conf, Connection))
def cast(context, topic, msg):
def cast(conf, context, topic, msg):
"""Sends a message on a topic without waiting for a response."""
return rpc_amqp.cast(context, topic, msg, Connection.pool)
return rpc_amqp.cast(conf, context, topic, msg,
rpc_amqp.get_connection_pool(conf, Connection))
def fanout_cast(context, topic, msg):
def fanout_cast(conf, context, topic, msg):
"""Sends a message on a fanout exchange without waiting for a response."""
return rpc_amqp.fanout_cast(context, topic, msg, Connection.pool)
return rpc_amqp.fanout_cast(conf, context, topic, msg,
rpc_amqp.get_connection_pool(conf, Connection))
def cast_to_server(context, server_params, topic, msg):
def cast_to_server(conf, context, server_params, topic, msg):
"""Sends a message on a topic to a specific server."""
return rpc_amqp.cast_to_server(context, server_params, topic, msg,
Connection.pool)
return rpc_amqp.cast_to_server(conf, context, server_params, topic, msg,
rpc_amqp.get_connection_pool(conf, Connection))
def fanout_cast_to_server(context, server_params, topic, msg):
def fanout_cast_to_server(conf, context, server_params, topic, msg):
"""Sends a message on a fanout exchange to a specific server."""
return rpc_amqp.cast_to_server(context, server_params, topic, msg,
Connection.pool)
return rpc_amqp.cast_to_server(conf, context, server_params, topic, msg,
rpc_amqp.get_connection_pool(conf, Connection))
def notify(context, topic, msg):
def notify(conf, context, topic, msg):
"""Sends a notification event on a topic."""
return rpc_amqp.notify(context, topic, msg, Connection.pool)
return rpc_amqp.notify(conf, context, topic, msg,
rpc_amqp.get_connection_pool(conf, Connection))
def cleanup():
return rpc_amqp.cleanup(Connection.pool)
def register_opts(conf):
conf.register_opts(kombu_opts)

View File

@ -25,7 +25,6 @@ import greenlet
import qpid.messaging
import qpid.messaging.exceptions
from nova import flags
from nova import log as logging
from nova.openstack.common import cfg
from nova.rpc import amqp as rpc_amqp
@ -78,9 +77,6 @@ qpid_opts = [
help='Disable Nagle algorithm'),
]
FLAGS = flags.FLAGS
FLAGS.register_opts(qpid_opts)
class ConsumerBase(object):
"""Consumer base class."""
@ -147,7 +143,7 @@ class ConsumerBase(object):
class DirectConsumer(ConsumerBase):
"""Queue/consumer class for 'direct'"""
def __init__(self, session, msg_id, callback):
def __init__(self, conf, session, msg_id, callback):
"""Init a 'direct' queue.
'session' is the amqp session to use
@ -165,7 +161,7 @@ class DirectConsumer(ConsumerBase):
class TopicConsumer(ConsumerBase):
"""Consumer class for 'topic'"""
def __init__(self, session, topic, callback):
def __init__(self, conf, session, topic, callback):
"""Init a 'topic' queue.
'session' is the amqp session to use
@ -174,14 +170,14 @@ class TopicConsumer(ConsumerBase):
"""
super(TopicConsumer, self).__init__(session, callback,
"%s/%s" % (FLAGS.control_exchange, topic), {},
"%s/%s" % (conf.control_exchange, topic), {},
topic, {})
class FanoutConsumer(ConsumerBase):
"""Consumer class for 'fanout'"""
def __init__(self, session, topic, callback):
def __init__(self, conf, session, topic, callback):
"""Init a 'fanout' queue.
'session' is the amqp session to use
@ -236,7 +232,7 @@ class Publisher(object):
class DirectPublisher(Publisher):
"""Publisher class for 'direct'"""
def __init__(self, session, msg_id):
def __init__(self, conf, session, msg_id):
"""Init a 'direct' publisher."""
super(DirectPublisher, self).__init__(session, msg_id,
{"type": "Direct"})
@ -244,16 +240,16 @@ class DirectPublisher(Publisher):
class TopicPublisher(Publisher):
"""Publisher class for 'topic'"""
def __init__(self, session, topic):
def __init__(self, conf, session, topic):
"""init a 'topic' publisher.
"""
super(TopicPublisher, self).__init__(session,
"%s/%s" % (FLAGS.control_exchange, topic))
"%s/%s" % (conf.control_exchange, topic))
class FanoutPublisher(Publisher):
"""Publisher class for 'fanout'"""
def __init__(self, session, topic):
def __init__(self, conf, session, topic):
"""init a 'fanout' publisher.
"""
super(FanoutPublisher, self).__init__(session,
@ -262,29 +258,32 @@ class FanoutPublisher(Publisher):
class NotifyPublisher(Publisher):
"""Publisher class for notifications"""
def __init__(self, session, topic):
def __init__(self, conf, session, topic):
"""init a 'topic' publisher.
"""
super(NotifyPublisher, self).__init__(session,
"%s/%s" % (FLAGS.control_exchange, topic),
"%s/%s" % (conf.control_exchange, topic),
{"durable": True})
class Connection(object):
"""Connection object."""
def __init__(self, server_params=None):
pool = None
def __init__(self, conf, server_params=None):
self.session = None
self.consumers = {}
self.consumer_thread = None
self.conf = conf
if server_params is None:
server_params = {}
default_params = dict(hostname=FLAGS.qpid_hostname,
port=FLAGS.qpid_port,
username=FLAGS.qpid_username,
password=FLAGS.qpid_password)
default_params = dict(hostname=self.conf.qpid_hostname,
port=self.conf.qpid_port,
username=self.conf.qpid_username,
password=self.conf.qpid_password)
params = server_params
for key in default_params.keys():
@ -298,23 +297,25 @@ class Connection(object):
# before we call open
self.connection.username = params['username']
self.connection.password = params['password']
self.connection.sasl_mechanisms = FLAGS.qpid_sasl_mechanisms
self.connection.reconnect = FLAGS.qpid_reconnect
if FLAGS.qpid_reconnect_timeout:
self.connection.reconnect_timeout = FLAGS.qpid_reconnect_timeout
if FLAGS.qpid_reconnect_limit:
self.connection.reconnect_limit = FLAGS.qpid_reconnect_limit
if FLAGS.qpid_reconnect_interval_max:
self.connection.sasl_mechanisms = self.conf.qpid_sasl_mechanisms
self.connection.reconnect = self.conf.qpid_reconnect
if self.conf.qpid_reconnect_timeout:
self.connection.reconnect_timeout = (
self.conf.qpid_reconnect_timeout)
if self.conf.qpid_reconnect_limit:
self.connection.reconnect_limit = self.conf.qpid_reconnect_limit
if self.conf.qpid_reconnect_interval_max:
self.connection.reconnect_interval_max = (
FLAGS.qpid_reconnect_interval_max)
if FLAGS.qpid_reconnect_interval_min:
self.conf.qpid_reconnect_interval_max)
if self.conf.qpid_reconnect_interval_min:
self.connection.reconnect_interval_min = (
FLAGS.qpid_reconnect_interval_min)
if FLAGS.qpid_reconnect_interval:
self.connection.reconnect_interval = FLAGS.qpid_reconnect_interval
self.connection.hearbeat = FLAGS.qpid_heartbeat
self.connection.protocol = FLAGS.qpid_protocol
self.connection.tcp_nodelay = FLAGS.qpid_tcp_nodelay
self.conf.qpid_reconnect_interval_min)
if self.conf.qpid_reconnect_interval:
self.connection.reconnect_interval = (
self.conf.qpid_reconnect_interval)
self.connection.hearbeat = self.conf.qpid_heartbeat
self.connection.protocol = self.conf.qpid_protocol
self.connection.tcp_nodelay = self.conf.qpid_tcp_nodelay
# Open is part of reconnect -
# NOTE(WGH) not sure we need this with the reconnect flags
@ -339,7 +340,7 @@ class Connection(object):
self.connection.open()
except qpid.messaging.exceptions.ConnectionError, e:
LOG.error(_('Unable to connect to AMQP server: %s'), e)
time.sleep(FLAGS.qpid_reconnect_interval or 1)
time.sleep(self.conf.qpid_reconnect_interval or 1)
else:
break
@ -386,7 +387,7 @@ class Connection(object):
"%(err_str)s") % log_info)
def _declare_consumer():
consumer = consumer_cls(self.session, topic, callback)
consumer = consumer_cls(self.conf, self.session, topic, callback)
self._register_consumer(consumer)
return consumer
@ -435,7 +436,7 @@ class Connection(object):
"'%(topic)s': %(err_str)s") % log_info)
def _publisher_send():
publisher = cls(self.session, topic)
publisher = cls(self.conf, self.session, topic)
publisher.send(msg)
return self.ensure(_connect_error, _publisher_send)
@ -493,60 +494,70 @@ class Connection(object):
def create_consumer(self, topic, proxy, fanout=False):
"""Create a consumer that calls a method in a proxy object"""
proxy_cb = rpc_amqp.ProxyCallback(self.conf, proxy,
rpc_amqp.get_connection_pool(self, Connection))
if fanout:
consumer = FanoutConsumer(self.session, topic,
rpc_amqp.ProxyCallback(proxy, Connection.pool))
consumer = FanoutConsumer(self.conf, self.session, topic, proxy_cb)
else:
consumer = TopicConsumer(self.session, topic,
rpc_amqp.ProxyCallback(proxy, Connection.pool))
consumer = TopicConsumer(self.conf, self.session, topic, proxy_cb)
self._register_consumer(consumer)
return consumer
Connection.pool = rpc_amqp.Pool(connection_cls=Connection)
def create_connection(new=True):
def create_connection(conf, new=True):
"""Create a connection"""
return rpc_amqp.create_connection(new, Connection.pool)
return rpc_amqp.create_connection(conf, new,
rpc_amqp.get_connection_pool(conf, Connection))
def multicall(context, topic, msg, timeout=None):
def multicall(conf, context, topic, msg, timeout=None):
"""Make a call that returns multiple times."""
return rpc_amqp.multicall(context, topic, msg, timeout, Connection.pool)
return rpc_amqp.multicall(conf, context, topic, msg, timeout,
rpc_amqp.get_connection_pool(conf, Connection))
def call(context, topic, msg, timeout=None):
def call(conf, context, topic, msg, timeout=None):
"""Sends a message on a topic and wait for a response."""
return rpc_amqp.call(context, topic, msg, timeout, Connection.pool)
return rpc_amqp.call(conf, context, topic, msg, timeout,
rpc_amqp.get_connection_pool(conf, Connection))
def cast(context, topic, msg):
def cast(conf, context, topic, msg):
"""Sends a message on a topic without waiting for a response."""
return rpc_amqp.cast(context, topic, msg, Connection.pool)
return rpc_amqp.cast(conf, context, topic, msg,
rpc_amqp.get_connection_pool(conf, Connection))
def fanout_cast(context, topic, msg):
def fanout_cast(conf, context, topic, msg):
"""Sends a message on a fanout exchange without waiting for a response."""
return rpc_amqp.fanout_cast(context, topic, msg, Connection.pool)
return rpc_amqp.fanout_cast(conf, context, topic, msg,
rpc_amqp.get_connection_pool(conf, Connection))
def cast_to_server(context, server_params, topic, msg):
def cast_to_server(conf, context, server_params, topic, msg):
"""Sends a message on a topic to a specific server."""
return rpc_amqp.cast_to_server(context, server_params, topic, msg,
Connection.pool)
return rpc_amqp.cast_to_server(conf, context, server_params, topic, msg,
rpc_amqp.get_connection_pool(conf, Connection))
def fanout_cast_to_server(context, server_params, topic, msg):
def fanout_cast_to_server(conf, context, server_params, topic, msg):
"""Sends a message on a fanout exchange to a specific server."""
return rpc_amqp.fanout_cast_to_server(context, server_params, topic,
msg, Connection.pool)
return rpc_amqp.fanout_cast_to_server(conf, context, server_params, topic,
msg, rpc_amqp.get_connection_pool(conf, Connection))
def notify(context, topic, msg):
def notify(conf, context, topic, msg):
"""Sends a notification event on a topic."""
return rpc_amqp.notify(context, topic, msg, Connection.pool)
return rpc_amqp.notify(conf, context, topic, msg,
rpc_amqp.get_connection_pool(conf, Connection))
def cleanup():
return rpc_amqp.cleanup(Connection.pool)
def register_opts(conf):
conf.register_opts(qpid_opts)

View File

@ -177,6 +177,7 @@ class Service(object):
LOG.audit(_('Starting %(topic)s node (version %(vcs_string)s)'),
{'topic': self.topic, 'vcs_string': vcs_string})
utils.cleanup_file_locks()
rpc.register_opts(FLAGS)
self.manager.init_host()
self.model_disconnected = False
ctxt = context.get_admin_context()
@ -393,6 +394,7 @@ class WSGIService(object):
"""
utils.cleanup_file_locks()
rpc.register_opts(FLAGS)
if self.manager:
self.manager.init_host()
self.server.start()

View File

@ -59,6 +59,9 @@ def reset_db():
def setup():
import mox # Fail fast if you don't have mox. Workaround for bug 810424
from nova import rpc # Register rpc_backend before fake_flags sets it
FLAGS.register_opts(rpc.rpc_opts)
from nova import context
from nova import db
from nova.db import migration

View File

@ -26,19 +26,21 @@ import nose
from nova import context
from nova import exception
from nova import flags
from nova import log as logging
from nova.rpc import amqp as rpc_amqp
from nova.rpc import common as rpc_common
from nova import test
FLAGS = flags.FLAGS
LOG = logging.getLogger(__name__)
class BaseRpcTestCase(test.TestCase):
def setUp(self, supports_timeouts=True):
super(BaseRpcTestCase, self).setUp()
self.conn = self.rpc.create_connection(True)
self.conn = self.rpc.create_connection(FLAGS, True)
self.receiver = TestReceiver()
self.conn.create_consumer('test', self.receiver, False)
self.conn.consume_in_thread()
@ -51,20 +53,20 @@ class BaseRpcTestCase(test.TestCase):
def test_call_succeed(self):
value = 42
result = self.rpc.call(self.context, 'test', {"method": "echo",
"args": {"value": value}})
result = self.rpc.call(FLAGS, self.context, 'test',
{"method": "echo", "args": {"value": value}})
self.assertEqual(value, result)
def test_call_succeed_despite_multiple_returns_yield(self):
value = 42
result = self.rpc.call(self.context, 'test',
result = self.rpc.call(FLAGS, self.context, 'test',
{"method": "echo_three_times_yield",
"args": {"value": value}})
self.assertEqual(value + 2, result)
def test_multicall_succeed_once(self):
value = 42
result = self.rpc.multicall(self.context,
result = self.rpc.multicall(FLAGS, self.context,
'test',
{"method": "echo",
"args": {"value": value}})
@ -75,7 +77,7 @@ class BaseRpcTestCase(test.TestCase):
def test_multicall_three_nones(self):
value = 42
result = self.rpc.multicall(self.context,
result = self.rpc.multicall(FLAGS, self.context,
'test',
{"method": "multicall_three_nones",
"args": {"value": value}})
@ -86,7 +88,7 @@ class BaseRpcTestCase(test.TestCase):
def test_multicall_succeed_three_times_yield(self):
value = 42
result = self.rpc.multicall(self.context,
result = self.rpc.multicall(FLAGS, self.context,
'test',
{"method": "echo_three_times_yield",
"args": {"value": value}})
@ -96,7 +98,7 @@ class BaseRpcTestCase(test.TestCase):
def test_context_passed(self):
"""Makes sure a context is passed through rpc call."""
value = 42
result = self.rpc.call(self.context,
result = self.rpc.call(FLAGS, self.context,
'test', {"method": "context",
"args": {"value": value}})
self.assertEqual(self.context.to_dict(), result)
@ -112,7 +114,7 @@ class BaseRpcTestCase(test.TestCase):
# TODO(comstud):
# so, it will replay the context and use the same REQID?
# that's bizarre.
ret = self.rpc.call(context,
ret = self.rpc.call(FLAGS, context,
queue,
{"method": "echo",
"args": {"value": value}})
@ -120,11 +122,11 @@ class BaseRpcTestCase(test.TestCase):
return value
nested = Nested()
conn = self.rpc.create_connection(True)
conn = self.rpc.create_connection(FLAGS, True)
conn.create_consumer('nested', nested, False)
conn.consume_in_thread()
value = 42
result = self.rpc.call(self.context,
result = self.rpc.call(FLAGS, self.context,
'nested', {"method": "echo",
"args": {"queue": "test",
"value": value}})
@ -139,12 +141,12 @@ class BaseRpcTestCase(test.TestCase):
value = 42
self.assertRaises(rpc_common.Timeout,
self.rpc.call,
self.context,
FLAGS, self.context,
'test',
{"method": "block",
"args": {"value": value}}, timeout=1)
try:
self.rpc.call(self.context,
self.rpc.call(FLAGS, self.context,
'test',
{"method": "block",
"args": {"value": value}},
@ -169,8 +171,8 @@ class BaseRpcAMQPTestCase(BaseRpcTestCase):
self.stubs.Set(rpc_amqp, 'unpack_context', fake_unpack_context)
value = 41
self.rpc.cast(self.context, 'test', {"method": "echo",
"args": {"value": value}})
self.rpc.cast(FLAGS, self.context, 'test',
{"method": "echo", "args": {"value": value}})
# Wait for the cast to complete.
for x in xrange(50):
@ -185,7 +187,7 @@ class BaseRpcAMQPTestCase(BaseRpcTestCase):
self.stubs.Set(rpc_amqp, 'unpack_context', orig_unpack)
value = 42
result = self.rpc.call(self.context, 'test',
result = self.rpc.call(FLAGS, self.context, 'test',
{"method": "echo",
"args": {"value": value}})
self.assertEqual(value, result)

View File

@ -93,7 +93,7 @@ class RpcCommonTestCase(test.TestCase):
}
serialized = json.dumps(failure)
after_exc = rpc_common.deserialize_remote_exception(serialized)
after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized)
self.assertTrue(isinstance(after_exc, exception.NovaException))
self.assertTrue('test message' in unicode(after_exc))
#assure the traceback was added
@ -108,7 +108,7 @@ class RpcCommonTestCase(test.TestCase):
}
serialized = json.dumps(failure)
after_exc = rpc_common.deserialize_remote_exception(serialized)
after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized)
self.assertTrue(isinstance(after_exc, rpc_common.RemoteError))
def test_deserialize_remote_exception_user_defined_exception(self):
@ -121,7 +121,7 @@ class RpcCommonTestCase(test.TestCase):
}
serialized = json.dumps(failure)
after_exc = rpc_common.deserialize_remote_exception(serialized)
after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized)
self.assertTrue(isinstance(after_exc, FakeUserDefinedException))
#assure the traceback was added
self.assertTrue('raise FakeUserDefinedException' in unicode(after_exc))
@ -141,7 +141,7 @@ class RpcCommonTestCase(test.TestCase):
}
serialized = json.dumps(failure)
after_exc = rpc_common.deserialize_remote_exception(serialized)
after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized)
self.assertTrue(isinstance(after_exc, rpc_common.RemoteError))
#assure the traceback was added
self.assertTrue('raise FakeIDontExistException' in unicode(after_exc))

View File

@ -53,6 +53,7 @@ def _raise_exc_stub(stubs, times, obj, method, exc_msg,
class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
def setUp(self):
self.rpc = impl_kombu
impl_kombu.register_opts(FLAGS)
super(RpcKombuTestCase, self).setUp()
def tearDown(self):
@ -61,10 +62,10 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
def test_reusing_connection(self):
"""Test that reusing a connection returns same one."""
conn_context = self.rpc.create_connection(new=False)
conn_context = self.rpc.create_connection(FLAGS, new=False)
conn1 = conn_context.connection
conn_context.close()
conn_context = self.rpc.create_connection(new=False)
conn_context = self.rpc.create_connection(FLAGS, new=False)
conn2 = conn_context.connection
conn_context.close()
self.assertEqual(conn1, conn2)
@ -72,7 +73,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
def test_topic_send_receive(self):
"""Test sending to a topic exchange/queue"""
conn = self.rpc.create_connection()
conn = self.rpc.create_connection(FLAGS)
message = 'topic test message'
self.received_message = None
@ -89,7 +90,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
def test_direct_send_receive(self):
"""Test sending to a direct exchange/queue"""
conn = self.rpc.create_connection()
conn = self.rpc.create_connection(FLAGS)
message = 'direct test message'
self.received_message = None
@ -123,10 +124,10 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
def topic_send(_context, topic, msg):
pass
MyConnection.pool = rpc_amqp.Pool(connection_cls=MyConnection)
MyConnection.pool = rpc_amqp.Pool(FLAGS, MyConnection)
self.stubs.Set(impl_kombu, 'Connection', MyConnection)
impl_kombu.cast(ctxt, 'fake_topic', {'msg': 'fake'})
impl_kombu.cast(FLAGS, ctxt, 'fake_topic', {'msg': 'fake'})
def test_cast_to_server_uses_server_params(self):
"""Test kombu rpc.cast"""
@ -153,10 +154,10 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
def topic_send(_context, topic, msg):
pass
MyConnection.pool = rpc_amqp.Pool(connection_cls=MyConnection)
MyConnection.pool = rpc_amqp.Pool(FLAGS, MyConnection)
self.stubs.Set(impl_kombu, 'Connection', MyConnection)
impl_kombu.cast_to_server(ctxt, server_params,
impl_kombu.cast_to_server(FLAGS, ctxt, server_params,
'fake_topic', {'msg': 'fake'})
@test.skip_test("kombu memory transport seems buggy with fanout queues "
@ -192,7 +193,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectConsumer,
'__init__', 'foo timeout foo')
conn = self.rpc.Connection()
conn = self.rpc.Connection(FLAGS)
result = conn.declare_consumer(self.rpc.DirectConsumer,
'test_topic', None)
@ -206,7 +207,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectConsumer,
'__init__', 'meow')
conn = self.rpc.Connection()
conn = self.rpc.Connection(FLAGS)
conn.connection_errors = (MyException, )
result = conn.declare_consumer(self.rpc.DirectConsumer,
@ -220,7 +221,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectConsumer,
'__init__', 'Socket closed', exc_class=IOError)
conn = self.rpc.Connection()
conn = self.rpc.Connection(FLAGS)
result = conn.declare_consumer(self.rpc.DirectConsumer,
'test_topic', None)
@ -234,7 +235,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectPublisher,
'__init__', 'foo timeout foo')
conn = self.rpc.Connection()
conn = self.rpc.Connection(FLAGS)
conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg')
self.assertEqual(info['called'], 3)
@ -243,7 +244,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectPublisher,
'send', 'foo timeout foo')
conn = self.rpc.Connection()
conn = self.rpc.Connection(FLAGS)
conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg')
self.assertEqual(info['called'], 3)
@ -256,7 +257,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectPublisher,
'__init__', 'meow')
conn = self.rpc.Connection()
conn = self.rpc.Connection(FLAGS)
conn.connection_errors = (MyException, )
conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg')
@ -267,7 +268,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectPublisher,
'send', 'meow')
conn = self.rpc.Connection()
conn = self.rpc.Connection(FLAGS)
conn.connection_errors = (MyException, )
conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg')
@ -275,7 +276,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
self.assertEqual(info['called'], 2)
def test_iterconsume_errors_will_reconnect(self):
conn = self.rpc.Connection()
conn = self.rpc.Connection(FLAGS)
message = 'reconnect test message'
self.received_message = None
@ -305,12 +306,13 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
value = "This is the exception message"
self.assertRaises(NotImplementedError,
self.rpc.call,
FLAGS,
self.context,
'test',
{"method": "fail",
"args": {"value": value}})
try:
self.rpc.call(self.context,
self.rpc.call(FLAGS, self.context,
'test',
{"method": "fail",
"args": {"value": value}})
@ -330,12 +332,13 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
value = "This is the exception message"
self.assertRaises(exception.ConvertedException,
self.rpc.call,
FLAGS,
self.context,
'test',
{"method": "fail_converted",
"args": {"value": value}})
try:
self.rpc.call(self.context,
self.rpc.call(FLAGS, self.context,
'test',
{"method": "fail_converted",
"args": {"value": value}})

View File

@ -19,6 +19,7 @@
Unit Tests for remote procedure calls using kombu + ssl
"""
from nova import flags
from nova import test
from nova.rpc import impl_kombu
@ -28,11 +29,14 @@ SSL_CERT = "/tmp/cert.blah.blah"
SSL_CA_CERT = "/tmp/cert.ca.blah.blah"
SSL_KEYFILE = "/tmp/keyfile.blah.blah"
FLAGS = flags.FLAGS
class RpcKombuSslTestCase(test.TestCase):
def setUp(self):
super(RpcKombuSslTestCase, self).setUp()
impl_kombu.register_opts(FLAGS)
self.flags(kombu_ssl_keyfile=SSL_KEYFILE,
kombu_ssl_ca_certs=SSL_CA_CERT,
kombu_ssl_certfile=SSL_CERT,
@ -41,7 +45,7 @@ class RpcKombuSslTestCase(test.TestCase):
def test_ssl_on_extended(self):
rpc = impl_kombu
conn = rpc.create_connection(True)
conn = rpc.create_connection(FLAGS, True)
c = conn.connection
#This might be kombu version dependent...
#Since we are now peaking into the internals of kombu...

View File

@ -23,6 +23,7 @@ Unit Tests for remote procedure calls using qpid
import mox
from nova import context
from nova import flags
from nova import log as logging
from nova.rpc import amqp as rpc_amqp
from nova import test
@ -35,6 +36,7 @@ except ImportError:
impl_qpid = None
FLAGS = flags.FLAGS
LOG = logging.getLogger(__name__)
@ -64,6 +66,7 @@ class RpcQpidTestCase(test.TestCase):
self.mock_receiver = None
if qpid:
impl_qpid.register_opts(FLAGS)
self.orig_connection = qpid.messaging.Connection
self.orig_session = qpid.messaging.Session
self.orig_sender = qpid.messaging.Sender
@ -98,7 +101,7 @@ class RpcQpidTestCase(test.TestCase):
self.mox.ReplayAll()
connection = impl_qpid.create_connection()
connection = impl_qpid.create_connection(FLAGS)
connection.close()
def _test_create_consumer(self, fanout):
@ -130,7 +133,7 @@ class RpcQpidTestCase(test.TestCase):
self.mox.ReplayAll()
connection = impl_qpid.create_connection()
connection = impl_qpid.create_connection(FLAGS)
connection.create_consumer("impl_qpid_test",
lambda *_x, **_y: None,
fanout)
@ -176,11 +179,11 @@ class RpcQpidTestCase(test.TestCase):
try:
ctx = context.RequestContext("user", "project")
args = [ctx, "impl_qpid_test",
args = [FLAGS, ctx, "impl_qpid_test",
{"method": "test_method", "args": {}}]
if server_params:
args.insert(1, server_params)
args.insert(2, server_params)
if fanout:
method = impl_qpid.fanout_cast_to_server
else:
@ -218,7 +221,7 @@ class RpcQpidTestCase(test.TestCase):
server_params['hostname'] + ':' +
str(server_params['port']))
MyConnection.pool = rpc_amqp.Pool(connection_cls=MyConnection)
MyConnection.pool = rpc_amqp.Pool(FLAGS, MyConnection)
self.stubs.Set(impl_qpid, 'Connection', MyConnection)
@test.skip_if(qpid is None, "Test requires qpid")
@ -295,7 +298,7 @@ class RpcQpidTestCase(test.TestCase):
else:
method = impl_qpid.call
res = method(ctx, "impl_qpid_test",
res = method(FLAGS, ctx, "impl_qpid_test",
{"method": "test_method", "args": {}})
if multi: