trunk merge

This commit is contained in:
Sandy Walsh
2011-05-27 07:31:29 -07:00
6 changed files with 388 additions and 105 deletions

View File

@@ -31,6 +31,7 @@ LOG = logging.getLogger("nova.fakerabbit")
EXCHANGES = {} EXCHANGES = {}
QUEUES = {} QUEUES = {}
CONSUMERS = {}
class Message(base.BaseMessage): class Message(base.BaseMessage):
@@ -96,17 +97,29 @@ class Backend(base.BaseBackend):
' key %(routing_key)s') % locals()) ' key %(routing_key)s') % locals())
EXCHANGES[exchange].bind(QUEUES[queue].push, routing_key) EXCHANGES[exchange].bind(QUEUES[queue].push, routing_key)
def declare_consumer(self, queue, callback, *args, **kwargs): def declare_consumer(self, queue, callback, consumer_tag, *args, **kwargs):
self.current_queue = queue global CONSUMERS
self.current_callback = callback LOG.debug("Adding consumer %s", consumer_tag)
CONSUMERS[consumer_tag] = (queue, callback)
def cancel(self, consumer_tag):
global CONSUMERS
LOG.debug("Removing consumer %s", consumer_tag)
del CONSUMERS[consumer_tag]
def consume(self, limit=None): def consume(self, limit=None):
global CONSUMERS
num = 0
while True: while True:
item = self.get(self.current_queue) for (queue, callback) in CONSUMERS.itervalues():
if item: item = self.get(queue)
self.current_callback(item) if item:
raise StopIteration() callback(item)
greenthread.sleep(0) num += 1
yield
if limit and num == limit:
raise StopIteration()
greenthread.sleep(0.1)
def get(self, queue, no_ack=False): def get(self, queue, no_ack=False):
global QUEUES global QUEUES
@@ -134,5 +147,7 @@ class Backend(base.BaseBackend):
def reset_all(): def reset_all():
global EXCHANGES global EXCHANGES
global QUEUES global QUEUES
global CONSUMERS
EXCHANGES = {} EXCHANGES = {}
QUEUES = {} QUEUES = {}
CONSUMERS = {}

View File

@@ -28,12 +28,15 @@ import json
import sys import sys
import time import time
import traceback import traceback
import types
import uuid import uuid
from carrot import connection as carrot_connection from carrot import connection as carrot_connection
from carrot import messaging from carrot import messaging
from eventlet import greenpool from eventlet import greenpool
from eventlet import greenthread from eventlet import pools
from eventlet import queue
import greenlet
from nova import context from nova import context
from nova import exception from nova import exception
@@ -47,7 +50,10 @@ LOG = logging.getLogger('nova.rpc')
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_integer('rpc_thread_pool_size', 1024, 'Size of RPC thread pool') flags.DEFINE_integer('rpc_thread_pool_size', 1024,
'Size of RPC thread pool')
flags.DEFINE_integer('rpc_conn_pool_size', 30,
'Size of RPC connection pool')
class Connection(carrot_connection.BrokerConnection): class Connection(carrot_connection.BrokerConnection):
@@ -90,6 +96,22 @@ class Connection(carrot_connection.BrokerConnection):
return cls.instance() return cls.instance()
class Pool(pools.Pool):
"""Class that implements a Pool of Connections."""
# TODO(comstud): Timeout connections not used in a while
def create(self):
LOG.debug('Creating new connection')
return Connection.instance(new=True)
# Create a ConnectionPool to use for RPC calls. We'll order the
# pool as a stack (LIFO), so that we can potentially loop through and
# timeout old unused connections at some point
ConnectionPool = Pool(
max_size=FLAGS.rpc_conn_pool_size,
order_as_stack=True)
class Consumer(messaging.Consumer): class Consumer(messaging.Consumer):
"""Consumer base class. """Consumer base class.
@@ -131,7 +153,9 @@ class Consumer(messaging.Consumer):
self.connection = Connection.recreate() self.connection = Connection.recreate()
self.backend = self.connection.create_backend() self.backend = self.connection.create_backend()
self.declare() self.declare()
super(Consumer, self).fetch(no_ack, auto_ack, enable_callbacks) return super(Consumer, self).fetch(no_ack,
auto_ack,
enable_callbacks)
if self.failed_connection: if self.failed_connection:
LOG.error(_('Reconnected to queue')) LOG.error(_('Reconnected to queue'))
self.failed_connection = False self.failed_connection = False
@@ -159,13 +183,13 @@ class AdapterConsumer(Consumer):
self.pool = greenpool.GreenPool(FLAGS.rpc_thread_pool_size) self.pool = greenpool.GreenPool(FLAGS.rpc_thread_pool_size)
super(AdapterConsumer, self).__init__(connection=connection, super(AdapterConsumer, self).__init__(connection=connection,
topic=topic) topic=topic)
self.register_callback(self.process_data)
def receive(self, *args, **kwargs): def process_data(self, message_data, message):
self.pool.spawn_n(self._receive, *args, **kwargs) """Consumer callback to call a method on a proxy object.
@exception.wrap_exception Parses the message for validity and fires off a thread to call the
def _receive(self, message_data, message): proxy object method.
"""Magically looks for a method on the proxy object and calls it.
Message data should be a dictionary with two keys: Message data should be a dictionary with two keys:
method: string representing the method to call method: string representing the method to call
@@ -175,8 +199,8 @@ class AdapterConsumer(Consumer):
""" """
LOG.debug(_('received %s') % message_data) LOG.debug(_('received %s') % message_data)
msg_id = message_data.pop('_msg_id', None) # This will be popped off in _unpack_context
msg_id = message_data.get('_msg_id', None)
ctxt = _unpack_context(message_data) ctxt = _unpack_context(message_data)
method = message_data.get('method') method = message_data.get('method')
@@ -188,8 +212,17 @@ class AdapterConsumer(Consumer):
# we just log the message and send an error string # we just log the message and send an error string
# back to the caller # back to the caller
LOG.warn(_('no method for message: %s') % message_data) LOG.warn(_('no method for message: %s') % message_data)
msg_reply(msg_id, _('No method for message: %s') % message_data) if msg_id:
msg_reply(msg_id,
_('No method for message: %s') % message_data)
return return
self.pool.spawn_n(self._process_data, msg_id, ctxt, method, args)
@exception.wrap_exception
def _process_data(self, msg_id, ctxt, method, args):
"""Thread that maigcally looks for a method on the proxy
object and calls it.
"""
node_func = getattr(self.proxy, str(method)) node_func = getattr(self.proxy, str(method))
node_args = dict((str(k), v) for k, v in args.iteritems()) node_args = dict((str(k), v) for k, v in args.iteritems())
@@ -197,7 +230,18 @@ class AdapterConsumer(Consumer):
try: try:
rval = node_func(context=ctxt, **node_args) rval = node_func(context=ctxt, **node_args)
if msg_id: if msg_id:
msg_reply(msg_id, rval, None) # Check if the result was a generator
if isinstance(rval, types.GeneratorType):
for x in rval:
msg_reply(msg_id, x, None)
else:
msg_reply(msg_id, rval, None)
# This final None tells multicall that it is done.
msg_reply(msg_id, None, None)
elif isinstance(rval, types.GeneratorType):
# NOTE(vish): this iterates through the generator
list(rval)
except Exception as e: except Exception as e:
logging.exception('Exception during message handling') logging.exception('Exception during message handling')
if msg_id: if msg_id:
@@ -205,11 +249,6 @@ class AdapterConsumer(Consumer):
return return
class Publisher(messaging.Publisher):
"""Publisher base class."""
pass
class TopicAdapterConsumer(AdapterConsumer): class TopicAdapterConsumer(AdapterConsumer):
"""Consumes messages on a specific topic.""" """Consumes messages on a specific topic."""
@@ -242,6 +281,58 @@ class FanoutAdapterConsumer(AdapterConsumer):
topic=topic, proxy=proxy) topic=topic, proxy=proxy)
class ConsumerSet(object):
"""Groups consumers to listen on together on a single connection."""
def __init__(self, connection, consumer_list):
self.consumer_list = set(consumer_list)
self.consumer_set = None
self.enabled = True
self.init(connection)
def init(self, conn):
if not conn:
conn = Connection.instance(new=True)
if self.consumer_set:
self.consumer_set.close()
self.consumer_set = messaging.ConsumerSet(conn)
for consumer in self.consumer_list:
consumer.connection = conn
# consumer.backend is set for us
self.consumer_set.add_consumer(consumer)
def reconnect(self):
self.init(None)
def wait(self, limit=None):
running = True
while running:
it = self.consumer_set.iterconsume(limit=limit)
if not it:
break
while True:
try:
it.next()
except StopIteration:
return
except greenlet.GreenletExit:
running = False
break
except Exception as e:
LOG.exception(_("Exception while processing consumer"))
self.reconnect()
# Break to outer loop
break
def close(self):
self.consumer_set.close()
class Publisher(messaging.Publisher):
"""Publisher base class."""
pass
class TopicPublisher(Publisher): class TopicPublisher(Publisher):
"""Publishes messages on a specific topic.""" """Publishes messages on a specific topic."""
@@ -306,16 +397,18 @@ def msg_reply(msg_id, reply=None, failure=None):
LOG.error(_("Returning exception %s to caller"), message) LOG.error(_("Returning exception %s to caller"), message)
LOG.error(tb) LOG.error(tb)
failure = (failure[0].__name__, str(failure[1]), tb) failure = (failure[0].__name__, str(failure[1]), tb)
conn = Connection.instance()
publisher = DirectPublisher(connection=conn, msg_id=msg_id) with ConnectionPool.item() as conn:
try: publisher = DirectPublisher(connection=conn, msg_id=msg_id)
publisher.send({'result': reply, 'failure': failure}) try:
except TypeError: publisher.send({'result': reply, 'failure': failure})
publisher.send( except TypeError:
{'result': dict((k, repr(v)) publisher.send(
for k, v in reply.__dict__.iteritems()), {'result': dict((k, repr(v))
'failure': failure}) for k, v in reply.__dict__.iteritems()),
publisher.close() 'failure': failure})
publisher.close()
class RemoteError(exception.Error): class RemoteError(exception.Error):
@@ -347,8 +440,9 @@ def _unpack_context(msg):
if key.startswith('_context_'): if key.startswith('_context_'):
value = msg.pop(key) value = msg.pop(key)
context_dict[key[9:]] = value context_dict[key[9:]] = value
context_dict['msg_id'] = msg.pop('_msg_id', None)
LOG.debug(_('unpacked context: %s'), context_dict) LOG.debug(_('unpacked context: %s'), context_dict)
return context.RequestContext.from_dict(context_dict) return RpcContext.from_dict(context_dict)
def _pack_context(msg, context): def _pack_context(msg, context):
@@ -360,70 +454,112 @@ def _pack_context(msg, context):
for args at some point. for args at some point.
""" """
context = dict([('_context_%s' % key, value) context_d = dict([('_context_%s' % key, value)
for (key, value) in context.to_dict().iteritems()]) for (key, value) in context.to_dict().iteritems()])
msg.update(context) msg.update(context_d)
def call(context, topic, msg): class RpcContext(context.RequestContext):
"""Sends a message on a topic and wait for a response.""" def __init__(self, *args, **kwargs):
msg_id = kwargs.pop('msg_id', None)
self.msg_id = msg_id
super(RpcContext, self).__init__(*args, **kwargs)
def reply(self, *args, **kwargs):
msg_reply(self.msg_id, *args, **kwargs)
def multicall(context, topic, msg):
"""Make a call that returns multiple times."""
LOG.debug(_('Making asynchronous call on %s ...'), topic) LOG.debug(_('Making asynchronous call on %s ...'), topic)
msg_id = uuid.uuid4().hex msg_id = uuid.uuid4().hex
msg.update({'_msg_id': msg_id}) msg.update({'_msg_id': msg_id})
LOG.debug(_('MSG_ID is %s') % (msg_id)) LOG.debug(_('MSG_ID is %s') % (msg_id))
_pack_context(msg, context) _pack_context(msg, context)
class WaitMessage(object): con_conn = ConnectionPool.get()
def __call__(self, data, message): consumer = DirectConsumer(connection=con_conn, msg_id=msg_id)
"""Acks message and sets result.""" wait_msg = MulticallWaiter(consumer)
message.ack()
if data['failure']:
self.result = RemoteError(*data['failure'])
else:
self.result = data['result']
wait_msg = WaitMessage()
conn = Connection.instance()
consumer = DirectConsumer(connection=conn, msg_id=msg_id)
consumer.register_callback(wait_msg) consumer.register_callback(wait_msg)
conn = Connection.instance() publisher = TopicPublisher(connection=con_conn, topic=topic)
publisher = TopicPublisher(connection=conn, topic=topic)
publisher.send(msg) publisher.send(msg)
publisher.close() publisher.close()
try: return wait_msg
consumer.wait(limit=1)
except StopIteration:
pass class MulticallWaiter(object):
consumer.close() def __init__(self, consumer):
# NOTE(termie): this is a little bit of a change from the original self._consumer = consumer
# non-eventlet code where returning a Failure self._results = queue.Queue()
# instance from a deferred call is very similar to self._closed = False
# raising an exception
if isinstance(wait_msg.result, Exception): def close(self):
raise wait_msg.result self._closed = True
return wait_msg.result self._consumer.close()
ConnectionPool.put(self._consumer.connection)
def __call__(self, data, message):
"""Acks message and sets result."""
message.ack()
if data['failure']:
self._results.put(RemoteError(*data['failure']))
else:
self._results.put(data['result'])
def __iter__(self):
return self.wait()
def wait(self):
while True:
rv = None
while rv is None and not self._closed:
try:
rv = self._consumer.fetch(enable_callbacks=True)
except Exception:
self.close()
raise
time.sleep(0.01)
result = self._results.get()
if isinstance(result, Exception):
self.close()
raise result
if result == None:
self.close()
raise StopIteration
yield result
def call(context, topic, msg):
"""Sends a message on a topic and wait for a response."""
rv = multicall(context, topic, msg)
# NOTE(vish): return the last result from the multicall
rv = list(rv)
if not rv:
return
return rv[-1]
def cast(context, topic, msg): def cast(context, topic, msg):
"""Sends a message on a topic without waiting for a response.""" """Sends a message on a topic without waiting for a response."""
LOG.debug(_('Making asynchronous cast on %s...'), topic) LOG.debug(_('Making asynchronous cast on %s...'), topic)
_pack_context(msg, context) _pack_context(msg, context)
conn = Connection.instance() with ConnectionPool.item() as conn:
publisher = TopicPublisher(connection=conn, topic=topic) publisher = TopicPublisher(connection=conn, topic=topic)
publisher.send(msg) publisher.send(msg)
publisher.close() publisher.close()
def fanout_cast(context, topic, msg): def fanout_cast(context, topic, msg):
"""Sends a message on a fanout exchange without waiting for a response.""" """Sends a message on a fanout exchange without waiting for a response."""
LOG.debug(_('Making asynchronous fanout cast...')) LOG.debug(_('Making asynchronous fanout cast...'))
_pack_context(msg, context) _pack_context(msg, context)
conn = Connection.instance() with ConnectionPool.item() as conn:
publisher = FanoutPublisher(topic, connection=conn) publisher = FanoutPublisher(topic, connection=conn)
publisher.send(msg) publisher.send(msg)
publisher.close() publisher.close()
def generic_response(message_data, message): def generic_response(message_data, message):
@@ -459,6 +595,7 @@ def send_message(topic, message, wait=True):
if wait: if wait:
consumer.wait() consumer.wait()
consumer.close()
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -17,13 +17,9 @@
# under the License. # under the License.
from base64 import b64decode from base64 import b64decode
import json
from M2Crypto import BIO from M2Crypto import BIO
from M2Crypto import RSA from M2Crypto import RSA
import os import os
import shutil
import tempfile
import time
from eventlet import greenthread from eventlet import greenthread
@@ -33,12 +29,10 @@ from nova import db
from nova import flags from nova import flags
from nova import log as logging from nova import log as logging
from nova import rpc from nova import rpc
from nova import service
from nova import test from nova import test
from nova import utils from nova import utils
from nova import exception from nova import exception
from nova.auth import manager from nova.auth import manager
from nova.compute import power_state
from nova.api.ec2 import cloud from nova.api.ec2 import cloud
from nova.api.ec2 import ec2utils from nova.api.ec2 import ec2utils
from nova.image import local from nova.image import local
@@ -79,14 +73,21 @@ class CloudTestCase(test.TestCase):
self.stubs.Set(local.LocalImageService, 'show', fake_show) self.stubs.Set(local.LocalImageService, 'show', fake_show)
self.stubs.Set(local.LocalImageService, 'show_by_name', fake_show) self.stubs.Set(local.LocalImageService, 'show_by_name', fake_show)
# NOTE(vish): set up a manual wait so rpc.cast has a chance to finish
rpc_cast = rpc.cast
def finish_cast(*args, **kwargs):
rpc_cast(*args, **kwargs)
greenthread.sleep(0.2)
self.stubs.Set(rpc, 'cast', finish_cast)
def tearDown(self): def tearDown(self):
network_ref = db.project_get_network(self.context, network_ref = db.project_get_network(self.context,
self.project.id) self.project.id)
db.network_disassociate(self.context, network_ref['id']) db.network_disassociate(self.context, network_ref['id'])
self.manager.delete_project(self.project) self.manager.delete_project(self.project)
self.manager.delete_user(self.user) self.manager.delete_user(self.user)
self.compute.kill()
self.network.kill()
super(CloudTestCase, self).tearDown() super(CloudTestCase, self).tearDown()
def _create_key(self, name): def _create_key(self, name):
@@ -113,7 +114,6 @@ class CloudTestCase(test.TestCase):
self.cloud.describe_addresses(self.context) self.cloud.describe_addresses(self.context)
self.cloud.release_address(self.context, self.cloud.release_address(self.context,
public_ip=address) public_ip=address)
greenthread.sleep(0.3)
db.floating_ip_destroy(self.context, address) db.floating_ip_destroy(self.context, address)
def test_associate_disassociate_address(self): def test_associate_disassociate_address(self):
@@ -129,12 +129,10 @@ class CloudTestCase(test.TestCase):
self.cloud.associate_address(self.context, self.cloud.associate_address(self.context,
instance_id=ec2_id, instance_id=ec2_id,
public_ip=address) public_ip=address)
greenthread.sleep(0.3)
self.cloud.disassociate_address(self.context, self.cloud.disassociate_address(self.context,
public_ip=address) public_ip=address)
self.cloud.release_address(self.context, self.cloud.release_address(self.context,
public_ip=address) public_ip=address)
greenthread.sleep(0.3)
self.network.deallocate_fixed_ip(self.context, fixed) self.network.deallocate_fixed_ip(self.context, fixed)
db.instance_destroy(self.context, inst['id']) db.instance_destroy(self.context, inst['id'])
db.floating_ip_destroy(self.context, address) db.floating_ip_destroy(self.context, address)
@@ -306,31 +304,25 @@ class CloudTestCase(test.TestCase):
'instance_type': instance_type, 'instance_type': instance_type,
'max_count': max_count} 'max_count': max_count}
rv = self.cloud.run_instances(self.context, **kwargs) rv = self.cloud.run_instances(self.context, **kwargs)
greenthread.sleep(0.3)
instance_id = rv['instancesSet'][0]['instanceId'] instance_id = rv['instancesSet'][0]['instanceId']
output = self.cloud.get_console_output(context=self.context, output = self.cloud.get_console_output(context=self.context,
instance_id=[instance_id]) instance_id=[instance_id])
self.assertEquals(b64decode(output['output']), 'FAKE CONSOLE?OUTPUT') self.assertEquals(b64decode(output['output']), 'FAKE CONSOLE?OUTPUT')
# TODO(soren): We need this until we can stop polling in the rpc code # TODO(soren): We need this until we can stop polling in the rpc code
# for unit tests. # for unit tests.
greenthread.sleep(0.3)
rv = self.cloud.terminate_instances(self.context, [instance_id]) rv = self.cloud.terminate_instances(self.context, [instance_id])
greenthread.sleep(0.3)
def test_ajax_console(self): def test_ajax_console(self):
kwargs = {'image_id': 'ami-1'} kwargs = {'image_id': 'ami-1'}
rv = self.cloud.run_instances(self.context, **kwargs) rv = self.cloud.run_instances(self.context, **kwargs)
instance_id = rv['instancesSet'][0]['instanceId'] instance_id = rv['instancesSet'][0]['instanceId']
greenthread.sleep(0.3)
output = self.cloud.get_ajax_console(context=self.context, output = self.cloud.get_ajax_console(context=self.context,
instance_id=[instance_id]) instance_id=[instance_id])
self.assertEquals(output['url'], self.assertEquals(output['url'],
'%s/?token=FAKETOKEN' % FLAGS.ajax_console_proxy_url) '%s/?token=FAKETOKEN' % FLAGS.ajax_console_proxy_url)
# TODO(soren): We need this until we can stop polling in the rpc code # TODO(soren): We need this until we can stop polling in the rpc code
# for unit tests. # for unit tests.
greenthread.sleep(0.3)
rv = self.cloud.terminate_instances(self.context, [instance_id]) rv = self.cloud.terminate_instances(self.context, [instance_id])
greenthread.sleep(0.3)
def test_key_generation(self): def test_key_generation(self):
result = self._create_key('test') result = self._create_key('test')

View File

@@ -31,7 +31,6 @@ LOG = logging.getLogger('nova.tests.rpc')
class RpcTestCase(test.TestCase): class RpcTestCase(test.TestCase):
"""Test cases for rpc"""
def setUp(self): def setUp(self):
super(RpcTestCase, self).setUp() super(RpcTestCase, self).setUp()
self.conn = rpc.Connection.instance(True) self.conn = rpc.Connection.instance(True)
@@ -43,14 +42,55 @@ class RpcTestCase(test.TestCase):
self.context = context.get_admin_context() self.context = context.get_admin_context()
def test_call_succeed(self): def test_call_succeed(self):
"""Get a value through rpc call"""
value = 42 value = 42
result = rpc.call(self.context, 'test', {"method": "echo", result = rpc.call(self.context, 'test', {"method": "echo",
"args": {"value": value}}) "args": {"value": value}})
self.assertEqual(value, result) self.assertEqual(value, result)
def test_call_succeed_despite_multiple_returns(self):
value = 42
result = rpc.call(self.context, 'test', {"method": "echo_three_times",
"args": {"value": value}})
self.assertEqual(value + 2, result)
def test_call_succeed_despite_multiple_returns_yield(self):
value = 42
result = rpc.call(self.context, 'test',
{"method": "echo_three_times_yield",
"args": {"value": value}})
self.assertEqual(value + 2, result)
def test_multicall_succeed_once(self):
value = 42
result = rpc.multicall(self.context,
'test',
{"method": "echo",
"args": {"value": value}})
for i, x in enumerate(result):
if i > 0:
self.fail('should only receive one response')
self.assertEqual(value + i, x)
def test_multicall_succeed_three_times(self):
value = 42
result = rpc.multicall(self.context,
'test',
{"method": "echo_three_times",
"args": {"value": value}})
for i, x in enumerate(result):
self.assertEqual(value + i, x)
def test_multicall_succeed_three_times_yield(self):
value = 42
result = rpc.multicall(self.context,
'test',
{"method": "echo_three_times_yield",
"args": {"value": value}})
for i, x in enumerate(result):
self.assertEqual(value + i, x)
def test_context_passed(self): def test_context_passed(self):
"""Makes sure a context is passed through rpc call""" """Makes sure a context is passed through rpc call."""
value = 42 value = 42
result = rpc.call(self.context, result = rpc.call(self.context,
'test', {"method": "context", 'test', {"method": "context",
@@ -58,11 +98,12 @@ class RpcTestCase(test.TestCase):
self.assertEqual(self.context.to_dict(), result) self.assertEqual(self.context.to_dict(), result)
def test_call_exception(self): def test_call_exception(self):
"""Test that exception gets passed back properly """Test that exception gets passed back properly.
rpc.call returns a RemoteError object. The value of the rpc.call returns a RemoteError object. The value of the
exception is converted to a string, so we convert it back exception is converted to a string, so we convert it back
to an int in the test. to an int in the test.
""" """
value = 42 value = 42
self.assertRaises(rpc.RemoteError, self.assertRaises(rpc.RemoteError,
@@ -81,7 +122,7 @@ class RpcTestCase(test.TestCase):
self.assertEqual(int(exc.value), value) self.assertEqual(int(exc.value), value)
def test_nested_calls(self): def test_nested_calls(self):
"""Test that we can do an rpc.call inside another call""" """Test that we can do an rpc.call inside another call."""
class Nested(object): class Nested(object):
@staticmethod @staticmethod
def echo(context, queue, value): def echo(context, queue, value):
@@ -108,25 +149,80 @@ class RpcTestCase(test.TestCase):
"value": value}}) "value": value}})
self.assertEqual(value, result) self.assertEqual(value, result)
def test_connectionpool_single(self):
"""Test that ConnectionPool recycles a single connection."""
conn1 = rpc.ConnectionPool.get()
rpc.ConnectionPool.put(conn1)
conn2 = rpc.ConnectionPool.get()
rpc.ConnectionPool.put(conn2)
self.assertEqual(conn1, conn2)
def test_connectionpool_double(self):
"""Test that ConnectionPool returns and reuses separate connections.
When called consecutively we should get separate connections and upon
returning them those connections should be reused for future calls
before generating a new connection.
"""
conn1 = rpc.ConnectionPool.get()
conn2 = rpc.ConnectionPool.get()
self.assertNotEqual(conn1, conn2)
rpc.ConnectionPool.put(conn1)
rpc.ConnectionPool.put(conn2)
conn3 = rpc.ConnectionPool.get()
conn4 = rpc.ConnectionPool.get()
self.assertEqual(conn1, conn3)
self.assertEqual(conn2, conn4)
def test_connectionpool_limit(self):
"""Test connection pool limit and connection uniqueness."""
max_size = FLAGS.rpc_conn_pool_size
conns = []
for i in xrange(max_size):
conns.append(rpc.ConnectionPool.get())
self.assertFalse(rpc.ConnectionPool.free_items)
self.assertEqual(rpc.ConnectionPool.current_size,
rpc.ConnectionPool.max_size)
self.assertEqual(len(set(conns)), max_size)
class TestReceiver(object): class TestReceiver(object):
"""Simple Proxy class so the consumer has methods to call """Simple Proxy class so the consumer has methods to call.
Uses static methods because we aren't actually storing any state""" Uses static methods because we aren't actually storing any state.
"""
@staticmethod @staticmethod
def echo(context, value): def echo(context, value):
"""Simply returns whatever value is sent in""" """Simply returns whatever value is sent in."""
LOG.debug(_("Received %s"), value) LOG.debug(_("Received %s"), value)
return value return value
@staticmethod @staticmethod
def context(context, value): def context(context, value):
"""Returns dictionary version of context""" """Returns dictionary version of context."""
LOG.debug(_("Received %s"), context) LOG.debug(_("Received %s"), context)
return context.to_dict() return context.to_dict()
@staticmethod
def echo_three_times(context, value):
context.reply(value)
context.reply(value + 1)
context.reply(value + 2)
@staticmethod
def echo_three_times_yield(context, value):
yield value
yield value + 1
yield value + 2
@staticmethod @staticmethod
def fail(context, value): def fail(context, value):
"""Raises an exception with the value sent in""" """Raises an exception with the value sent in."""
raise Exception(value) raise Exception(value)

View File

@@ -395,6 +395,29 @@ class XenAPIVMTestCase(test.TestCase):
os_type="linux") os_type="linux")
self.check_vm_params_for_linux() self.check_vm_params_for_linux()
def test_spawn_vhd_glance_swapdisk(self):
# Change the default host_call_plugin to one that'll return
# a swap disk
orig_func = stubs.FakeSessionForVMTests.host_call_plugin
stubs.FakeSessionForVMTests.host_call_plugin = \
stubs.FakeSessionForVMTests.host_call_plugin_swap
try:
# We'll steal the above glance linux test
self.test_spawn_vhd_glance_linux()
finally:
# Make sure to put this back
stubs.FakeSessionForVMTests.host_call_plugin = orig_func
# We should have 2 VBDs.
self.assertEqual(len(self.vm['VBDs']), 2)
# Now test that we have 1.
self.tearDown()
self.setUp()
self.test_spawn_vhd_glance_linux()
self.assertEqual(len(self.vm['VBDs']), 1)
def test_spawn_vhd_glance_windows(self): def test_spawn_vhd_glance_windows(self):
FLAGS.xenapi_image_service = 'glance' FLAGS.xenapi_image_service = 'glance'
self._test_spawn(glance_stubs.FakeGlance.IMAGE_VHD, None, None, self._test_spawn(glance_stubs.FakeGlance.IMAGE_VHD, None, None,

View File

@@ -17,6 +17,7 @@
"""Stubouts, mocks and fixtures for the test suite""" """Stubouts, mocks and fixtures for the test suite"""
import eventlet import eventlet
import json
from nova.virt import xenapi_conn from nova.virt import xenapi_conn
from nova.virt.xenapi import fake from nova.virt.xenapi import fake
from nova.virt.xenapi import volume_utils from nova.virt.xenapi import volume_utils
@@ -37,7 +38,7 @@ def stubout_instance_snapshot(stubs):
sr_ref=sr_ref, sharable=False) sr_ref=sr_ref, sharable=False)
vdi_rec = session.get_xenapi().VDI.get_record(vdi_ref) vdi_rec = session.get_xenapi().VDI.get_record(vdi_ref)
vdi_uuid = vdi_rec['uuid'] vdi_uuid = vdi_rec['uuid']
return vdi_uuid return [dict(vdi_type='os', vdi_uuid=vdi_uuid)]
stubs.Set(vm_utils.VMHelper, 'fetch_image', fake_fetch_image) stubs.Set(vm_utils.VMHelper, 'fetch_image', fake_fetch_image)
@@ -132,11 +133,30 @@ class FakeSessionForVMTests(fake.SessionBase):
def __init__(self, uri): def __init__(self, uri):
super(FakeSessionForVMTests, self).__init__(uri) super(FakeSessionForVMTests, self).__init__(uri)
def host_call_plugin(self, _1, _2, _3, _4, _5): def host_call_plugin(self, _1, _2, plugin, method, _5):
sr_ref = fake.get_all('SR')[0] sr_ref = fake.get_all('SR')[0]
vdi_ref = fake.create_vdi('', False, sr_ref, False) vdi_ref = fake.create_vdi('', False, sr_ref, False)
vdi_rec = fake.get_record('VDI', vdi_ref) vdi_rec = fake.get_record('VDI', vdi_ref)
return '<string>%s</string>' % vdi_rec['uuid'] if plugin == "glance" and method == "download_vhd":
ret_str = json.dumps([dict(vdi_type='os',
vdi_uuid=vdi_rec['uuid'])])
else:
ret_str = vdi_rec['uuid']
return '<string>%s</string>' % ret_str
def host_call_plugin_swap(self, _1, _2, plugin, method, _5):
sr_ref = fake.get_all('SR')[0]
vdi_ref = fake.create_vdi('', False, sr_ref, False)
vdi_rec = fake.get_record('VDI', vdi_ref)
if plugin == "glance" and method == "download_vhd":
swap_vdi_ref = fake.create_vdi('', False, sr_ref, False)
swap_vdi_rec = fake.get_record('VDI', swap_vdi_ref)
ret_str = json.dumps(
[dict(vdi_type='os', vdi_uuid=vdi_rec['uuid']),
dict(vdi_type='swap', vdi_uuid=swap_vdi_rec['uuid'])])
else:
ret_str = vdi_rec['uuid']
return '<string>%s</string>' % ret_str
def VM_start(self, _1, ref, _2, _3): def VM_start(self, _1, ref, _2, _3):
vm = fake.get_record('VM', ref) vm = fake.get_record('VM', ref)