Fix RPC responses to allow None response correctly.

Fixes bug 897155

Also adds a new fake rpc implementation that tests use by default.
This speeds up the test run by ~10% on my system.  We can decide to
ditch fake_rabbit at some point later..

Change-Id: I8877fad3d41ae055c15b1adff99e535c34e9ce92
This commit is contained in:
Chris Behrens
2011-11-29 09:01:16 -08:00
parent 7ec9c67576
commit 62a2580fce
13 changed files with 294 additions and 72 deletions

View File

@@ -266,14 +266,13 @@ class AdapterConsumer(Consumer):
# we just log the message and send an error string # we just log the message and send an error string
# back to the caller # back to the caller
LOG.warn(_('no method for message: %s') % message_data) LOG.warn(_('no method for message: %s') % message_data)
if msg_id: ctxt.reply(msg_id,
msg_reply(msg_id,
_('No method for message: %s') % message_data) _('No method for message: %s') % message_data)
return return
self.pool.spawn_n(self._process_data, msg_id, ctxt, method, args) self.pool.spawn_n(self._process_data, ctxt, method, args)
@exception.wrap_exception() @exception.wrap_exception()
def _process_data(self, msg_id, ctxt, method, args): def _process_data(self, ctxt, method, args):
"""Thread that magically looks for a method on the proxy """Thread that magically looks for a method on the proxy
object and calls it. object and calls it.
""" """
@@ -283,23 +282,18 @@ class AdapterConsumer(Consumer):
# NOTE(vish): magic is fun! # NOTE(vish): magic is fun!
try: try:
rval = node_func(context=ctxt, **node_args) rval = node_func(context=ctxt, **node_args)
if msg_id:
# Check if the result was a generator # Check if the result was a generator
if isinstance(rval, types.GeneratorType): if isinstance(rval, types.GeneratorType):
for x in rval: for x in rval:
msg_reply(msg_id, x, None) ctxt.reply(x, None)
else: else:
msg_reply(msg_id, rval, None) ctxt.reply(rval, None)
# This final None tells multicall that it is done. # This final None tells multicall that it is done.
msg_reply(msg_id, None, None) ctxt.reply(ending=True)
elif isinstance(rval, types.GeneratorType):
# NOTE(vish): this iterates through the generator
list(rval)
except Exception as e: except Exception as e:
LOG.exception('Exception during message handling') LOG.exception('Exception during message handling')
if msg_id: ctxt.reply(None, sys.exc_info())
msg_reply(msg_id, None, sys.exc_info())
return return
@@ -447,7 +441,7 @@ class DirectPublisher(Publisher):
super(DirectPublisher, self).__init__(connection=connection) super(DirectPublisher, self).__init__(connection=connection)
def msg_reply(msg_id, reply=None, failure=None): def msg_reply(msg_id, reply=None, failure=None, ending=False):
"""Sends a reply or an error on the channel signified by msg_id. """Sends a reply or an error on the channel signified by msg_id.
Failure should be a sys.exc_info() tuple. Failure should be a sys.exc_info() tuple.
@@ -463,12 +457,17 @@ def msg_reply(msg_id, reply=None, failure=None):
with ConnectionPool.item() as conn: with ConnectionPool.item() as conn:
publisher = DirectPublisher(connection=conn, msg_id=msg_id) publisher = DirectPublisher(connection=conn, msg_id=msg_id)
try: try:
publisher.send({'result': reply, 'failure': failure}) msg = {'result': reply, 'failure': failure}
if ending:
msg['ending'] = True
publisher.send(msg)
except TypeError: except TypeError:
publisher.send( msg = {'result': dict((k, repr(v))
{'result': dict((k, repr(v))
for k, v in reply.__dict__.iteritems()), for k, v in reply.__dict__.iteritems()),
'failure': failure}) 'failure': failure}
if ending:
msg['ending'] = True
publisher.send(msg)
publisher.close() publisher.close()
@@ -508,8 +507,11 @@ class RpcContext(context.RequestContext):
self.msg_id = msg_id self.msg_id = msg_id
super(RpcContext, self).__init__(*args, **kwargs) super(RpcContext, self).__init__(*args, **kwargs)
def reply(self, *args, **kwargs): def reply(self, reply=None, failure=None, ending=False):
msg_reply(self.msg_id, *args, **kwargs) if self.msg_id:
msg_reply(self.msg_id, reply, failure, ending)
if ending:
self.msg_id = None
def multicall(context, topic, msg): def multicall(context, topic, msg):
@@ -537,8 +539,11 @@ class MulticallWaiter(object):
self._consumer = consumer self._consumer = consumer
self._results = queue.Queue() self._results = queue.Queue()
self._closed = False self._closed = False
self._got_ending = False
def close(self): def close(self):
if self._closed:
return
self._closed = True self._closed = True
self._consumer.close() self._consumer.close()
ConnectionPool.put(self._consumer.connection) ConnectionPool.put(self._consumer.connection)
@@ -548,6 +553,8 @@ class MulticallWaiter(object):
message.ack() message.ack()
if data['failure']: if data['failure']:
self._results.put(RemoteError(*data['failure'])) self._results.put(RemoteError(*data['failure']))
elif data.get('ending', False):
self._got_ending = True
else: else:
self._results.put(data['result']) self._results.put(data['result'])
@@ -555,23 +562,22 @@ class MulticallWaiter(object):
return self.wait() return self.wait()
def wait(self): def wait(self):
while True: while not self._closed:
rv = None
while rv is None and not self._closed:
try: try:
rv = self._consumer.fetch(enable_callbacks=True) rv = self._consumer.fetch(enable_callbacks=True)
except Exception: except Exception:
self.close() self.close()
raise raise
if rv is None:
time.sleep(0.01) time.sleep(0.01)
continue
if self._got_ending:
self.close()
raise StopIteration
result = self._results.get() result = self._results.get()
if isinstance(result, Exception): if isinstance(result, Exception):
self.close() self.close()
raise result raise result
if result == None:
self.close()
raise StopIteration
yield result yield result

146
nova/rpc/impl_fake.py Normal file
View File

@@ -0,0 +1,146 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2011 OpenStack LLC
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Fake RPC implementation which calls proxy methods directly with no
queues. Casts will block, but this is very useful for tests.
"""
import sys
import traceback
import types
from nova import context
from nova.rpc import common as rpc_common
CONSUMERS = {}
class RpcContext(context.RequestContext):
def __init__(self, *args, **kwargs):
super(RpcContext, self).__init__(*args, **kwargs)
self._response = []
self._done = False
def reply(self, reply=None, failure=None, ending=False):
if ending:
self._done = True
if not self._done:
self._response.append((reply, failure))
class Consumer(object):
def __init__(self, topic, proxy):
self.topic = topic
self.proxy = proxy
def call(self, context, method, args):
node_func = getattr(self.proxy, method)
node_args = dict((str(k), v) for k, v in args.iteritems())
ctxt = RpcContext.from_dict(context.to_dict())
try:
rval = node_func(context=ctxt, **node_args)
# Caller might have called ctxt.reply() manually
for (reply, failure) in ctxt._response:
if failure:
raise failure[0], failure[1], failure[2]
yield reply
# if ending not 'sent'...we might have more data to
# return from the function itself
if not ctxt._done:
if isinstance(rval, types.GeneratorType):
for val in rval:
yield val
else:
yield rval
except Exception:
exc_info = sys.exc_info()
raise rpc_common.RemoteError(exc_info[0].__name__,
str(exc_info[1]),
traceback.format_exception(*exc_info))
class Connection(object):
"""Connection object."""
def __init__(self):
self.consumers = []
def create_consumer(self, topic, proxy, fanout=False):
consumer = Consumer(topic, proxy)
self.consumers.append(consumer)
if topic not in CONSUMERS:
CONSUMERS[topic] = []
CONSUMERS[topic].append(consumer)
def close(self):
for consumer in self.consumers:
CONSUMERS[consumer.topic].remove(consumer)
self.consumers = []
def consume_in_thread(self):
pass
def create_connection(new=True):
"""Create a connection"""
return Connection()
def multicall(context, topic, msg):
"""Make a call that returns multiple times."""
method = msg.get('method')
if not method:
return
args = msg.get('args', {})
try:
consumer = CONSUMERS[topic][0]
except (KeyError, IndexError):
return iter([None])
else:
return consumer.call(context, method, args)
def call(context, topic, msg):
"""Sends a message on a topic and wait for a response."""
rv = multicall(context, topic, msg)
# NOTE(vish): return the last result from the multicall
rv = list(rv)
if not rv:
return
return rv[-1]
def cast(context, topic, msg):
try:
call(context, topic, msg)
except rpc_common.RemoteError:
pass
def fanout_cast(context, topic, msg):
"""Cast to all consumers of a topic"""
method = msg.get('method')
if not method:
return
args = msg.get('args', {})
for consumer in CONSUMERS.get(topic, []):
try:
consumer.call(context, method, args)
except rpc_common.RemoteError:
pass

View File

@@ -625,7 +625,7 @@ class ProxyCallback(object):
else: else:
ctxt.reply(rval, None) ctxt.reply(rval, None)
# This final None tells multicall that it is done. # This final None tells multicall that it is done.
ctxt.reply(None, None) ctxt.reply(ending=True)
except Exception as e: except Exception as e:
LOG.exception('Exception during message handling') LOG.exception('Exception during message handling')
ctxt.reply(None, sys.exc_info()) ctxt.reply(None, sys.exc_info())
@@ -668,9 +668,11 @@ class RpcContext(context.RequestContext):
self.msg_id = msg_id self.msg_id = msg_id
super(RpcContext, self).__init__(*args, **kwargs) super(RpcContext, self).__init__(*args, **kwargs)
def reply(self, *args, **kwargs): def reply(self, reply=None, failure=None, ending=False):
if self.msg_id: if self.msg_id:
msg_reply(self.msg_id, *args, **kwargs) msg_reply(self.msg_id, reply, failure, ending)
if ending:
self.msg_id = None
class MulticallWaiter(object): class MulticallWaiter(object):
@@ -679,8 +681,11 @@ class MulticallWaiter(object):
self._iterator = connection.iterconsume() self._iterator = connection.iterconsume()
self._result = None self._result = None
self._done = False self._done = False
self._got_ending = False
def done(self): def done(self):
if self._done:
return
self._done = True self._done = True
self._iterator.close() self._iterator.close()
self._iterator = None self._iterator = None
@@ -690,6 +695,8 @@ class MulticallWaiter(object):
"""The consume() callback will call this. Store the result.""" """The consume() callback will call this. Store the result."""
if data['failure']: if data['failure']:
self._result = RemoteError(*data['failure']) self._result = RemoteError(*data['failure'])
elif data.get('ending', False):
self._got_ending = True
else: else:
self._result = data['result'] self._result = data['result']
@@ -699,13 +706,13 @@ class MulticallWaiter(object):
raise StopIteration raise StopIteration
while True: while True:
self._iterator.next() self._iterator.next()
if self._got_ending:
self.done()
raise StopIteration
result = self._result result = self._result
if isinstance(result, Exception): if isinstance(result, Exception):
self.done() self.done()
raise result raise result
if result == None:
self.done()
raise StopIteration
yield result yield result
@@ -759,7 +766,7 @@ def fanout_cast(context, topic, msg):
conn.fanout_send(topic, msg) conn.fanout_send(topic, msg)
def msg_reply(msg_id, reply=None, failure=None): def msg_reply(msg_id, reply=None, failure=None, ending=False):
"""Sends a reply or an error on the channel signified by msg_id. """Sends a reply or an error on the channel signified by msg_id.
Failure should be a sys.exc_info() tuple. Failure should be a sys.exc_info() tuple.
@@ -779,4 +786,6 @@ def msg_reply(msg_id, reply=None, failure=None):
msg = {'result': dict((k, repr(v)) msg = {'result': dict((k, repr(v))
for k, v in reply.__dict__.iteritems()), for k, v in reply.__dict__.iteritems()),
'failure': failure} 'failure': failure}
if ending:
msg['ending'] = True
conn.direct_send(msg_id, msg) conn.direct_send(msg_id, msg)

View File

@@ -24,6 +24,7 @@ flags.DECLARE('volume_driver', 'nova.volume.manager')
FLAGS['volume_driver'].SetDefault('nova.volume.driver.FakeISCSIDriver') FLAGS['volume_driver'].SetDefault('nova.volume.driver.FakeISCSIDriver')
FLAGS['connection_type'].SetDefault('fake') FLAGS['connection_type'].SetDefault('fake')
FLAGS['fake_rabbit'].SetDefault(True) FLAGS['fake_rabbit'].SetDefault(True)
FLAGS['rpc_backend'].SetDefault('nova.rpc.impl_fake')
flags.DECLARE('auth_driver', 'nova.auth.manager') flags.DECLARE('auth_driver', 'nova.auth.manager')
FLAGS['auth_driver'].SetDefault('nova.auth.dbdriver.DbDriver') FLAGS['auth_driver'].SetDefault('nova.auth.dbdriver.DbDriver')
flags.DECLARE('network_size', 'nova.network.manager') flags.DECLARE('network_size', 'nova.network.manager')

View File

@@ -0,0 +1,19 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2011 Openstack LLC.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
# NOTE(vish): this forces the fixtures from tests/__init.py:setup() to work
from nova.tests import *

View File

@@ -81,6 +81,17 @@ class _BaseRpcTestCase(test.TestCase):
for i, x in enumerate(result): for i, x in enumerate(result):
self.assertEqual(value + i, x) self.assertEqual(value + i, x)
def test_multicall_three_nones(self):
value = 42
result = self.rpc.multicall(self.context,
'test',
{"method": "multicall_three_nones",
"args": {"value": value}})
for i, x in enumerate(result):
self.assertEqual(x, None)
# i should have been 0, 1, and finally 2:
self.assertEqual(i, 2)
def test_multicall_succeed_three_times_yield(self): def test_multicall_succeed_three_times_yield(self):
value = 42 value = 42
result = self.rpc.multicall(self.context, result = self.rpc.multicall(self.context,
@@ -176,6 +187,13 @@ class TestReceiver(object):
context.reply(value) context.reply(value)
context.reply(value + 1) context.reply(value + 1)
context.reply(value + 2) context.reply(value + 2)
context.reply(ending=True)
@staticmethod
def multicall_three_nones(context, value):
yield None
yield None
yield None
@staticmethod @staticmethod
def echo_three_times_yield(context, value): def echo_three_times_yield(context, value):

View File

@@ -22,13 +22,13 @@ Unit Tests for remote procedure calls using carrot
from nova import context from nova import context
from nova import log as logging from nova import log as logging
from nova.rpc import impl_carrot from nova.rpc import impl_carrot
from nova.tests import test_rpc_common from nova.tests.rpc import common
LOG = logging.getLogger('nova.tests.rpc') LOG = logging.getLogger('nova.tests.rpc')
class RpcCarrotTestCase(test_rpc_common._BaseRpcTestCase): class RpcCarrotTestCase(common._BaseRpcTestCase):
def setUp(self): def setUp(self):
self.rpc = impl_carrot self.rpc = impl_carrot
super(RpcCarrotTestCase, self).setUp() super(RpcCarrotTestCase, self).setUp()

View File

@@ -0,0 +1,36 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Unit Tests for remote procedure calls using fake_impl
"""
from nova import log as logging
from nova.rpc import impl_fake
from nova.tests.rpc import common
LOG = logging.getLogger('nova.tests.rpc')
class RpcFakeTestCase(common._BaseRpcTestCase):
def setUp(self):
self.rpc = impl_fake
super(RpcFakeTestCase, self).setUp()
def tearDown(self):
super(RpcFakeTestCase, self).tearDown()

View File

@@ -23,13 +23,13 @@ from nova import context
from nova import log as logging from nova import log as logging
from nova import test from nova import test
from nova.rpc import impl_kombu from nova.rpc import impl_kombu
from nova.tests import test_rpc_common from nova.tests.rpc import common
LOG = logging.getLogger('nova.tests.rpc') LOG = logging.getLogger('nova.tests.rpc')
class RpcKombuTestCase(test_rpc_common._BaseRpcTestCase): class RpcKombuTestCase(common._BaseRpcTestCase):
def setUp(self): def setUp(self):
self.rpc = impl_kombu self.rpc = impl_kombu
super(RpcKombuTestCase, self).setUp() super(RpcKombuTestCase, self).setUp()

View File

@@ -16,20 +16,20 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
""" """
Unit Tests for remote procedure calls using queue Unit Tests for remote procedure call interfaces
""" """
from nova import context
from nova import log as logging from nova import log as logging
from nova import rpc from nova import rpc
from nova.tests import test_rpc_common from nova.tests.rpc import common
LOG = logging.getLogger('nova.tests.rpc') LOG = logging.getLogger('nova.tests.rpc')
class RpcTestCase(test_rpc_common._BaseRpcTestCase): class RpcTestCase(common._BaseRpcTestCase):
def setUp(self): def setUp(self):
self.flags(rpc_backend='nova.tests.rpc.fake')
self.rpc = rpc self.rpc = rpc
super(RpcTestCase, self).setUp() super(RpcTestCase, self).setUp()

View File

@@ -61,14 +61,9 @@ class AdminApiTestCase(test.TestCase):
self.stubs.Set(fake._FakeImageService, 'show', fake_show) self.stubs.Set(fake._FakeImageService, 'show', fake_show)
self.stubs.Set(fake._FakeImageService, 'show_by_name', fake_show) self.stubs.Set(fake._FakeImageService, 'show_by_name', fake_show)
# NOTE(vish): set up a manual wait so rpc.cast has a chance to finish # NOTE(comstud): Make 'cast' behave like a 'call' which will
rpc_cast = rpc.cast # ensure that operations complete
self.stubs.Set(rpc, 'cast', rpc.call)
def finish_cast(*args, **kwargs):
rpc_cast(*args, **kwargs)
greenthread.sleep(0.2)
self.stubs.Set(rpc, 'cast', finish_cast)
def test_block_external_ips(self): def test_block_external_ips(self):
"""Make sure provider firewall rules are created.""" """Make sure provider firewall rules are created."""

View File

@@ -132,16 +132,6 @@ def stubout_loopingcall_start(stubs):
stubs.Set(utils.LoopingCall, 'start', fake_start) stubs.Set(utils.LoopingCall, 'start', fake_start)
def stubout_loopingcall_delay(stubs):
def fake_start(self, interval, now=True):
self._running = True
eventlet.sleep(1)
self.f(*self.args, **self.kw)
# This would fail before parallel xenapi calls were fixed
assert self._running == False
stubs.Set(utils.LoopingCall, 'start', fake_start)
def _make_fake_vdi(): def _make_fake_vdi():
sr_ref = fake.get_all('SR')[0] sr_ref = fake.get_all('SR')[0]
vdi_ref = fake.create_vdi('', False, sr_ref, False) vdi_ref = fake.create_vdi('', False, sr_ref, False)

View File

@@ -64,6 +64,7 @@ import time
gettext.install('nova', unicode=1) gettext.install('nova', unicode=1)
import eventlet
from nose import config from nose import config
from nose import core from nose import core
from nose import result from nose import result
@@ -336,6 +337,7 @@ class NovaTestRunner(core.TextTestRunner):
if __name__ == '__main__': if __name__ == '__main__':
eventlet.monkey_patch()
logging.setup() logging.setup()
# If any argument looks like a test name but doesn't have "nova.tests" in # If any argument looks like a test name but doesn't have "nova.tests" in
# front of it, automatically add that so we don't have to type as much # front of it, automatically add that so we don't have to type as much