diff --git a/nova/fakerabbit.py b/nova/fakerabbit.py index a7dee8ca..e7e9dab7 100644 --- a/nova/fakerabbit.py +++ b/nova/fakerabbit.py @@ -31,6 +31,7 @@ LOG = logging.getLogger("nova.fakerabbit") EXCHANGES = {} QUEUES = {} +CONSUMERS = {} class Message(base.BaseMessage): @@ -96,17 +97,29 @@ class Backend(base.BaseBackend): ' key %(routing_key)s') % locals()) EXCHANGES[exchange].bind(QUEUES[queue].push, routing_key) - def declare_consumer(self, queue, callback, *args, **kwargs): - self.current_queue = queue - self.current_callback = callback + def declare_consumer(self, queue, callback, consumer_tag, *args, **kwargs): + global CONSUMERS + LOG.debug("Adding consumer %s", consumer_tag) + CONSUMERS[consumer_tag] = (queue, callback) + + def cancel(self, consumer_tag): + global CONSUMERS + LOG.debug("Removing consumer %s", consumer_tag) + del CONSUMERS[consumer_tag] def consume(self, limit=None): + global CONSUMERS + num = 0 while True: - item = self.get(self.current_queue) - if item: - self.current_callback(item) - raise StopIteration() - greenthread.sleep(0) + for (queue, callback) in CONSUMERS.itervalues(): + item = self.get(queue) + if item: + callback(item) + num += 1 + yield + if limit and num == limit: + raise StopIteration() + greenthread.sleep(0.1) def get(self, queue, no_ack=False): global QUEUES @@ -134,5 +147,7 @@ class Backend(base.BaseBackend): def reset_all(): global EXCHANGES global QUEUES + global CONSUMERS EXCHANGES = {} QUEUES = {} + CONSUMERS = {} diff --git a/nova/rpc.py b/nova/rpc.py index 2116f22c..c5277c6a 100644 --- a/nova/rpc.py +++ b/nova/rpc.py @@ -28,12 +28,15 @@ import json import sys import time import traceback +import types import uuid from carrot import connection as carrot_connection from carrot import messaging from eventlet import greenpool -from eventlet import greenthread +from eventlet import pools +from eventlet import queue +import greenlet from nova import context from nova import exception @@ -47,7 +50,10 @@ LOG = logging.getLogger('nova.rpc') FLAGS = flags.FLAGS -flags.DEFINE_integer('rpc_thread_pool_size', 1024, 'Size of RPC thread pool') +flags.DEFINE_integer('rpc_thread_pool_size', 1024, + 'Size of RPC thread pool') +flags.DEFINE_integer('rpc_conn_pool_size', 30, + 'Size of RPC connection pool') class Connection(carrot_connection.BrokerConnection): @@ -90,6 +96,22 @@ class Connection(carrot_connection.BrokerConnection): return cls.instance() +class Pool(pools.Pool): + """Class that implements a Pool of Connections.""" + + # TODO(comstud): Timeout connections not used in a while + def create(self): + LOG.debug('Creating new connection') + return Connection.instance(new=True) + +# Create a ConnectionPool to use for RPC calls. We'll order the +# pool as a stack (LIFO), so that we can potentially loop through and +# timeout old unused connections at some point +ConnectionPool = Pool( + max_size=FLAGS.rpc_conn_pool_size, + order_as_stack=True) + + class Consumer(messaging.Consumer): """Consumer base class. @@ -131,7 +153,9 @@ class Consumer(messaging.Consumer): self.connection = Connection.recreate() self.backend = self.connection.create_backend() self.declare() - super(Consumer, self).fetch(no_ack, auto_ack, enable_callbacks) + return super(Consumer, self).fetch(no_ack, + auto_ack, + enable_callbacks) if self.failed_connection: LOG.error(_('Reconnected to queue')) self.failed_connection = False @@ -159,13 +183,13 @@ class AdapterConsumer(Consumer): self.pool = greenpool.GreenPool(FLAGS.rpc_thread_pool_size) super(AdapterConsumer, self).__init__(connection=connection, topic=topic) + self.register_callback(self.process_data) - def receive(self, *args, **kwargs): - self.pool.spawn_n(self._receive, *args, **kwargs) + def process_data(self, message_data, message): + """Consumer callback to call a method on a proxy object. - @exception.wrap_exception - def _receive(self, message_data, message): - """Magically looks for a method on the proxy object and calls it. + Parses the message for validity and fires off a thread to call the + proxy object method. Message data should be a dictionary with two keys: method: string representing the method to call @@ -175,8 +199,8 @@ class AdapterConsumer(Consumer): """ LOG.debug(_('received %s') % message_data) - msg_id = message_data.pop('_msg_id', None) - + # This will be popped off in _unpack_context + msg_id = message_data.get('_msg_id', None) ctxt = _unpack_context(message_data) method = message_data.get('method') @@ -188,8 +212,17 @@ class AdapterConsumer(Consumer): # we just log the message and send an error string # back to the caller LOG.warn(_('no method for message: %s') % message_data) - msg_reply(msg_id, _('No method for message: %s') % message_data) + if msg_id: + msg_reply(msg_id, + _('No method for message: %s') % message_data) return + self.pool.spawn_n(self._process_data, msg_id, ctxt, method, args) + + @exception.wrap_exception + def _process_data(self, msg_id, ctxt, method, args): + """Thread that maigcally looks for a method on the proxy + object and calls it. + """ node_func = getattr(self.proxy, str(method)) node_args = dict((str(k), v) for k, v in args.iteritems()) @@ -197,7 +230,18 @@ class AdapterConsumer(Consumer): try: rval = node_func(context=ctxt, **node_args) if msg_id: - msg_reply(msg_id, rval, None) + # Check if the result was a generator + if isinstance(rval, types.GeneratorType): + for x in rval: + msg_reply(msg_id, x, None) + else: + msg_reply(msg_id, rval, None) + + # This final None tells multicall that it is done. + msg_reply(msg_id, None, None) + elif isinstance(rval, types.GeneratorType): + # NOTE(vish): this iterates through the generator + list(rval) except Exception as e: logging.exception('Exception during message handling') if msg_id: @@ -205,11 +249,6 @@ class AdapterConsumer(Consumer): return -class Publisher(messaging.Publisher): - """Publisher base class.""" - pass - - class TopicAdapterConsumer(AdapterConsumer): """Consumes messages on a specific topic.""" @@ -242,6 +281,58 @@ class FanoutAdapterConsumer(AdapterConsumer): topic=topic, proxy=proxy) +class ConsumerSet(object): + """Groups consumers to listen on together on a single connection.""" + + def __init__(self, connection, consumer_list): + self.consumer_list = set(consumer_list) + self.consumer_set = None + self.enabled = True + self.init(connection) + + def init(self, conn): + if not conn: + conn = Connection.instance(new=True) + if self.consumer_set: + self.consumer_set.close() + self.consumer_set = messaging.ConsumerSet(conn) + for consumer in self.consumer_list: + consumer.connection = conn + # consumer.backend is set for us + self.consumer_set.add_consumer(consumer) + + def reconnect(self): + self.init(None) + + def wait(self, limit=None): + running = True + while running: + it = self.consumer_set.iterconsume(limit=limit) + if not it: + break + while True: + try: + it.next() + except StopIteration: + return + except greenlet.GreenletExit: + running = False + break + except Exception as e: + LOG.exception(_("Exception while processing consumer")) + self.reconnect() + # Break to outer loop + break + + def close(self): + self.consumer_set.close() + + +class Publisher(messaging.Publisher): + """Publisher base class.""" + pass + + class TopicPublisher(Publisher): """Publishes messages on a specific topic.""" @@ -306,16 +397,18 @@ def msg_reply(msg_id, reply=None, failure=None): LOG.error(_("Returning exception %s to caller"), message) LOG.error(tb) failure = (failure[0].__name__, str(failure[1]), tb) - conn = Connection.instance() - publisher = DirectPublisher(connection=conn, msg_id=msg_id) - try: - publisher.send({'result': reply, 'failure': failure}) - except TypeError: - publisher.send( - {'result': dict((k, repr(v)) - for k, v in reply.__dict__.iteritems()), - 'failure': failure}) - publisher.close() + + with ConnectionPool.item() as conn: + publisher = DirectPublisher(connection=conn, msg_id=msg_id) + try: + publisher.send({'result': reply, 'failure': failure}) + except TypeError: + publisher.send( + {'result': dict((k, repr(v)) + for k, v in reply.__dict__.iteritems()), + 'failure': failure}) + + publisher.close() class RemoteError(exception.Error): @@ -347,8 +440,9 @@ def _unpack_context(msg): if key.startswith('_context_'): value = msg.pop(key) context_dict[key[9:]] = value + context_dict['msg_id'] = msg.pop('_msg_id', None) LOG.debug(_('unpacked context: %s'), context_dict) - return context.RequestContext.from_dict(context_dict) + return RpcContext.from_dict(context_dict) def _pack_context(msg, context): @@ -360,70 +454,112 @@ def _pack_context(msg, context): for args at some point. """ - context = dict([('_context_%s' % key, value) - for (key, value) in context.to_dict().iteritems()]) - msg.update(context) + context_d = dict([('_context_%s' % key, value) + for (key, value) in context.to_dict().iteritems()]) + msg.update(context_d) -def call(context, topic, msg): - """Sends a message on a topic and wait for a response.""" +class RpcContext(context.RequestContext): + def __init__(self, *args, **kwargs): + msg_id = kwargs.pop('msg_id', None) + self.msg_id = msg_id + super(RpcContext, self).__init__(*args, **kwargs) + + def reply(self, *args, **kwargs): + msg_reply(self.msg_id, *args, **kwargs) + + +def multicall(context, topic, msg): + """Make a call that returns multiple times.""" LOG.debug(_('Making asynchronous call on %s ...'), topic) msg_id = uuid.uuid4().hex msg.update({'_msg_id': msg_id}) LOG.debug(_('MSG_ID is %s') % (msg_id)) _pack_context(msg, context) - class WaitMessage(object): - def __call__(self, data, message): - """Acks message and sets result.""" - message.ack() - if data['failure']: - self.result = RemoteError(*data['failure']) - else: - self.result = data['result'] - - wait_msg = WaitMessage() - conn = Connection.instance() - consumer = DirectConsumer(connection=conn, msg_id=msg_id) + con_conn = ConnectionPool.get() + consumer = DirectConsumer(connection=con_conn, msg_id=msg_id) + wait_msg = MulticallWaiter(consumer) consumer.register_callback(wait_msg) - conn = Connection.instance() - publisher = TopicPublisher(connection=conn, topic=topic) + publisher = TopicPublisher(connection=con_conn, topic=topic) publisher.send(msg) publisher.close() - try: - consumer.wait(limit=1) - except StopIteration: - pass - consumer.close() - # NOTE(termie): this is a little bit of a change from the original - # non-eventlet code where returning a Failure - # instance from a deferred call is very similar to - # raising an exception - if isinstance(wait_msg.result, Exception): - raise wait_msg.result - return wait_msg.result + return wait_msg + + +class MulticallWaiter(object): + def __init__(self, consumer): + self._consumer = consumer + self._results = queue.Queue() + self._closed = False + + def close(self): + self._closed = True + self._consumer.close() + ConnectionPool.put(self._consumer.connection) + + def __call__(self, data, message): + """Acks message and sets result.""" + message.ack() + if data['failure']: + self._results.put(RemoteError(*data['failure'])) + else: + self._results.put(data['result']) + + def __iter__(self): + return self.wait() + + def wait(self): + while True: + rv = None + while rv is None and not self._closed: + try: + rv = self._consumer.fetch(enable_callbacks=True) + except Exception: + self.close() + raise + time.sleep(0.01) + + result = self._results.get() + if isinstance(result, Exception): + self.close() + raise result + if result == None: + self.close() + raise StopIteration + yield result + + +def call(context, topic, msg): + """Sends a message on a topic and wait for a response.""" + rv = multicall(context, topic, msg) + # NOTE(vish): return the last result from the multicall + rv = list(rv) + if not rv: + return + return rv[-1] def cast(context, topic, msg): """Sends a message on a topic without waiting for a response.""" LOG.debug(_('Making asynchronous cast on %s...'), topic) _pack_context(msg, context) - conn = Connection.instance() - publisher = TopicPublisher(connection=conn, topic=topic) - publisher.send(msg) - publisher.close() + with ConnectionPool.item() as conn: + publisher = TopicPublisher(connection=conn, topic=topic) + publisher.send(msg) + publisher.close() def fanout_cast(context, topic, msg): """Sends a message on a fanout exchange without waiting for a response.""" LOG.debug(_('Making asynchronous fanout cast...')) _pack_context(msg, context) - conn = Connection.instance() - publisher = FanoutPublisher(topic, connection=conn) - publisher.send(msg) - publisher.close() + with ConnectionPool.item() as conn: + publisher = FanoutPublisher(topic, connection=conn) + publisher.send(msg) + publisher.close() def generic_response(message_data, message): @@ -459,6 +595,7 @@ def send_message(topic, message, wait=True): if wait: consumer.wait() + consumer.close() if __name__ == '__main__': diff --git a/nova/tests/test_cloud.py b/nova/tests/test_cloud.py index 54c0454d..b64be662 100644 --- a/nova/tests/test_cloud.py +++ b/nova/tests/test_cloud.py @@ -17,13 +17,9 @@ # under the License. from base64 import b64decode -import json from M2Crypto import BIO from M2Crypto import RSA import os -import shutil -import tempfile -import time from eventlet import greenthread @@ -33,12 +29,10 @@ from nova import db from nova import flags from nova import log as logging from nova import rpc -from nova import service from nova import test from nova import utils from nova import exception from nova.auth import manager -from nova.compute import power_state from nova.api.ec2 import cloud from nova.api.ec2 import ec2utils from nova.image import local @@ -79,14 +73,21 @@ class CloudTestCase(test.TestCase): self.stubs.Set(local.LocalImageService, 'show', fake_show) self.stubs.Set(local.LocalImageService, 'show_by_name', fake_show) + # NOTE(vish): set up a manual wait so rpc.cast has a chance to finish + rpc_cast = rpc.cast + + def finish_cast(*args, **kwargs): + rpc_cast(*args, **kwargs) + greenthread.sleep(0.2) + + self.stubs.Set(rpc, 'cast', finish_cast) + def tearDown(self): network_ref = db.project_get_network(self.context, self.project.id) db.network_disassociate(self.context, network_ref['id']) self.manager.delete_project(self.project) self.manager.delete_user(self.user) - self.compute.kill() - self.network.kill() super(CloudTestCase, self).tearDown() def _create_key(self, name): @@ -113,7 +114,6 @@ class CloudTestCase(test.TestCase): self.cloud.describe_addresses(self.context) self.cloud.release_address(self.context, public_ip=address) - greenthread.sleep(0.3) db.floating_ip_destroy(self.context, address) def test_associate_disassociate_address(self): @@ -129,12 +129,10 @@ class CloudTestCase(test.TestCase): self.cloud.associate_address(self.context, instance_id=ec2_id, public_ip=address) - greenthread.sleep(0.3) self.cloud.disassociate_address(self.context, public_ip=address) self.cloud.release_address(self.context, public_ip=address) - greenthread.sleep(0.3) self.network.deallocate_fixed_ip(self.context, fixed) db.instance_destroy(self.context, inst['id']) db.floating_ip_destroy(self.context, address) @@ -306,31 +304,25 @@ class CloudTestCase(test.TestCase): 'instance_type': instance_type, 'max_count': max_count} rv = self.cloud.run_instances(self.context, **kwargs) - greenthread.sleep(0.3) instance_id = rv['instancesSet'][0]['instanceId'] output = self.cloud.get_console_output(context=self.context, instance_id=[instance_id]) self.assertEquals(b64decode(output['output']), 'FAKE CONSOLE?OUTPUT') # TODO(soren): We need this until we can stop polling in the rpc code # for unit tests. - greenthread.sleep(0.3) rv = self.cloud.terminate_instances(self.context, [instance_id]) - greenthread.sleep(0.3) def test_ajax_console(self): kwargs = {'image_id': 'ami-1'} rv = self.cloud.run_instances(self.context, **kwargs) instance_id = rv['instancesSet'][0]['instanceId'] - greenthread.sleep(0.3) output = self.cloud.get_ajax_console(context=self.context, instance_id=[instance_id]) self.assertEquals(output['url'], '%s/?token=FAKETOKEN' % FLAGS.ajax_console_proxy_url) # TODO(soren): We need this until we can stop polling in the rpc code # for unit tests. - greenthread.sleep(0.3) rv = self.cloud.terminate_instances(self.context, [instance_id]) - greenthread.sleep(0.3) def test_key_generation(self): result = self._create_key('test') diff --git a/nova/tests/test_rpc.py b/nova/tests/test_rpc.py index 44d7c91e..ffd748ef 100644 --- a/nova/tests/test_rpc.py +++ b/nova/tests/test_rpc.py @@ -31,7 +31,6 @@ LOG = logging.getLogger('nova.tests.rpc') class RpcTestCase(test.TestCase): - """Test cases for rpc""" def setUp(self): super(RpcTestCase, self).setUp() self.conn = rpc.Connection.instance(True) @@ -43,14 +42,55 @@ class RpcTestCase(test.TestCase): self.context = context.get_admin_context() def test_call_succeed(self): - """Get a value through rpc call""" value = 42 result = rpc.call(self.context, 'test', {"method": "echo", "args": {"value": value}}) self.assertEqual(value, result) + def test_call_succeed_despite_multiple_returns(self): + value = 42 + result = rpc.call(self.context, 'test', {"method": "echo_three_times", + "args": {"value": value}}) + self.assertEqual(value + 2, result) + + def test_call_succeed_despite_multiple_returns_yield(self): + value = 42 + result = rpc.call(self.context, 'test', + {"method": "echo_three_times_yield", + "args": {"value": value}}) + self.assertEqual(value + 2, result) + + def test_multicall_succeed_once(self): + value = 42 + result = rpc.multicall(self.context, + 'test', + {"method": "echo", + "args": {"value": value}}) + for i, x in enumerate(result): + if i > 0: + self.fail('should only receive one response') + self.assertEqual(value + i, x) + + def test_multicall_succeed_three_times(self): + value = 42 + result = rpc.multicall(self.context, + 'test', + {"method": "echo_three_times", + "args": {"value": value}}) + for i, x in enumerate(result): + self.assertEqual(value + i, x) + + def test_multicall_succeed_three_times_yield(self): + value = 42 + result = rpc.multicall(self.context, + 'test', + {"method": "echo_three_times_yield", + "args": {"value": value}}) + for i, x in enumerate(result): + self.assertEqual(value + i, x) + def test_context_passed(self): - """Makes sure a context is passed through rpc call""" + """Makes sure a context is passed through rpc call.""" value = 42 result = rpc.call(self.context, 'test', {"method": "context", @@ -58,11 +98,12 @@ class RpcTestCase(test.TestCase): self.assertEqual(self.context.to_dict(), result) def test_call_exception(self): - """Test that exception gets passed back properly + """Test that exception gets passed back properly. rpc.call returns a RemoteError object. The value of the exception is converted to a string, so we convert it back to an int in the test. + """ value = 42 self.assertRaises(rpc.RemoteError, @@ -81,7 +122,7 @@ class RpcTestCase(test.TestCase): self.assertEqual(int(exc.value), value) def test_nested_calls(self): - """Test that we can do an rpc.call inside another call""" + """Test that we can do an rpc.call inside another call.""" class Nested(object): @staticmethod def echo(context, queue, value): @@ -108,25 +149,80 @@ class RpcTestCase(test.TestCase): "value": value}}) self.assertEqual(value, result) + def test_connectionpool_single(self): + """Test that ConnectionPool recycles a single connection.""" + conn1 = rpc.ConnectionPool.get() + rpc.ConnectionPool.put(conn1) + conn2 = rpc.ConnectionPool.get() + rpc.ConnectionPool.put(conn2) + self.assertEqual(conn1, conn2) + + def test_connectionpool_double(self): + """Test that ConnectionPool returns and reuses separate connections. + + When called consecutively we should get separate connections and upon + returning them those connections should be reused for future calls + before generating a new connection. + + """ + conn1 = rpc.ConnectionPool.get() + conn2 = rpc.ConnectionPool.get() + + self.assertNotEqual(conn1, conn2) + rpc.ConnectionPool.put(conn1) + rpc.ConnectionPool.put(conn2) + + conn3 = rpc.ConnectionPool.get() + conn4 = rpc.ConnectionPool.get() + self.assertEqual(conn1, conn3) + self.assertEqual(conn2, conn4) + + def test_connectionpool_limit(self): + """Test connection pool limit and connection uniqueness.""" + max_size = FLAGS.rpc_conn_pool_size + conns = [] + + for i in xrange(max_size): + conns.append(rpc.ConnectionPool.get()) + + self.assertFalse(rpc.ConnectionPool.free_items) + self.assertEqual(rpc.ConnectionPool.current_size, + rpc.ConnectionPool.max_size) + self.assertEqual(len(set(conns)), max_size) + class TestReceiver(object): - """Simple Proxy class so the consumer has methods to call + """Simple Proxy class so the consumer has methods to call. - Uses static methods because we aren't actually storing any state""" + Uses static methods because we aren't actually storing any state. + + """ @staticmethod def echo(context, value): - """Simply returns whatever value is sent in""" + """Simply returns whatever value is sent in.""" LOG.debug(_("Received %s"), value) return value @staticmethod def context(context, value): - """Returns dictionary version of context""" + """Returns dictionary version of context.""" LOG.debug(_("Received %s"), context) return context.to_dict() + @staticmethod + def echo_three_times(context, value): + context.reply(value) + context.reply(value + 1) + context.reply(value + 2) + + @staticmethod + def echo_three_times_yield(context, value): + yield value + yield value + 1 + yield value + 2 + @staticmethod def fail(context, value): - """Raises an exception with the value sent in""" + """Raises an exception with the value sent in.""" raise Exception(value) diff --git a/nova/tests/test_xenapi.py b/nova/tests/test_xenapi.py index be1e3569..18a26789 100644 --- a/nova/tests/test_xenapi.py +++ b/nova/tests/test_xenapi.py @@ -395,6 +395,29 @@ class XenAPIVMTestCase(test.TestCase): os_type="linux") self.check_vm_params_for_linux() + def test_spawn_vhd_glance_swapdisk(self): + # Change the default host_call_plugin to one that'll return + # a swap disk + orig_func = stubs.FakeSessionForVMTests.host_call_plugin + + stubs.FakeSessionForVMTests.host_call_plugin = \ + stubs.FakeSessionForVMTests.host_call_plugin_swap + + try: + # We'll steal the above glance linux test + self.test_spawn_vhd_glance_linux() + finally: + # Make sure to put this back + stubs.FakeSessionForVMTests.host_call_plugin = orig_func + + # We should have 2 VBDs. + self.assertEqual(len(self.vm['VBDs']), 2) + # Now test that we have 1. + self.tearDown() + self.setUp() + self.test_spawn_vhd_glance_linux() + self.assertEqual(len(self.vm['VBDs']), 1) + def test_spawn_vhd_glance_windows(self): FLAGS.xenapi_image_service = 'glance' self._test_spawn(glance_stubs.FakeGlance.IMAGE_VHD, None, None, diff --git a/nova/tests/xenapi/stubs.py b/nova/tests/xenapi/stubs.py index 4833ccb0..35308d95 100644 --- a/nova/tests/xenapi/stubs.py +++ b/nova/tests/xenapi/stubs.py @@ -17,6 +17,7 @@ """Stubouts, mocks and fixtures for the test suite""" import eventlet +import json from nova.virt import xenapi_conn from nova.virt.xenapi import fake from nova.virt.xenapi import volume_utils @@ -37,7 +38,7 @@ def stubout_instance_snapshot(stubs): sr_ref=sr_ref, sharable=False) vdi_rec = session.get_xenapi().VDI.get_record(vdi_ref) vdi_uuid = vdi_rec['uuid'] - return vdi_uuid + return [dict(vdi_type='os', vdi_uuid=vdi_uuid)] stubs.Set(vm_utils.VMHelper, 'fetch_image', fake_fetch_image) @@ -132,11 +133,30 @@ class FakeSessionForVMTests(fake.SessionBase): def __init__(self, uri): super(FakeSessionForVMTests, self).__init__(uri) - def host_call_plugin(self, _1, _2, _3, _4, _5): + def host_call_plugin(self, _1, _2, plugin, method, _5): sr_ref = fake.get_all('SR')[0] vdi_ref = fake.create_vdi('', False, sr_ref, False) vdi_rec = fake.get_record('VDI', vdi_ref) - return '%s' % vdi_rec['uuid'] + if plugin == "glance" and method == "download_vhd": + ret_str = json.dumps([dict(vdi_type='os', + vdi_uuid=vdi_rec['uuid'])]) + else: + ret_str = vdi_rec['uuid'] + return '%s' % ret_str + + def host_call_plugin_swap(self, _1, _2, plugin, method, _5): + sr_ref = fake.get_all('SR')[0] + vdi_ref = fake.create_vdi('', False, sr_ref, False) + vdi_rec = fake.get_record('VDI', vdi_ref) + if plugin == "glance" and method == "download_vhd": + swap_vdi_ref = fake.create_vdi('', False, sr_ref, False) + swap_vdi_rec = fake.get_record('VDI', swap_vdi_ref) + ret_str = json.dumps( + [dict(vdi_type='os', vdi_uuid=vdi_rec['uuid']), + dict(vdi_type='swap', vdi_uuid=swap_vdi_rec['uuid'])]) + else: + ret_str = vdi_rec['uuid'] + return '%s' % ret_str def VM_start(self, _1, ref, _2, _3): vm = fake.get_record('VM', ref)