diff --git a/nova/fakerabbit.py b/nova/fakerabbit.py
index a7dee8ca..e7e9dab7 100644
--- a/nova/fakerabbit.py
+++ b/nova/fakerabbit.py
@@ -31,6 +31,7 @@ LOG = logging.getLogger("nova.fakerabbit")
EXCHANGES = {}
QUEUES = {}
+CONSUMERS = {}
class Message(base.BaseMessage):
@@ -96,17 +97,29 @@ class Backend(base.BaseBackend):
' key %(routing_key)s') % locals())
EXCHANGES[exchange].bind(QUEUES[queue].push, routing_key)
- def declare_consumer(self, queue, callback, *args, **kwargs):
- self.current_queue = queue
- self.current_callback = callback
+ def declare_consumer(self, queue, callback, consumer_tag, *args, **kwargs):
+ global CONSUMERS
+ LOG.debug("Adding consumer %s", consumer_tag)
+ CONSUMERS[consumer_tag] = (queue, callback)
+
+ def cancel(self, consumer_tag):
+ global CONSUMERS
+ LOG.debug("Removing consumer %s", consumer_tag)
+ del CONSUMERS[consumer_tag]
def consume(self, limit=None):
+ global CONSUMERS
+ num = 0
while True:
- item = self.get(self.current_queue)
- if item:
- self.current_callback(item)
- raise StopIteration()
- greenthread.sleep(0)
+ for (queue, callback) in CONSUMERS.itervalues():
+ item = self.get(queue)
+ if item:
+ callback(item)
+ num += 1
+ yield
+ if limit and num == limit:
+ raise StopIteration()
+ greenthread.sleep(0.1)
def get(self, queue, no_ack=False):
global QUEUES
@@ -134,5 +147,7 @@ class Backend(base.BaseBackend):
def reset_all():
global EXCHANGES
global QUEUES
+ global CONSUMERS
EXCHANGES = {}
QUEUES = {}
+ CONSUMERS = {}
diff --git a/nova/rpc.py b/nova/rpc.py
index 2116f22c..c5277c6a 100644
--- a/nova/rpc.py
+++ b/nova/rpc.py
@@ -28,12 +28,15 @@ import json
import sys
import time
import traceback
+import types
import uuid
from carrot import connection as carrot_connection
from carrot import messaging
from eventlet import greenpool
-from eventlet import greenthread
+from eventlet import pools
+from eventlet import queue
+import greenlet
from nova import context
from nova import exception
@@ -47,7 +50,10 @@ LOG = logging.getLogger('nova.rpc')
FLAGS = flags.FLAGS
-flags.DEFINE_integer('rpc_thread_pool_size', 1024, 'Size of RPC thread pool')
+flags.DEFINE_integer('rpc_thread_pool_size', 1024,
+ 'Size of RPC thread pool')
+flags.DEFINE_integer('rpc_conn_pool_size', 30,
+ 'Size of RPC connection pool')
class Connection(carrot_connection.BrokerConnection):
@@ -90,6 +96,22 @@ class Connection(carrot_connection.BrokerConnection):
return cls.instance()
+class Pool(pools.Pool):
+ """Class that implements a Pool of Connections."""
+
+ # TODO(comstud): Timeout connections not used in a while
+ def create(self):
+ LOG.debug('Creating new connection')
+ return Connection.instance(new=True)
+
+# Create a ConnectionPool to use for RPC calls. We'll order the
+# pool as a stack (LIFO), so that we can potentially loop through and
+# timeout old unused connections at some point
+ConnectionPool = Pool(
+ max_size=FLAGS.rpc_conn_pool_size,
+ order_as_stack=True)
+
+
class Consumer(messaging.Consumer):
"""Consumer base class.
@@ -131,7 +153,9 @@ class Consumer(messaging.Consumer):
self.connection = Connection.recreate()
self.backend = self.connection.create_backend()
self.declare()
- super(Consumer, self).fetch(no_ack, auto_ack, enable_callbacks)
+ return super(Consumer, self).fetch(no_ack,
+ auto_ack,
+ enable_callbacks)
if self.failed_connection:
LOG.error(_('Reconnected to queue'))
self.failed_connection = False
@@ -159,13 +183,13 @@ class AdapterConsumer(Consumer):
self.pool = greenpool.GreenPool(FLAGS.rpc_thread_pool_size)
super(AdapterConsumer, self).__init__(connection=connection,
topic=topic)
+ self.register_callback(self.process_data)
- def receive(self, *args, **kwargs):
- self.pool.spawn_n(self._receive, *args, **kwargs)
+ def process_data(self, message_data, message):
+ """Consumer callback to call a method on a proxy object.
- @exception.wrap_exception
- def _receive(self, message_data, message):
- """Magically looks for a method on the proxy object and calls it.
+ Parses the message for validity and fires off a thread to call the
+ proxy object method.
Message data should be a dictionary with two keys:
method: string representing the method to call
@@ -175,8 +199,8 @@ class AdapterConsumer(Consumer):
"""
LOG.debug(_('received %s') % message_data)
- msg_id = message_data.pop('_msg_id', None)
-
+ # This will be popped off in _unpack_context
+ msg_id = message_data.get('_msg_id', None)
ctxt = _unpack_context(message_data)
method = message_data.get('method')
@@ -188,8 +212,17 @@ class AdapterConsumer(Consumer):
# we just log the message and send an error string
# back to the caller
LOG.warn(_('no method for message: %s') % message_data)
- msg_reply(msg_id, _('No method for message: %s') % message_data)
+ if msg_id:
+ msg_reply(msg_id,
+ _('No method for message: %s') % message_data)
return
+ self.pool.spawn_n(self._process_data, msg_id, ctxt, method, args)
+
+ @exception.wrap_exception
+ def _process_data(self, msg_id, ctxt, method, args):
+ """Thread that maigcally looks for a method on the proxy
+ object and calls it.
+ """
node_func = getattr(self.proxy, str(method))
node_args = dict((str(k), v) for k, v in args.iteritems())
@@ -197,7 +230,18 @@ class AdapterConsumer(Consumer):
try:
rval = node_func(context=ctxt, **node_args)
if msg_id:
- msg_reply(msg_id, rval, None)
+ # Check if the result was a generator
+ if isinstance(rval, types.GeneratorType):
+ for x in rval:
+ msg_reply(msg_id, x, None)
+ else:
+ msg_reply(msg_id, rval, None)
+
+ # This final None tells multicall that it is done.
+ msg_reply(msg_id, None, None)
+ elif isinstance(rval, types.GeneratorType):
+ # NOTE(vish): this iterates through the generator
+ list(rval)
except Exception as e:
logging.exception('Exception during message handling')
if msg_id:
@@ -205,11 +249,6 @@ class AdapterConsumer(Consumer):
return
-class Publisher(messaging.Publisher):
- """Publisher base class."""
- pass
-
-
class TopicAdapterConsumer(AdapterConsumer):
"""Consumes messages on a specific topic."""
@@ -242,6 +281,58 @@ class FanoutAdapterConsumer(AdapterConsumer):
topic=topic, proxy=proxy)
+class ConsumerSet(object):
+ """Groups consumers to listen on together on a single connection."""
+
+ def __init__(self, connection, consumer_list):
+ self.consumer_list = set(consumer_list)
+ self.consumer_set = None
+ self.enabled = True
+ self.init(connection)
+
+ def init(self, conn):
+ if not conn:
+ conn = Connection.instance(new=True)
+ if self.consumer_set:
+ self.consumer_set.close()
+ self.consumer_set = messaging.ConsumerSet(conn)
+ for consumer in self.consumer_list:
+ consumer.connection = conn
+ # consumer.backend is set for us
+ self.consumer_set.add_consumer(consumer)
+
+ def reconnect(self):
+ self.init(None)
+
+ def wait(self, limit=None):
+ running = True
+ while running:
+ it = self.consumer_set.iterconsume(limit=limit)
+ if not it:
+ break
+ while True:
+ try:
+ it.next()
+ except StopIteration:
+ return
+ except greenlet.GreenletExit:
+ running = False
+ break
+ except Exception as e:
+ LOG.exception(_("Exception while processing consumer"))
+ self.reconnect()
+ # Break to outer loop
+ break
+
+ def close(self):
+ self.consumer_set.close()
+
+
+class Publisher(messaging.Publisher):
+ """Publisher base class."""
+ pass
+
+
class TopicPublisher(Publisher):
"""Publishes messages on a specific topic."""
@@ -306,16 +397,18 @@ def msg_reply(msg_id, reply=None, failure=None):
LOG.error(_("Returning exception %s to caller"), message)
LOG.error(tb)
failure = (failure[0].__name__, str(failure[1]), tb)
- conn = Connection.instance()
- publisher = DirectPublisher(connection=conn, msg_id=msg_id)
- try:
- publisher.send({'result': reply, 'failure': failure})
- except TypeError:
- publisher.send(
- {'result': dict((k, repr(v))
- for k, v in reply.__dict__.iteritems()),
- 'failure': failure})
- publisher.close()
+
+ with ConnectionPool.item() as conn:
+ publisher = DirectPublisher(connection=conn, msg_id=msg_id)
+ try:
+ publisher.send({'result': reply, 'failure': failure})
+ except TypeError:
+ publisher.send(
+ {'result': dict((k, repr(v))
+ for k, v in reply.__dict__.iteritems()),
+ 'failure': failure})
+
+ publisher.close()
class RemoteError(exception.Error):
@@ -347,8 +440,9 @@ def _unpack_context(msg):
if key.startswith('_context_'):
value = msg.pop(key)
context_dict[key[9:]] = value
+ context_dict['msg_id'] = msg.pop('_msg_id', None)
LOG.debug(_('unpacked context: %s'), context_dict)
- return context.RequestContext.from_dict(context_dict)
+ return RpcContext.from_dict(context_dict)
def _pack_context(msg, context):
@@ -360,70 +454,112 @@ def _pack_context(msg, context):
for args at some point.
"""
- context = dict([('_context_%s' % key, value)
- for (key, value) in context.to_dict().iteritems()])
- msg.update(context)
+ context_d = dict([('_context_%s' % key, value)
+ for (key, value) in context.to_dict().iteritems()])
+ msg.update(context_d)
-def call(context, topic, msg):
- """Sends a message on a topic and wait for a response."""
+class RpcContext(context.RequestContext):
+ def __init__(self, *args, **kwargs):
+ msg_id = kwargs.pop('msg_id', None)
+ self.msg_id = msg_id
+ super(RpcContext, self).__init__(*args, **kwargs)
+
+ def reply(self, *args, **kwargs):
+ msg_reply(self.msg_id, *args, **kwargs)
+
+
+def multicall(context, topic, msg):
+ """Make a call that returns multiple times."""
LOG.debug(_('Making asynchronous call on %s ...'), topic)
msg_id = uuid.uuid4().hex
msg.update({'_msg_id': msg_id})
LOG.debug(_('MSG_ID is %s') % (msg_id))
_pack_context(msg, context)
- class WaitMessage(object):
- def __call__(self, data, message):
- """Acks message and sets result."""
- message.ack()
- if data['failure']:
- self.result = RemoteError(*data['failure'])
- else:
- self.result = data['result']
-
- wait_msg = WaitMessage()
- conn = Connection.instance()
- consumer = DirectConsumer(connection=conn, msg_id=msg_id)
+ con_conn = ConnectionPool.get()
+ consumer = DirectConsumer(connection=con_conn, msg_id=msg_id)
+ wait_msg = MulticallWaiter(consumer)
consumer.register_callback(wait_msg)
- conn = Connection.instance()
- publisher = TopicPublisher(connection=conn, topic=topic)
+ publisher = TopicPublisher(connection=con_conn, topic=topic)
publisher.send(msg)
publisher.close()
- try:
- consumer.wait(limit=1)
- except StopIteration:
- pass
- consumer.close()
- # NOTE(termie): this is a little bit of a change from the original
- # non-eventlet code where returning a Failure
- # instance from a deferred call is very similar to
- # raising an exception
- if isinstance(wait_msg.result, Exception):
- raise wait_msg.result
- return wait_msg.result
+ return wait_msg
+
+
+class MulticallWaiter(object):
+ def __init__(self, consumer):
+ self._consumer = consumer
+ self._results = queue.Queue()
+ self._closed = False
+
+ def close(self):
+ self._closed = True
+ self._consumer.close()
+ ConnectionPool.put(self._consumer.connection)
+
+ def __call__(self, data, message):
+ """Acks message and sets result."""
+ message.ack()
+ if data['failure']:
+ self._results.put(RemoteError(*data['failure']))
+ else:
+ self._results.put(data['result'])
+
+ def __iter__(self):
+ return self.wait()
+
+ def wait(self):
+ while True:
+ rv = None
+ while rv is None and not self._closed:
+ try:
+ rv = self._consumer.fetch(enable_callbacks=True)
+ except Exception:
+ self.close()
+ raise
+ time.sleep(0.01)
+
+ result = self._results.get()
+ if isinstance(result, Exception):
+ self.close()
+ raise result
+ if result == None:
+ self.close()
+ raise StopIteration
+ yield result
+
+
+def call(context, topic, msg):
+ """Sends a message on a topic and wait for a response."""
+ rv = multicall(context, topic, msg)
+ # NOTE(vish): return the last result from the multicall
+ rv = list(rv)
+ if not rv:
+ return
+ return rv[-1]
def cast(context, topic, msg):
"""Sends a message on a topic without waiting for a response."""
LOG.debug(_('Making asynchronous cast on %s...'), topic)
_pack_context(msg, context)
- conn = Connection.instance()
- publisher = TopicPublisher(connection=conn, topic=topic)
- publisher.send(msg)
- publisher.close()
+ with ConnectionPool.item() as conn:
+ publisher = TopicPublisher(connection=conn, topic=topic)
+ publisher.send(msg)
+ publisher.close()
def fanout_cast(context, topic, msg):
"""Sends a message on a fanout exchange without waiting for a response."""
LOG.debug(_('Making asynchronous fanout cast...'))
_pack_context(msg, context)
- conn = Connection.instance()
- publisher = FanoutPublisher(topic, connection=conn)
- publisher.send(msg)
- publisher.close()
+ with ConnectionPool.item() as conn:
+ publisher = FanoutPublisher(topic, connection=conn)
+ publisher.send(msg)
+ publisher.close()
def generic_response(message_data, message):
@@ -459,6 +595,7 @@ def send_message(topic, message, wait=True):
if wait:
consumer.wait()
+ consumer.close()
if __name__ == '__main__':
diff --git a/nova/tests/test_cloud.py b/nova/tests/test_cloud.py
index 54c0454d..b64be662 100644
--- a/nova/tests/test_cloud.py
+++ b/nova/tests/test_cloud.py
@@ -17,13 +17,9 @@
# under the License.
from base64 import b64decode
-import json
from M2Crypto import BIO
from M2Crypto import RSA
import os
-import shutil
-import tempfile
-import time
from eventlet import greenthread
@@ -33,12 +29,10 @@ from nova import db
from nova import flags
from nova import log as logging
from nova import rpc
-from nova import service
from nova import test
from nova import utils
from nova import exception
from nova.auth import manager
-from nova.compute import power_state
from nova.api.ec2 import cloud
from nova.api.ec2 import ec2utils
from nova.image import local
@@ -79,14 +73,21 @@ class CloudTestCase(test.TestCase):
self.stubs.Set(local.LocalImageService, 'show', fake_show)
self.stubs.Set(local.LocalImageService, 'show_by_name', fake_show)
+ # NOTE(vish): set up a manual wait so rpc.cast has a chance to finish
+ rpc_cast = rpc.cast
+
+ def finish_cast(*args, **kwargs):
+ rpc_cast(*args, **kwargs)
+ greenthread.sleep(0.2)
+
+ self.stubs.Set(rpc, 'cast', finish_cast)
+
def tearDown(self):
network_ref = db.project_get_network(self.context,
self.project.id)
db.network_disassociate(self.context, network_ref['id'])
self.manager.delete_project(self.project)
self.manager.delete_user(self.user)
- self.compute.kill()
- self.network.kill()
super(CloudTestCase, self).tearDown()
def _create_key(self, name):
@@ -113,7 +114,6 @@ class CloudTestCase(test.TestCase):
self.cloud.describe_addresses(self.context)
self.cloud.release_address(self.context,
public_ip=address)
- greenthread.sleep(0.3)
db.floating_ip_destroy(self.context, address)
def test_associate_disassociate_address(self):
@@ -129,12 +129,10 @@ class CloudTestCase(test.TestCase):
self.cloud.associate_address(self.context,
instance_id=ec2_id,
public_ip=address)
- greenthread.sleep(0.3)
self.cloud.disassociate_address(self.context,
public_ip=address)
self.cloud.release_address(self.context,
public_ip=address)
- greenthread.sleep(0.3)
self.network.deallocate_fixed_ip(self.context, fixed)
db.instance_destroy(self.context, inst['id'])
db.floating_ip_destroy(self.context, address)
@@ -306,31 +304,25 @@ class CloudTestCase(test.TestCase):
'instance_type': instance_type,
'max_count': max_count}
rv = self.cloud.run_instances(self.context, **kwargs)
- greenthread.sleep(0.3)
instance_id = rv['instancesSet'][0]['instanceId']
output = self.cloud.get_console_output(context=self.context,
instance_id=[instance_id])
self.assertEquals(b64decode(output['output']), 'FAKE CONSOLE?OUTPUT')
# TODO(soren): We need this until we can stop polling in the rpc code
# for unit tests.
- greenthread.sleep(0.3)
rv = self.cloud.terminate_instances(self.context, [instance_id])
- greenthread.sleep(0.3)
def test_ajax_console(self):
kwargs = {'image_id': 'ami-1'}
rv = self.cloud.run_instances(self.context, **kwargs)
instance_id = rv['instancesSet'][0]['instanceId']
- greenthread.sleep(0.3)
output = self.cloud.get_ajax_console(context=self.context,
instance_id=[instance_id])
self.assertEquals(output['url'],
'%s/?token=FAKETOKEN' % FLAGS.ajax_console_proxy_url)
# TODO(soren): We need this until we can stop polling in the rpc code
# for unit tests.
- greenthread.sleep(0.3)
rv = self.cloud.terminate_instances(self.context, [instance_id])
- greenthread.sleep(0.3)
def test_key_generation(self):
result = self._create_key('test')
diff --git a/nova/tests/test_rpc.py b/nova/tests/test_rpc.py
index 44d7c91e..ffd748ef 100644
--- a/nova/tests/test_rpc.py
+++ b/nova/tests/test_rpc.py
@@ -31,7 +31,6 @@ LOG = logging.getLogger('nova.tests.rpc')
class RpcTestCase(test.TestCase):
- """Test cases for rpc"""
def setUp(self):
super(RpcTestCase, self).setUp()
self.conn = rpc.Connection.instance(True)
@@ -43,14 +42,55 @@ class RpcTestCase(test.TestCase):
self.context = context.get_admin_context()
def test_call_succeed(self):
- """Get a value through rpc call"""
value = 42
result = rpc.call(self.context, 'test', {"method": "echo",
"args": {"value": value}})
self.assertEqual(value, result)
+ def test_call_succeed_despite_multiple_returns(self):
+ value = 42
+ result = rpc.call(self.context, 'test', {"method": "echo_three_times",
+ "args": {"value": value}})
+ self.assertEqual(value + 2, result)
+
+ def test_call_succeed_despite_multiple_returns_yield(self):
+ value = 42
+ result = rpc.call(self.context, 'test',
+ {"method": "echo_three_times_yield",
+ "args": {"value": value}})
+ self.assertEqual(value + 2, result)
+
+ def test_multicall_succeed_once(self):
+ value = 42
+ result = rpc.multicall(self.context,
+ 'test',
+ {"method": "echo",
+ "args": {"value": value}})
+ for i, x in enumerate(result):
+ if i > 0:
+ self.fail('should only receive one response')
+ self.assertEqual(value + i, x)
+
+ def test_multicall_succeed_three_times(self):
+ value = 42
+ result = rpc.multicall(self.context,
+ 'test',
+ {"method": "echo_three_times",
+ "args": {"value": value}})
+ for i, x in enumerate(result):
+ self.assertEqual(value + i, x)
+
+ def test_multicall_succeed_three_times_yield(self):
+ value = 42
+ result = rpc.multicall(self.context,
+ 'test',
+ {"method": "echo_three_times_yield",
+ "args": {"value": value}})
+ for i, x in enumerate(result):
+ self.assertEqual(value + i, x)
+
def test_context_passed(self):
- """Makes sure a context is passed through rpc call"""
+ """Makes sure a context is passed through rpc call."""
value = 42
result = rpc.call(self.context,
'test', {"method": "context",
@@ -58,11 +98,12 @@ class RpcTestCase(test.TestCase):
self.assertEqual(self.context.to_dict(), result)
def test_call_exception(self):
- """Test that exception gets passed back properly
+ """Test that exception gets passed back properly.
rpc.call returns a RemoteError object. The value of the
exception is converted to a string, so we convert it back
to an int in the test.
+
"""
value = 42
self.assertRaises(rpc.RemoteError,
@@ -81,7 +122,7 @@ class RpcTestCase(test.TestCase):
self.assertEqual(int(exc.value), value)
def test_nested_calls(self):
- """Test that we can do an rpc.call inside another call"""
+ """Test that we can do an rpc.call inside another call."""
class Nested(object):
@staticmethod
def echo(context, queue, value):
@@ -108,25 +149,80 @@ class RpcTestCase(test.TestCase):
"value": value}})
self.assertEqual(value, result)
+ def test_connectionpool_single(self):
+ """Test that ConnectionPool recycles a single connection."""
+ conn1 = rpc.ConnectionPool.get()
+ rpc.ConnectionPool.put(conn1)
+ conn2 = rpc.ConnectionPool.get()
+ rpc.ConnectionPool.put(conn2)
+ self.assertEqual(conn1, conn2)
+
+ def test_connectionpool_double(self):
+ """Test that ConnectionPool returns and reuses separate connections.
+
+ When called consecutively we should get separate connections and upon
+ returning them those connections should be reused for future calls
+ before generating a new connection.
+
+ """
+ conn1 = rpc.ConnectionPool.get()
+ conn2 = rpc.ConnectionPool.get()
+
+ self.assertNotEqual(conn1, conn2)
+ rpc.ConnectionPool.put(conn1)
+ rpc.ConnectionPool.put(conn2)
+
+ conn3 = rpc.ConnectionPool.get()
+ conn4 = rpc.ConnectionPool.get()
+ self.assertEqual(conn1, conn3)
+ self.assertEqual(conn2, conn4)
+
+ def test_connectionpool_limit(self):
+ """Test connection pool limit and connection uniqueness."""
+ max_size = FLAGS.rpc_conn_pool_size
+ conns = []
+
+ for i in xrange(max_size):
+ conns.append(rpc.ConnectionPool.get())
+
+ self.assertFalse(rpc.ConnectionPool.free_items)
+ self.assertEqual(rpc.ConnectionPool.current_size,
+ rpc.ConnectionPool.max_size)
+ self.assertEqual(len(set(conns)), max_size)
+
class TestReceiver(object):
- """Simple Proxy class so the consumer has methods to call
+ """Simple Proxy class so the consumer has methods to call.
- Uses static methods because we aren't actually storing any state"""
+ Uses static methods because we aren't actually storing any state.
+
+ """
@staticmethod
def echo(context, value):
- """Simply returns whatever value is sent in"""
+ """Simply returns whatever value is sent in."""
LOG.debug(_("Received %s"), value)
return value
@staticmethod
def context(context, value):
- """Returns dictionary version of context"""
+ """Returns dictionary version of context."""
LOG.debug(_("Received %s"), context)
return context.to_dict()
+ @staticmethod
+ def echo_three_times(context, value):
+ context.reply(value)
+ context.reply(value + 1)
+ context.reply(value + 2)
+
+ @staticmethod
+ def echo_three_times_yield(context, value):
+ yield value
+ yield value + 1
+ yield value + 2
+
@staticmethod
def fail(context, value):
- """Raises an exception with the value sent in"""
+ """Raises an exception with the value sent in."""
raise Exception(value)
diff --git a/nova/tests/test_xenapi.py b/nova/tests/test_xenapi.py
index be1e3569..18a26789 100644
--- a/nova/tests/test_xenapi.py
+++ b/nova/tests/test_xenapi.py
@@ -395,6 +395,29 @@ class XenAPIVMTestCase(test.TestCase):
os_type="linux")
self.check_vm_params_for_linux()
+ def test_spawn_vhd_glance_swapdisk(self):
+ # Change the default host_call_plugin to one that'll return
+ # a swap disk
+ orig_func = stubs.FakeSessionForVMTests.host_call_plugin
+
+ stubs.FakeSessionForVMTests.host_call_plugin = \
+ stubs.FakeSessionForVMTests.host_call_plugin_swap
+
+ try:
+ # We'll steal the above glance linux test
+ self.test_spawn_vhd_glance_linux()
+ finally:
+ # Make sure to put this back
+ stubs.FakeSessionForVMTests.host_call_plugin = orig_func
+
+ # We should have 2 VBDs.
+ self.assertEqual(len(self.vm['VBDs']), 2)
+ # Now test that we have 1.
+ self.tearDown()
+ self.setUp()
+ self.test_spawn_vhd_glance_linux()
+ self.assertEqual(len(self.vm['VBDs']), 1)
+
def test_spawn_vhd_glance_windows(self):
FLAGS.xenapi_image_service = 'glance'
self._test_spawn(glance_stubs.FakeGlance.IMAGE_VHD, None, None,
diff --git a/nova/tests/xenapi/stubs.py b/nova/tests/xenapi/stubs.py
index 4833ccb0..35308d95 100644
--- a/nova/tests/xenapi/stubs.py
+++ b/nova/tests/xenapi/stubs.py
@@ -17,6 +17,7 @@
"""Stubouts, mocks and fixtures for the test suite"""
import eventlet
+import json
from nova.virt import xenapi_conn
from nova.virt.xenapi import fake
from nova.virt.xenapi import volume_utils
@@ -37,7 +38,7 @@ def stubout_instance_snapshot(stubs):
sr_ref=sr_ref, sharable=False)
vdi_rec = session.get_xenapi().VDI.get_record(vdi_ref)
vdi_uuid = vdi_rec['uuid']
- return vdi_uuid
+ return [dict(vdi_type='os', vdi_uuid=vdi_uuid)]
stubs.Set(vm_utils.VMHelper, 'fetch_image', fake_fetch_image)
@@ -132,11 +133,30 @@ class FakeSessionForVMTests(fake.SessionBase):
def __init__(self, uri):
super(FakeSessionForVMTests, self).__init__(uri)
- def host_call_plugin(self, _1, _2, _3, _4, _5):
+ def host_call_plugin(self, _1, _2, plugin, method, _5):
sr_ref = fake.get_all('SR')[0]
vdi_ref = fake.create_vdi('', False, sr_ref, False)
vdi_rec = fake.get_record('VDI', vdi_ref)
- return '%s' % vdi_rec['uuid']
+ if plugin == "glance" and method == "download_vhd":
+ ret_str = json.dumps([dict(vdi_type='os',
+ vdi_uuid=vdi_rec['uuid'])])
+ else:
+ ret_str = vdi_rec['uuid']
+ return '%s' % ret_str
+
+ def host_call_plugin_swap(self, _1, _2, plugin, method, _5):
+ sr_ref = fake.get_all('SR')[0]
+ vdi_ref = fake.create_vdi('', False, sr_ref, False)
+ vdi_rec = fake.get_record('VDI', vdi_ref)
+ if plugin == "glance" and method == "download_vhd":
+ swap_vdi_ref = fake.create_vdi('', False, sr_ref, False)
+ swap_vdi_rec = fake.get_record('VDI', swap_vdi_ref)
+ ret_str = json.dumps(
+ [dict(vdi_type='os', vdi_uuid=vdi_rec['uuid']),
+ dict(vdi_type='swap', vdi_uuid=swap_vdi_rec['uuid'])])
+ else:
+ ret_str = vdi_rec['uuid']
+ return '%s' % ret_str
def VM_start(self, _1, ref, _2, _3):
vm = fake.get_record('VM', ref)