Merge from nova trunk
This commit is contained in:
		
							
								
								
									
										3
									
								
								Authors
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								Authors
									
									
									
									
									
								
							@@ -1,4 +1,5 @@
 | 
			
		||||
Alex Meade <alex.meade@rackspace.com>
 | 
			
		||||
Andrey Brindeyev <abrindeyev@griddynamics.com>
 | 
			
		||||
Andy Smith <code@term.ie>
 | 
			
		||||
Andy Southgate <andy.southgate@citrix.com>
 | 
			
		||||
Anne Gentle <anne@openstack.org>
 | 
			
		||||
@@ -16,6 +17,7 @@ Christian Berendt <berendt@b1-systems.de>
 | 
			
		||||
Chuck Short <zulcss@ubuntu.com>
 | 
			
		||||
Cory Wright <corywright@gmail.com>
 | 
			
		||||
Dan Prince <dan.prince@rackspace.com>
 | 
			
		||||
Dave Walker <DaveWalker@ubuntu.com>
 | 
			
		||||
David Pravec <David.Pravec@danix.org>
 | 
			
		||||
Dean Troyer <dtroyer@gmail.com>
 | 
			
		||||
Devin Carlen <devin.carlen@gmail.com>
 | 
			
		||||
@@ -65,6 +67,7 @@ Nachi Ueno <ueno.nachi@lab.ntt.co.jp>
 | 
			
		||||
Naveed Massjouni <naveedm9@gmail.com>
 | 
			
		||||
Nirmal Ranganathan <nirmal.ranganathan@rackspace.com>
 | 
			
		||||
Paul Voccio <paul@openstack.org>
 | 
			
		||||
Renuka Apte <renuka.apte@citrix.com>
 | 
			
		||||
Ricardo Carrillo Cruz <emaildericky@gmail.com>
 | 
			
		||||
Rick Clark <rick@openstack.org>
 | 
			
		||||
Rick Harris <rconradharris@gmail.com>
 | 
			
		||||
 
 | 
			
		||||
@@ -35,6 +35,7 @@ include nova/tests/bundle/1mb.manifest.xml
 | 
			
		||||
include nova/tests/bundle/1mb.no_kernel_or_ramdisk.manifest.xml
 | 
			
		||||
include nova/tests/bundle/1mb.part.0
 | 
			
		||||
include nova/tests/bundle/1mb.part.1
 | 
			
		||||
include nova/tests/public_key/*
 | 
			
		||||
include nova/tests/db/nova.austin.sqlite
 | 
			
		||||
include plugins/xenapi/README
 | 
			
		||||
include plugins/xenapi/etc/xapi.d/plugins/objectstore
 | 
			
		||||
 
 | 
			
		||||
@@ -108,6 +108,13 @@ def main():
 | 
			
		||||
    interface = os.environ.get('DNSMASQ_INTERFACE', FLAGS.dnsmasq_interface)
 | 
			
		||||
    if int(os.environ.get('TESTING', '0')):
 | 
			
		||||
        from nova.tests import fake_flags
 | 
			
		||||
 | 
			
		||||
    #if FLAGS.fake_rabbit:
 | 
			
		||||
    #    LOG.debug(_("leasing ip"))
 | 
			
		||||
    #    network_manager = utils.import_object(FLAGS.network_manager)
 | 
			
		||||
    ##    reload(fake_flags)
 | 
			
		||||
    #    from nova.tests import fake_flags
 | 
			
		||||
 | 
			
		||||
    action = argv[1]
 | 
			
		||||
    if action in ['add', 'del', 'old']:
 | 
			
		||||
        mac = argv[2]
 | 
			
		||||
 
 | 
			
		||||
@@ -97,7 +97,7 @@ flags.DECLARE('vlan_start', 'nova.network.manager')
 | 
			
		||||
flags.DECLARE('vpn_start', 'nova.network.manager')
 | 
			
		||||
flags.DECLARE('fixed_range_v6', 'nova.network.manager')
 | 
			
		||||
flags.DECLARE('images_path', 'nova.image.local')
 | 
			
		||||
flags.DECLARE('libvirt_type', 'nova.virt.libvirt_conn')
 | 
			
		||||
flags.DECLARE('libvirt_type', 'nova.virt.libvirt.connection')
 | 
			
		||||
flags.DEFINE_flag(flags.HelpFlag())
 | 
			
		||||
flags.DEFINE_flag(flags.HelpshortFlag())
 | 
			
		||||
flags.DEFINE_flag(flags.HelpXMLFlag())
 | 
			
		||||
@@ -362,27 +362,47 @@ class ProjectCommands(object):
 | 
			
		||||
    def add(self, project_id, user_id):
 | 
			
		||||
        """Adds user to project
 | 
			
		||||
        arguments: project_id user_id"""
 | 
			
		||||
        try:
 | 
			
		||||
            self.manager.add_to_project(user_id, project_id)
 | 
			
		||||
        except exception.UserNotFound as ex:
 | 
			
		||||
            print ex
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
    def create(self, name, project_manager, description=None):
 | 
			
		||||
        """Creates a new project
 | 
			
		||||
        arguments: name project_manager [description]"""
 | 
			
		||||
        try:
 | 
			
		||||
            self.manager.create_project(name, project_manager, description)
 | 
			
		||||
        except exception.UserNotFound as ex:
 | 
			
		||||
            print ex
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
    def modify(self, name, project_manager, description=None):
 | 
			
		||||
        """Modifies a project
 | 
			
		||||
        arguments: name project_manager [description]"""
 | 
			
		||||
        try:
 | 
			
		||||
            self.manager.modify_project(name, project_manager, description)
 | 
			
		||||
        except exception.UserNotFound as ex:
 | 
			
		||||
            print ex
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
    def delete(self, name):
 | 
			
		||||
        """Deletes an existing project
 | 
			
		||||
        arguments: name"""
 | 
			
		||||
        try:
 | 
			
		||||
            self.manager.delete_project(name)
 | 
			
		||||
        except exception.ProjectNotFound as ex:
 | 
			
		||||
            print ex
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
    def environment(self, project_id, user_id, filename='novarc'):
 | 
			
		||||
        """Exports environment variables to an sourcable file
 | 
			
		||||
        arguments: project_id user_id [filename='novarc]"""
 | 
			
		||||
        try:
 | 
			
		||||
            rc = self.manager.get_environment_rc(user_id, project_id)
 | 
			
		||||
        except (exception.UserNotFound, exception.ProjectNotFound) as ex:
 | 
			
		||||
            print ex
 | 
			
		||||
            raise
 | 
			
		||||
        with open(filename, 'w') as f:
 | 
			
		||||
            f.write(rc)
 | 
			
		||||
 | 
			
		||||
@@ -397,18 +417,26 @@ class ProjectCommands(object):
 | 
			
		||||
        arguments: project_id [key] [value]"""
 | 
			
		||||
        ctxt = context.get_admin_context()
 | 
			
		||||
        if key:
 | 
			
		||||
            if value.lower() == 'unlimited':
 | 
			
		||||
                value = None
 | 
			
		||||
            try:
 | 
			
		||||
                db.quota_update(ctxt, project_id, key, value)
 | 
			
		||||
            except exception.NotFound:
 | 
			
		||||
            except exception.ProjectQuotaNotFound:
 | 
			
		||||
                db.quota_create(ctxt, project_id, key, value)
 | 
			
		||||
        project_quota = quota.get_quota(ctxt, project_id)
 | 
			
		||||
        project_quota = quota.get_project_quotas(ctxt, project_id)
 | 
			
		||||
        for key, value in project_quota.iteritems():
 | 
			
		||||
            if value is None:
 | 
			
		||||
                value = 'unlimited'
 | 
			
		||||
            print '%s: %s' % (key, value)
 | 
			
		||||
 | 
			
		||||
    def remove(self, project_id, user_id):
 | 
			
		||||
        """Removes user from project
 | 
			
		||||
        arguments: project_id user_id"""
 | 
			
		||||
        try:
 | 
			
		||||
            self.manager.remove_from_project(user_id, project_id)
 | 
			
		||||
        except (exception.UserNotFound, exception.ProjectNotFound) as ex:
 | 
			
		||||
            print ex
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
    def scrub(self, project_id):
 | 
			
		||||
        """Deletes data associated with project
 | 
			
		||||
@@ -427,6 +455,9 @@ class ProjectCommands(object):
 | 
			
		||||
            zip_file = self.manager.get_credentials(user_id, project_id)
 | 
			
		||||
            with open(filename, 'w') as f:
 | 
			
		||||
                f.write(zip_file)
 | 
			
		||||
        except (exception.UserNotFound, exception.ProjectNotFound) as ex:
 | 
			
		||||
            print ex
 | 
			
		||||
            raise
 | 
			
		||||
        except db.api.NoMoreNetworks:
 | 
			
		||||
            print _('No more networks available. If this is a new '
 | 
			
		||||
                    'installation, you need\nto call something like this:\n\n'
 | 
			
		||||
 
 | 
			
		||||
@@ -31,6 +31,7 @@ LOG = logging.getLogger("nova.fakerabbit")
 | 
			
		||||
 | 
			
		||||
EXCHANGES = {}
 | 
			
		||||
QUEUES = {}
 | 
			
		||||
CONSUMERS = {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Message(base.BaseMessage):
 | 
			
		||||
@@ -96,17 +97,29 @@ class Backend(base.BaseBackend):
 | 
			
		||||
                ' key %(routing_key)s') % locals())
 | 
			
		||||
        EXCHANGES[exchange].bind(QUEUES[queue].push, routing_key)
 | 
			
		||||
 | 
			
		||||
    def declare_consumer(self, queue, callback, *args, **kwargs):
 | 
			
		||||
        self.current_queue = queue
 | 
			
		||||
        self.current_callback = callback
 | 
			
		||||
    def declare_consumer(self, queue, callback, consumer_tag, *args, **kwargs):
 | 
			
		||||
        global CONSUMERS
 | 
			
		||||
        LOG.debug("Adding consumer %s", consumer_tag)
 | 
			
		||||
        CONSUMERS[consumer_tag] = (queue, callback)
 | 
			
		||||
 | 
			
		||||
    def cancel(self, consumer_tag):
 | 
			
		||||
        global CONSUMERS
 | 
			
		||||
        LOG.debug("Removing consumer %s", consumer_tag)
 | 
			
		||||
        del CONSUMERS[consumer_tag]
 | 
			
		||||
 | 
			
		||||
    def consume(self, limit=None):
 | 
			
		||||
        global CONSUMERS
 | 
			
		||||
        num = 0
 | 
			
		||||
        while True:
 | 
			
		||||
            item = self.get(self.current_queue)
 | 
			
		||||
            for (queue, callback) in CONSUMERS.itervalues():
 | 
			
		||||
                item = self.get(queue)
 | 
			
		||||
                if item:
 | 
			
		||||
                self.current_callback(item)
 | 
			
		||||
                    callback(item)
 | 
			
		||||
                    num += 1
 | 
			
		||||
                    yield
 | 
			
		||||
                    if limit and num == limit:
 | 
			
		||||
                        raise StopIteration()
 | 
			
		||||
            greenthread.sleep(0)
 | 
			
		||||
            greenthread.sleep(0.1)
 | 
			
		||||
 | 
			
		||||
    def get(self, queue, no_ack=False):
 | 
			
		||||
        global QUEUES
 | 
			
		||||
@@ -134,5 +147,7 @@ class Backend(base.BaseBackend):
 | 
			
		||||
def reset_all():
 | 
			
		||||
    global EXCHANGES
 | 
			
		||||
    global QUEUES
 | 
			
		||||
    global CONSUMERS
 | 
			
		||||
    EXCHANGES = {}
 | 
			
		||||
    QUEUES = {}
 | 
			
		||||
    CONSUMERS = {}
 | 
			
		||||
 
 | 
			
		||||
@@ -110,7 +110,7 @@ class FlagValues(gflags.FlagValues):
 | 
			
		||||
        return name in self.__dict__['__dirty']
 | 
			
		||||
 | 
			
		||||
    def ClearDirty(self):
 | 
			
		||||
        self.__dict__['__is_dirty'] = []
 | 
			
		||||
        self.__dict__['__dirty'] = []
 | 
			
		||||
 | 
			
		||||
    def WasAlreadyParsed(self):
 | 
			
		||||
        return self.__dict__['__was_already_parsed']
 | 
			
		||||
@@ -119,11 +119,12 @@ class FlagValues(gflags.FlagValues):
 | 
			
		||||
        if '__stored_argv' not in self.__dict__:
 | 
			
		||||
            return
 | 
			
		||||
        new_flags = FlagValues(self)
 | 
			
		||||
        for k in self.__dict__['__dirty']:
 | 
			
		||||
        for k in self.FlagDict().iterkeys():
 | 
			
		||||
            new_flags[k] = gflags.FlagValues.__getitem__(self, k)
 | 
			
		||||
 | 
			
		||||
        new_flags.Reset()
 | 
			
		||||
        new_flags(self.__dict__['__stored_argv'])
 | 
			
		||||
        for k in self.__dict__['__dirty']:
 | 
			
		||||
        for k in new_flags.FlagDict().iterkeys():
 | 
			
		||||
            setattr(self, k, getattr(new_flags, k))
 | 
			
		||||
        self.ClearDirty()
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										231
									
								
								nova/rpc.py
									
									
									
									
									
								
							
							
						
						
									
										231
									
								
								nova/rpc.py
									
									
									
									
									
								
							@@ -28,12 +28,15 @@ import json
 | 
			
		||||
import sys
 | 
			
		||||
import time
 | 
			
		||||
import traceback
 | 
			
		||||
import types
 | 
			
		||||
import uuid
 | 
			
		||||
 | 
			
		||||
from carrot import connection as carrot_connection
 | 
			
		||||
from carrot import messaging
 | 
			
		||||
from eventlet import greenpool
 | 
			
		||||
from eventlet import greenthread
 | 
			
		||||
from eventlet import pools
 | 
			
		||||
from eventlet import queue
 | 
			
		||||
import greenlet
 | 
			
		||||
 | 
			
		||||
from nova import context
 | 
			
		||||
from nova import exception
 | 
			
		||||
@@ -47,7 +50,10 @@ LOG = logging.getLogger('nova.rpc')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
FLAGS = flags.FLAGS
 | 
			
		||||
flags.DEFINE_integer('rpc_thread_pool_size', 1024, 'Size of RPC thread pool')
 | 
			
		||||
flags.DEFINE_integer('rpc_thread_pool_size', 1024,
 | 
			
		||||
                     'Size of RPC thread pool')
 | 
			
		||||
flags.DEFINE_integer('rpc_conn_pool_size', 30,
 | 
			
		||||
                     'Size of RPC connection pool')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Connection(carrot_connection.BrokerConnection):
 | 
			
		||||
@@ -90,6 +96,22 @@ class Connection(carrot_connection.BrokerConnection):
 | 
			
		||||
        return cls.instance()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Pool(pools.Pool):
 | 
			
		||||
    """Class that implements a Pool of Connections."""
 | 
			
		||||
 | 
			
		||||
    # TODO(comstud): Timeout connections not used in a while
 | 
			
		||||
    def create(self):
 | 
			
		||||
        LOG.debug('Creating new connection')
 | 
			
		||||
        return Connection.instance(new=True)
 | 
			
		||||
 | 
			
		||||
# Create a ConnectionPool to use for RPC calls.  We'll order the
 | 
			
		||||
# pool as a stack (LIFO), so that we can potentially loop through and
 | 
			
		||||
# timeout old unused connections at some point
 | 
			
		||||
ConnectionPool = Pool(
 | 
			
		||||
        max_size=FLAGS.rpc_conn_pool_size,
 | 
			
		||||
        order_as_stack=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Consumer(messaging.Consumer):
 | 
			
		||||
    """Consumer base class.
 | 
			
		||||
 | 
			
		||||
@@ -131,7 +153,9 @@ class Consumer(messaging.Consumer):
 | 
			
		||||
                self.connection = Connection.recreate()
 | 
			
		||||
                self.backend = self.connection.create_backend()
 | 
			
		||||
                self.declare()
 | 
			
		||||
            super(Consumer, self).fetch(no_ack, auto_ack, enable_callbacks)
 | 
			
		||||
            return super(Consumer, self).fetch(no_ack,
 | 
			
		||||
                                               auto_ack,
 | 
			
		||||
                                               enable_callbacks)
 | 
			
		||||
            if self.failed_connection:
 | 
			
		||||
                LOG.error(_('Reconnected to queue'))
 | 
			
		||||
                self.failed_connection = False
 | 
			
		||||
@@ -159,13 +183,13 @@ class AdapterConsumer(Consumer):
 | 
			
		||||
        self.pool = greenpool.GreenPool(FLAGS.rpc_thread_pool_size)
 | 
			
		||||
        super(AdapterConsumer, self).__init__(connection=connection,
 | 
			
		||||
                                              topic=topic)
 | 
			
		||||
        self.register_callback(self.process_data)
 | 
			
		||||
 | 
			
		||||
    def receive(self, *args, **kwargs):
 | 
			
		||||
        self.pool.spawn_n(self._receive, *args, **kwargs)
 | 
			
		||||
    def process_data(self, message_data, message):
 | 
			
		||||
        """Consumer callback to call a method on a proxy object.
 | 
			
		||||
 | 
			
		||||
    @exception.wrap_exception
 | 
			
		||||
    def _receive(self, message_data, message):
 | 
			
		||||
        """Magically looks for a method on the proxy object and calls it.
 | 
			
		||||
        Parses the message for validity and fires off a thread to call the
 | 
			
		||||
        proxy object method.
 | 
			
		||||
 | 
			
		||||
        Message data should be a dictionary with two keys:
 | 
			
		||||
            method: string representing the method to call
 | 
			
		||||
@@ -175,8 +199,8 @@ class AdapterConsumer(Consumer):
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        LOG.debug(_('received %s') % message_data)
 | 
			
		||||
        msg_id = message_data.pop('_msg_id', None)
 | 
			
		||||
 | 
			
		||||
        # This will be popped off in _unpack_context
 | 
			
		||||
        msg_id = message_data.get('_msg_id', None)
 | 
			
		||||
        ctxt = _unpack_context(message_data)
 | 
			
		||||
 | 
			
		||||
        method = message_data.get('method')
 | 
			
		||||
@@ -188,8 +212,17 @@ class AdapterConsumer(Consumer):
 | 
			
		||||
            #             we just log the message and send an error string
 | 
			
		||||
            #             back to the caller
 | 
			
		||||
            LOG.warn(_('no method for message: %s') % message_data)
 | 
			
		||||
            msg_reply(msg_id, _('No method for message: %s') % message_data)
 | 
			
		||||
            if msg_id:
 | 
			
		||||
                msg_reply(msg_id,
 | 
			
		||||
                          _('No method for message: %s') % message_data)
 | 
			
		||||
            return
 | 
			
		||||
        self.pool.spawn_n(self._process_data, msg_id, ctxt, method, args)
 | 
			
		||||
 | 
			
		||||
    @exception.wrap_exception
 | 
			
		||||
    def _process_data(self, msg_id, ctxt, method, args):
 | 
			
		||||
        """Thread that maigcally looks for a method on the proxy
 | 
			
		||||
        object and calls it.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        node_func = getattr(self.proxy, str(method))
 | 
			
		||||
        node_args = dict((str(k), v) for k, v in args.iteritems())
 | 
			
		||||
@@ -197,7 +230,18 @@ class AdapterConsumer(Consumer):
 | 
			
		||||
        try:
 | 
			
		||||
            rval = node_func(context=ctxt, **node_args)
 | 
			
		||||
            if msg_id:
 | 
			
		||||
                # Check if the result was a generator
 | 
			
		||||
                if isinstance(rval, types.GeneratorType):
 | 
			
		||||
                    for x in rval:
 | 
			
		||||
                        msg_reply(msg_id, x, None)
 | 
			
		||||
                else:
 | 
			
		||||
                    msg_reply(msg_id, rval, None)
 | 
			
		||||
 | 
			
		||||
                # This final None tells multicall that it is done.
 | 
			
		||||
                msg_reply(msg_id, None, None)
 | 
			
		||||
            elif isinstance(rval, types.GeneratorType):
 | 
			
		||||
                # NOTE(vish): this iterates through the generator
 | 
			
		||||
                list(rval)
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logging.exception('Exception during message handling')
 | 
			
		||||
            if msg_id:
 | 
			
		||||
@@ -205,11 +249,6 @@ class AdapterConsumer(Consumer):
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Publisher(messaging.Publisher):
 | 
			
		||||
    """Publisher base class."""
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TopicAdapterConsumer(AdapterConsumer):
 | 
			
		||||
    """Consumes messages on a specific topic."""
 | 
			
		||||
 | 
			
		||||
@@ -242,6 +281,58 @@ class FanoutAdapterConsumer(AdapterConsumer):
 | 
			
		||||
                                    topic=topic, proxy=proxy)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ConsumerSet(object):
 | 
			
		||||
    """Groups consumers to listen on together on a single connection."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, connection, consumer_list):
 | 
			
		||||
        self.consumer_list = set(consumer_list)
 | 
			
		||||
        self.consumer_set = None
 | 
			
		||||
        self.enabled = True
 | 
			
		||||
        self.init(connection)
 | 
			
		||||
 | 
			
		||||
    def init(self, conn):
 | 
			
		||||
        if not conn:
 | 
			
		||||
            conn = Connection.instance(new=True)
 | 
			
		||||
        if self.consumer_set:
 | 
			
		||||
            self.consumer_set.close()
 | 
			
		||||
        self.consumer_set = messaging.ConsumerSet(conn)
 | 
			
		||||
        for consumer in self.consumer_list:
 | 
			
		||||
            consumer.connection = conn
 | 
			
		||||
            # consumer.backend is set for us
 | 
			
		||||
            self.consumer_set.add_consumer(consumer)
 | 
			
		||||
 | 
			
		||||
    def reconnect(self):
 | 
			
		||||
        self.init(None)
 | 
			
		||||
 | 
			
		||||
    def wait(self, limit=None):
 | 
			
		||||
        running = True
 | 
			
		||||
        while running:
 | 
			
		||||
            it = self.consumer_set.iterconsume(limit=limit)
 | 
			
		||||
            if not it:
 | 
			
		||||
                break
 | 
			
		||||
            while True:
 | 
			
		||||
                try:
 | 
			
		||||
                    it.next()
 | 
			
		||||
                except StopIteration:
 | 
			
		||||
                    return
 | 
			
		||||
                except greenlet.GreenletExit:
 | 
			
		||||
                    running = False
 | 
			
		||||
                    break
 | 
			
		||||
                except Exception as e:
 | 
			
		||||
                    LOG.exception(_("Exception while processing consumer"))
 | 
			
		||||
                    self.reconnect()
 | 
			
		||||
                    # Break to outer loop
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
        self.consumer_set.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Publisher(messaging.Publisher):
 | 
			
		||||
    """Publisher base class."""
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TopicPublisher(Publisher):
 | 
			
		||||
    """Publishes messages on a specific topic."""
 | 
			
		||||
 | 
			
		||||
@@ -306,7 +397,8 @@ def msg_reply(msg_id, reply=None, failure=None):
 | 
			
		||||
        LOG.error(_("Returning exception %s to caller"), message)
 | 
			
		||||
        LOG.error(tb)
 | 
			
		||||
        failure = (failure[0].__name__, str(failure[1]), tb)
 | 
			
		||||
    conn = Connection.instance()
 | 
			
		||||
 | 
			
		||||
    with ConnectionPool.item() as conn:
 | 
			
		||||
        publisher = DirectPublisher(connection=conn, msg_id=msg_id)
 | 
			
		||||
        try:
 | 
			
		||||
            publisher.send({'result': reply, 'failure': failure})
 | 
			
		||||
@@ -315,6 +407,7 @@ def msg_reply(msg_id, reply=None, failure=None):
 | 
			
		||||
                    {'result': dict((k, repr(v))
 | 
			
		||||
                                    for k, v in reply.__dict__.iteritems()),
 | 
			
		||||
                     'failure': failure})
 | 
			
		||||
 | 
			
		||||
        publisher.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -347,8 +440,9 @@ def _unpack_context(msg):
 | 
			
		||||
        if key.startswith('_context_'):
 | 
			
		||||
            value = msg.pop(key)
 | 
			
		||||
            context_dict[key[9:]] = value
 | 
			
		||||
    context_dict['msg_id'] = msg.pop('_msg_id', None)
 | 
			
		||||
    LOG.debug(_('unpacked context: %s'), context_dict)
 | 
			
		||||
    return context.RequestContext.from_dict(context_dict)
 | 
			
		||||
    return RpcContext.from_dict(context_dict)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _pack_context(msg, context):
 | 
			
		||||
@@ -360,57 +454,99 @@ def _pack_context(msg, context):
 | 
			
		||||
    for args at some point.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    context = dict([('_context_%s' % key, value)
 | 
			
		||||
    context_d = dict([('_context_%s' % key, value)
 | 
			
		||||
                      for (key, value) in context.to_dict().iteritems()])
 | 
			
		||||
    msg.update(context)
 | 
			
		||||
    msg.update(context_d)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def call(context, topic, msg):
 | 
			
		||||
    """Sends a message on a topic and wait for a response."""
 | 
			
		||||
class RpcContext(context.RequestContext):
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        msg_id = kwargs.pop('msg_id', None)
 | 
			
		||||
        self.msg_id = msg_id
 | 
			
		||||
        super(RpcContext, self).__init__(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def reply(self, *args, **kwargs):
 | 
			
		||||
        msg_reply(self.msg_id, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def multicall(context, topic, msg):
 | 
			
		||||
    """Make a call that returns multiple times."""
 | 
			
		||||
    LOG.debug(_('Making asynchronous call on %s ...'), topic)
 | 
			
		||||
    msg_id = uuid.uuid4().hex
 | 
			
		||||
    msg.update({'_msg_id': msg_id})
 | 
			
		||||
    LOG.debug(_('MSG_ID is %s') % (msg_id))
 | 
			
		||||
    _pack_context(msg, context)
 | 
			
		||||
 | 
			
		||||
    class WaitMessage(object):
 | 
			
		||||
    con_conn = ConnectionPool.get()
 | 
			
		||||
    consumer = DirectConsumer(connection=con_conn, msg_id=msg_id)
 | 
			
		||||
    wait_msg = MulticallWaiter(consumer)
 | 
			
		||||
    consumer.register_callback(wait_msg)
 | 
			
		||||
 | 
			
		||||
    publisher = TopicPublisher(connection=con_conn, topic=topic)
 | 
			
		||||
    publisher.send(msg)
 | 
			
		||||
    publisher.close()
 | 
			
		||||
 | 
			
		||||
    return wait_msg
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MulticallWaiter(object):
 | 
			
		||||
    def __init__(self, consumer):
 | 
			
		||||
        self._consumer = consumer
 | 
			
		||||
        self._results = queue.Queue()
 | 
			
		||||
        self._closed = False
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
        self._closed = True
 | 
			
		||||
        self._consumer.close()
 | 
			
		||||
        ConnectionPool.put(self._consumer.connection)
 | 
			
		||||
 | 
			
		||||
    def __call__(self, data, message):
 | 
			
		||||
        """Acks message and sets result."""
 | 
			
		||||
        message.ack()
 | 
			
		||||
        if data['failure']:
 | 
			
		||||
                self.result = RemoteError(*data['failure'])
 | 
			
		||||
            self._results.put(RemoteError(*data['failure']))
 | 
			
		||||
        else:
 | 
			
		||||
                self.result = data['result']
 | 
			
		||||
            self._results.put(data['result'])
 | 
			
		||||
 | 
			
		||||
    wait_msg = WaitMessage()
 | 
			
		||||
    conn = Connection.instance()
 | 
			
		||||
    consumer = DirectConsumer(connection=conn, msg_id=msg_id)
 | 
			
		||||
    consumer.register_callback(wait_msg)
 | 
			
		||||
 | 
			
		||||
    conn = Connection.instance()
 | 
			
		||||
    publisher = TopicPublisher(connection=conn, topic=topic)
 | 
			
		||||
    publisher.send(msg)
 | 
			
		||||
    publisher.close()
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
        return self.wait()
 | 
			
		||||
 | 
			
		||||
    def wait(self):
 | 
			
		||||
        while True:
 | 
			
		||||
            rv = None
 | 
			
		||||
            while rv is None and not self._closed:
 | 
			
		||||
                try:
 | 
			
		||||
        consumer.wait(limit=1)
 | 
			
		||||
    except StopIteration:
 | 
			
		||||
        pass
 | 
			
		||||
    consumer.close()
 | 
			
		||||
    # NOTE(termie): this is a little bit of a change from the original
 | 
			
		||||
    #               non-eventlet code where returning a Failure
 | 
			
		||||
    #               instance from a deferred call is very similar to
 | 
			
		||||
    #               raising an exception
 | 
			
		||||
    if isinstance(wait_msg.result, Exception):
 | 
			
		||||
        raise wait_msg.result
 | 
			
		||||
    return wait_msg.result
 | 
			
		||||
                    rv = self._consumer.fetch(enable_callbacks=True)
 | 
			
		||||
                except Exception:
 | 
			
		||||
                    self.close()
 | 
			
		||||
                    raise
 | 
			
		||||
                time.sleep(0.01)
 | 
			
		||||
 | 
			
		||||
            result = self._results.get()
 | 
			
		||||
            if isinstance(result, Exception):
 | 
			
		||||
                self.close()
 | 
			
		||||
                raise result
 | 
			
		||||
            if result == None:
 | 
			
		||||
                self.close()
 | 
			
		||||
                raise StopIteration
 | 
			
		||||
            yield result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def call(context, topic, msg):
 | 
			
		||||
    """Sends a message on a topic and wait for a response."""
 | 
			
		||||
    rv = multicall(context, topic, msg)
 | 
			
		||||
    # NOTE(vish): return the last result from the multicall
 | 
			
		||||
    rv = list(rv)
 | 
			
		||||
    if not rv:
 | 
			
		||||
        return
 | 
			
		||||
    return rv[-1]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def cast(context, topic, msg):
 | 
			
		||||
    """Sends a message on a topic without waiting for a response."""
 | 
			
		||||
    LOG.debug(_('Making asynchronous cast on %s...'), topic)
 | 
			
		||||
    _pack_context(msg, context)
 | 
			
		||||
    conn = Connection.instance()
 | 
			
		||||
    with ConnectionPool.item() as conn:
 | 
			
		||||
        publisher = TopicPublisher(connection=conn, topic=topic)
 | 
			
		||||
        publisher.send(msg)
 | 
			
		||||
        publisher.close()
 | 
			
		||||
@@ -420,7 +556,7 @@ def fanout_cast(context, topic, msg):
 | 
			
		||||
    """Sends a message on a fanout exchange without waiting for a response."""
 | 
			
		||||
    LOG.debug(_('Making asynchronous fanout cast...'))
 | 
			
		||||
    _pack_context(msg, context)
 | 
			
		||||
    conn = Connection.instance()
 | 
			
		||||
    with ConnectionPool.item() as conn:
 | 
			
		||||
        publisher = FanoutPublisher(topic, connection=conn)
 | 
			
		||||
        publisher.send(msg)
 | 
			
		||||
        publisher.close()
 | 
			
		||||
@@ -459,6 +595,7 @@ def send_message(topic, message, wait=True):
 | 
			
		||||
 | 
			
		||||
    if wait:
 | 
			
		||||
        consumer.wait()
 | 
			
		||||
        consumer.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
 
 | 
			
		||||
@@ -81,6 +81,12 @@ def get_zone_capabilities(context):
 | 
			
		||||
    return _call_scheduler('get_zone_capabilities', context=context)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def select(context, specs=None):
 | 
			
		||||
    """Returns a list of hosts."""
 | 
			
		||||
    return _call_scheduler('select', context=context,
 | 
			
		||||
            params={"specs": specs})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def update_service_capabilities(context, service_name, host, capabilities):
 | 
			
		||||
    """Send an update to all the scheduler services informing them
 | 
			
		||||
       of the capabilities of this service."""
 | 
			
		||||
@@ -105,6 +111,45 @@ def _process(func, zone):
 | 
			
		||||
    return func(nova, zone)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def call_zone_method(context, method, errors_to_ignore=None, *args, **kwargs):
 | 
			
		||||
    """Returns a list of (zone, call_result) objects."""
 | 
			
		||||
    if not isinstance(errors_to_ignore, (list, tuple)):
 | 
			
		||||
        # This will also handle the default None
 | 
			
		||||
        errors_to_ignore = [errors_to_ignore]
 | 
			
		||||
 | 
			
		||||
    pool = greenpool.GreenPool()
 | 
			
		||||
    results = []
 | 
			
		||||
    for zone in db.zone_get_all(context):
 | 
			
		||||
        try:
 | 
			
		||||
            nova = novaclient.OpenStack(zone.username, zone.password,
 | 
			
		||||
                    zone.api_url)
 | 
			
		||||
            nova.authenticate()
 | 
			
		||||
        except novaclient.exceptions.BadRequest, e:
 | 
			
		||||
            url = zone.api_url
 | 
			
		||||
            LOG.warn(_("Failed request to zone; URL=%(url)s: %(e)s")
 | 
			
		||||
                    % locals())
 | 
			
		||||
            #TODO (dabo) - add logic for failure counts per zone,
 | 
			
		||||
            # with escalation after a given number of failures.
 | 
			
		||||
            continue
 | 
			
		||||
        zone_method = getattr(nova.zones, method)
 | 
			
		||||
 | 
			
		||||
        def _error_trap(*args, **kwargs):
 | 
			
		||||
            try:
 | 
			
		||||
                return zone_method(*args, **kwargs)
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                if type(e) in errors_to_ignore:
 | 
			
		||||
                    return None
 | 
			
		||||
                # TODO (dabo) - want to be able to re-raise here.
 | 
			
		||||
                # Returning a string now; raising was causing issues.
 | 
			
		||||
                # raise e
 | 
			
		||||
                return "ERROR", "%s" % e
 | 
			
		||||
 | 
			
		||||
        res = pool.spawn(_error_trap, *args, **kwargs)
 | 
			
		||||
        results.append((zone, res))
 | 
			
		||||
    pool.waitall()
 | 
			
		||||
    return [(zone.id, res.wait()) for zone, res in results]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def child_zone_helper(zone_list, func):
 | 
			
		||||
    """Fire off a command to each zone in the list.
 | 
			
		||||
    The return is [novaclient return objects] from each child zone.
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										119
									
								
								nova/scheduler/zone_aware_scheduler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								nova/scheduler/zone_aware_scheduler.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,119 @@
 | 
			
		||||
# Copyright (c) 2011 Openstack, LLC.
 | 
			
		||||
# All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
 | 
			
		||||
#    not use this file except in compliance with the License. You may obtain
 | 
			
		||||
#    a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#         http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
#    Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
#    License for the specific language governing permissions and limitations
 | 
			
		||||
#    under the License.
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
The Zone Aware Scheduler is a base class Scheduler for creating instances
 | 
			
		||||
across zones. There are two expansion points to this class for:
 | 
			
		||||
1. Assigning Weights to hosts for requested instances
 | 
			
		||||
2. Filtering Hosts based on required instance capabilities
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import operator
 | 
			
		||||
 | 
			
		||||
from nova import log as logging
 | 
			
		||||
from nova.scheduler import api
 | 
			
		||||
from nova.scheduler import driver
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger('nova.scheduler.zone_aware_scheduler')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ZoneAwareScheduler(driver.Scheduler):
 | 
			
		||||
    """Base class for creating Zone Aware Schedulers."""
 | 
			
		||||
 | 
			
		||||
    def _call_zone_method(self, context, method, specs):
 | 
			
		||||
        """Call novaclient zone method. Broken out for testing."""
 | 
			
		||||
        return api.call_zone_method(context, method, specs=specs)
 | 
			
		||||
 | 
			
		||||
    def schedule_run_instance(self, context, topic='compute', specs={},
 | 
			
		||||
                                        *args, **kwargs):
 | 
			
		||||
        """This method is called from nova.compute.api to provision
 | 
			
		||||
        an instance. However we need to look at the parameters being
 | 
			
		||||
        passed in to see if this is a request to:
 | 
			
		||||
        1. Create a Build Plan and then provision, or
 | 
			
		||||
        2. Use the Build Plan information in the request parameters
 | 
			
		||||
           to simply create the instance (either in this zone or
 | 
			
		||||
           a child zone)."""
 | 
			
		||||
 | 
			
		||||
        if 'blob' in specs:
 | 
			
		||||
            return self.provision_instance(context, topic, specs)
 | 
			
		||||
 | 
			
		||||
        # Create build plan and provision ...
 | 
			
		||||
        build_plan = self.select(context, specs)
 | 
			
		||||
        for item in build_plan:
 | 
			
		||||
            self.provision_instance(context, topic, item)
 | 
			
		||||
 | 
			
		||||
    def provision_instance(context, topic, item):
 | 
			
		||||
        """Create the requested instance in this Zone or a child zone."""
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def select(self, context, *args, **kwargs):
 | 
			
		||||
        """Select returns a list of weights and zone/host information
 | 
			
		||||
        corresponding to the best hosts to service the request. Any
 | 
			
		||||
        child zone information has been encrypted so as not to reveal
 | 
			
		||||
        anything about the children."""
 | 
			
		||||
        return self._schedule(context, "compute", *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def schedule(self, context, topic, *args, **kwargs):
 | 
			
		||||
        """The schedule() contract requires we return the one
 | 
			
		||||
        best-suited host for this request.
 | 
			
		||||
        """
 | 
			
		||||
        res = self._schedule(context, topic, *args, **kwargs)
 | 
			
		||||
        # TODO(sirp): should this be a host object rather than a weight-dict?
 | 
			
		||||
        if not res:
 | 
			
		||||
            raise driver.NoValidHost(_('No hosts were available'))
 | 
			
		||||
        return res[0]
 | 
			
		||||
 | 
			
		||||
    def _schedule(self, context, topic, *args, **kwargs):
 | 
			
		||||
        """Returns a list of hosts that meet the required specs,
 | 
			
		||||
        ordered by their fitness.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        #TODO(sandy): extract these from args.
 | 
			
		||||
        num_instances = 1
 | 
			
		||||
        specs = {}
 | 
			
		||||
 | 
			
		||||
        # Filter local hosts based on requirements ...
 | 
			
		||||
        host_list = self.filter_hosts(num_instances, specs)
 | 
			
		||||
 | 
			
		||||
        # then weigh the selected hosts.
 | 
			
		||||
        # weighted = [{weight=weight, name=hostname}, ...]
 | 
			
		||||
        weighted = self.weigh_hosts(num_instances, specs, host_list)
 | 
			
		||||
 | 
			
		||||
        # Next, tack on the best weights from the child zones ...
 | 
			
		||||
        child_results = self._call_zone_method(context, "select",
 | 
			
		||||
                specs=specs)
 | 
			
		||||
        for child_zone, result in child_results:
 | 
			
		||||
            for weighting in result:
 | 
			
		||||
                # Remember the child_zone so we can get back to
 | 
			
		||||
                # it later if needed. This implicitly builds a zone
 | 
			
		||||
                # path structure.
 | 
			
		||||
                host_dict = {
 | 
			
		||||
                        "weight": weighting["weight"],
 | 
			
		||||
                        "child_zone": child_zone,
 | 
			
		||||
                        "child_blob": weighting["blob"]}
 | 
			
		||||
                weighted.append(host_dict)
 | 
			
		||||
 | 
			
		||||
        weighted.sort(key=operator.itemgetter('weight'))
 | 
			
		||||
        return weighted
 | 
			
		||||
 | 
			
		||||
    def filter_hosts(self, num, specs):
 | 
			
		||||
        """Derived classes must override this method and return
 | 
			
		||||
           a list of hosts in [(hostname, capability_dict)] format."""
 | 
			
		||||
        raise NotImplemented()
 | 
			
		||||
 | 
			
		||||
    def weigh_hosts(self, num, specs, hosts):
 | 
			
		||||
        """Derived classes must override this method and return
 | 
			
		||||
           a lists of hosts in [{weight, hostname}] format."""
 | 
			
		||||
        raise NotImplemented()
 | 
			
		||||
@@ -21,24 +21,24 @@ from nova import flags
 | 
			
		||||
FLAGS = flags.FLAGS
 | 
			
		||||
 | 
			
		||||
flags.DECLARE('volume_driver', 'nova.volume.manager')
 | 
			
		||||
FLAGS.volume_driver = 'nova.volume.driver.FakeISCSIDriver'
 | 
			
		||||
FLAGS.connection_type = 'fake'
 | 
			
		||||
FLAGS.fake_rabbit = True
 | 
			
		||||
FLAGS['volume_driver'].SetDefault('nova.volume.driver.FakeISCSIDriver')
 | 
			
		||||
FLAGS['connection_type'].SetDefault('fake')
 | 
			
		||||
FLAGS['fake_rabbit'].SetDefault(True)
 | 
			
		||||
flags.DECLARE('auth_driver', 'nova.auth.manager')
 | 
			
		||||
FLAGS.auth_driver = 'nova.auth.dbdriver.DbDriver'
 | 
			
		||||
FLAGS['auth_driver'].SetDefault('nova.auth.dbdriver.DbDriver')
 | 
			
		||||
flags.DECLARE('network_size', 'nova.network.manager')
 | 
			
		||||
flags.DECLARE('num_networks', 'nova.network.manager')
 | 
			
		||||
flags.DECLARE('fake_network', 'nova.network.manager')
 | 
			
		||||
FLAGS.network_size = 8
 | 
			
		||||
FLAGS.num_networks = 2
 | 
			
		||||
FLAGS.fake_network = True
 | 
			
		||||
FLAGS.image_service = 'nova.image.local.LocalImageService'
 | 
			
		||||
FLAGS['network_size'].SetDefault(8)
 | 
			
		||||
FLAGS['num_networks'].SetDefault(2)
 | 
			
		||||
FLAGS['fake_network'].SetDefault(True)
 | 
			
		||||
FLAGS['image_service'].SetDefault('nova.image.local.LocalImageService')
 | 
			
		||||
flags.DECLARE('num_shelves', 'nova.volume.driver')
 | 
			
		||||
flags.DECLARE('blades_per_shelf', 'nova.volume.driver')
 | 
			
		||||
flags.DECLARE('iscsi_num_targets', 'nova.volume.driver')
 | 
			
		||||
FLAGS.num_shelves = 2
 | 
			
		||||
FLAGS.blades_per_shelf = 4
 | 
			
		||||
FLAGS.iscsi_num_targets = 8
 | 
			
		||||
FLAGS.verbose = True
 | 
			
		||||
FLAGS.sqlite_db = "tests.sqlite"
 | 
			
		||||
FLAGS.use_ipv6 = True
 | 
			
		||||
FLAGS['num_shelves'].SetDefault(2)
 | 
			
		||||
FLAGS['blades_per_shelf'].SetDefault(4)
 | 
			
		||||
FLAGS['iscsi_num_targets'].SetDefault(8)
 | 
			
		||||
FLAGS['verbose'].SetDefault(True)
 | 
			
		||||
FLAGS['sqlite_db'].SetDefault("tests.sqlite")
 | 
			
		||||
FLAGS['use_ipv6'].SetDefault(True)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										1
									
								
								nova/tests/public_key/dummy.fingerprint
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								nova/tests/public_key/dummy.fingerprint
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1 @@
 | 
			
		||||
1c:87:d1:d9:32:fd:62:3c:78:2b:c0:ad:c0:15:88:df
 | 
			
		||||
							
								
								
									
										1
									
								
								nova/tests/public_key/dummy.pub
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								nova/tests/public_key/dummy.pub
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1 @@
 | 
			
		||||
ssh-dss AAAAB3NzaC1kc3MAAACBAMGJlY9XEIm2X234pdO5yFWMp2JuOQx8U0E815IVXhmKxYCBK9ZakgZOIQmPbXoGYyV+mziDPp6HJ0wKYLQxkwLEFr51fAZjWQvRss0SinURRuLkockDfGFtD4pYJthekr/rlqMKlBSDUSpGq8jUWW60UJ18FGooFpxR7ESqQRx/AAAAFQC96LRglaUeeP+E8U/yblEJocuiWwAAAIA3XiMR8Skiz/0aBm5K50SeQznQuMJTyzt9S9uaz5QZWiFu69hOyGSFGw8fqgxEkXFJIuHobQQpGYQubLW0NdaYRqyE/Vud3JUJUb8Texld6dz8vGemyB5d1YvtSeHIo8/BGv2msOqR3u5AZTaGCBD9DhpSGOKHEdNjTtvpPd8S8gAAAIBociGZ5jf09iHLVENhyXujJbxfGRPsyNTyARJfCOGl0oFV6hEzcQyw8U/ePwjgvjc2UizMWLl8tsb2FXKHRdc2v+ND3Us+XqKQ33X3ADP4FZ/+Oj213gMyhCmvFTP0u5FmHog9My4CB7YcIWRuUR42WlhQ2IfPvKwUoTk3R+T6Og== www-data@mk
 | 
			
		||||
@@ -1,26 +0,0 @@
 | 
			
		||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
 | 
			
		||||
 | 
			
		||||
# Copyright 2010 United States Government as represented by the
 | 
			
		||||
# Administrator of the National Aeronautics and Space Administration.
 | 
			
		||||
# All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
 | 
			
		||||
#    not use this file except in compliance with the License. You may obtain
 | 
			
		||||
#    a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#         http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
#    Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
#    License for the specific language governing permissions and limitations
 | 
			
		||||
#    under the License.
 | 
			
		||||
 | 
			
		||||
from nova import flags
 | 
			
		||||
 | 
			
		||||
FLAGS = flags.FLAGS
 | 
			
		||||
 | 
			
		||||
FLAGS.connection_type = 'libvirt'
 | 
			
		||||
FLAGS.fake_rabbit = False
 | 
			
		||||
FLAGS.fake_network = False
 | 
			
		||||
FLAGS.verbose = False
 | 
			
		||||
@@ -224,6 +224,29 @@ class ApiEc2TestCase(test.TestCase):
 | 
			
		||||
        self.manager.delete_project(project)
 | 
			
		||||
        self.manager.delete_user(user)
 | 
			
		||||
 | 
			
		||||
    def test_create_duplicate_key_pair(self):
 | 
			
		||||
        """Test that, after successfully generating a keypair,
 | 
			
		||||
        requesting a second keypair with the same name fails sanely"""
 | 
			
		||||
        self.expect_http()
 | 
			
		||||
        self.mox.ReplayAll()
 | 
			
		||||
        keyname = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \
 | 
			
		||||
                          for x in range(random.randint(4, 8)))
 | 
			
		||||
        user = self.manager.create_user('fake', 'fake', 'fake')
 | 
			
		||||
        project = self.manager.create_project('fake', 'fake', 'fake')
 | 
			
		||||
        # NOTE(vish): create depends on pool, so call helper directly
 | 
			
		||||
        self.ec2.create_key_pair('test')
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            self.ec2.create_key_pair('test')
 | 
			
		||||
        except EC2ResponseError, e:
 | 
			
		||||
            if e.code == 'KeyPairExists':
 | 
			
		||||
                pass
 | 
			
		||||
            else:
 | 
			
		||||
                self.fail("Unexpected EC2ResponseError: %s "
 | 
			
		||||
                          "(expected KeyPairExists)" % e.code)
 | 
			
		||||
        else:
 | 
			
		||||
            self.fail('Exception not raised.')
 | 
			
		||||
 | 
			
		||||
    def test_get_all_security_groups(self):
 | 
			
		||||
        """Test that we can retrieve security groups"""
 | 
			
		||||
        self.expect_http()
 | 
			
		||||
 
 | 
			
		||||
@@ -17,13 +17,9 @@
 | 
			
		||||
#    under the License.
 | 
			
		||||
 | 
			
		||||
from base64 import b64decode
 | 
			
		||||
import json
 | 
			
		||||
from M2Crypto import BIO
 | 
			
		||||
from M2Crypto import RSA
 | 
			
		||||
import os
 | 
			
		||||
import shutil
 | 
			
		||||
import tempfile
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
from eventlet import greenthread
 | 
			
		||||
 | 
			
		||||
@@ -33,12 +29,10 @@ from nova import db
 | 
			
		||||
from nova import flags
 | 
			
		||||
from nova import log as logging
 | 
			
		||||
from nova import rpc
 | 
			
		||||
from nova import service
 | 
			
		||||
from nova import test
 | 
			
		||||
from nova import utils
 | 
			
		||||
from nova import exception
 | 
			
		||||
from nova.auth import manager
 | 
			
		||||
from nova.compute import power_state
 | 
			
		||||
from nova.api.ec2 import cloud
 | 
			
		||||
from nova.api.ec2 import ec2utils
 | 
			
		||||
from nova.image import local
 | 
			
		||||
@@ -80,15 +74,21 @@ class CloudTestCase(test.TestCase):
 | 
			
		||||
        self.stubs.Set(local.LocalImageService, 'show', fake_show)
 | 
			
		||||
        self.stubs.Set(local.LocalImageService, 'show_by_name', fake_show)
 | 
			
		||||
 | 
			
		||||
        # NOTE(vish): set up a manual wait so rpc.cast has a chance to finish
 | 
			
		||||
        rpc_cast = rpc.cast
 | 
			
		||||
 | 
			
		||||
        def finish_cast(*args, **kwargs):
 | 
			
		||||
            rpc_cast(*args, **kwargs)
 | 
			
		||||
            greenthread.sleep(0.2)
 | 
			
		||||
 | 
			
		||||
        self.stubs.Set(rpc, 'cast', finish_cast)
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        network_ref = db.project_get_network(self.context,
 | 
			
		||||
                                             self.project.id)
 | 
			
		||||
        db.network_disassociate(self.context, network_ref['id'])
 | 
			
		||||
        self.manager.delete_project(self.project)
 | 
			
		||||
        self.manager.delete_user(self.user)
 | 
			
		||||
        self.volume.kill()
 | 
			
		||||
        self.compute.kill()
 | 
			
		||||
        self.network.kill()
 | 
			
		||||
        super(CloudTestCase, self).tearDown()
 | 
			
		||||
 | 
			
		||||
    def _create_key(self, name):
 | 
			
		||||
@@ -115,7 +115,6 @@ class CloudTestCase(test.TestCase):
 | 
			
		||||
        self.cloud.describe_addresses(self.context)
 | 
			
		||||
        self.cloud.release_address(self.context,
 | 
			
		||||
                                  public_ip=address)
 | 
			
		||||
        greenthread.sleep(0.3)
 | 
			
		||||
        db.floating_ip_destroy(self.context, address)
 | 
			
		||||
 | 
			
		||||
    def test_associate_disassociate_address(self):
 | 
			
		||||
@@ -131,12 +130,10 @@ class CloudTestCase(test.TestCase):
 | 
			
		||||
        self.cloud.associate_address(self.context,
 | 
			
		||||
                                     instance_id=ec2_id,
 | 
			
		||||
                                     public_ip=address)
 | 
			
		||||
        greenthread.sleep(0.3)
 | 
			
		||||
        self.cloud.disassociate_address(self.context,
 | 
			
		||||
                                        public_ip=address)
 | 
			
		||||
        self.cloud.release_address(self.context,
 | 
			
		||||
                                  public_ip=address)
 | 
			
		||||
        greenthread.sleep(0.3)
 | 
			
		||||
        self.network.deallocate_fixed_ip(self.context, fixed)
 | 
			
		||||
        db.instance_destroy(self.context, inst['id'])
 | 
			
		||||
        db.floating_ip_destroy(self.context, address)
 | 
			
		||||
@@ -368,7 +365,6 @@ class CloudTestCase(test.TestCase):
 | 
			
		||||
 | 
			
		||||
    def _run_instance(self, **kwargs):
 | 
			
		||||
        rv = self.cloud.run_instances(self.context, **kwargs)
 | 
			
		||||
        greenthread.sleep(0.3)
 | 
			
		||||
        instance_id = rv['instancesSet'][0]['instanceId']
 | 
			
		||||
        return instance_id
 | 
			
		||||
 | 
			
		||||
@@ -387,9 +383,7 @@ class CloudTestCase(test.TestCase):
 | 
			
		||||
        self.assertEquals(b64decode(output['output']), 'FAKE CONSOLE?OUTPUT')
 | 
			
		||||
        # TODO(soren): We need this until we can stop polling in the rpc code
 | 
			
		||||
        #              for unit tests.
 | 
			
		||||
        greenthread.sleep(0.3)
 | 
			
		||||
        rv = self.cloud.terminate_instances(self.context, [instance_id])
 | 
			
		||||
        greenthread.sleep(0.3)
 | 
			
		||||
 | 
			
		||||
    def test_ajax_console(self):
 | 
			
		||||
        instance_id = self._run_instance(image_id='ami-1')
 | 
			
		||||
@@ -399,9 +393,7 @@ class CloudTestCase(test.TestCase):
 | 
			
		||||
                          '%s/?token=FAKETOKEN' % FLAGS.ajax_console_proxy_url)
 | 
			
		||||
        # TODO(soren): We need this until we can stop polling in the rpc code
 | 
			
		||||
        #              for unit tests.
 | 
			
		||||
        greenthread.sleep(0.3)
 | 
			
		||||
        rv = self.cloud.terminate_instances(self.context, [instance_id])
 | 
			
		||||
        greenthread.sleep(0.3)
 | 
			
		||||
 | 
			
		||||
    def test_key_generation(self):
 | 
			
		||||
        result = self._create_key('test')
 | 
			
		||||
@@ -425,6 +417,36 @@ class CloudTestCase(test.TestCase):
 | 
			
		||||
        self.assertTrue(filter(lambda k: k['keyName'] == 'test1', keys))
 | 
			
		||||
        self.assertTrue(filter(lambda k: k['keyName'] == 'test2', keys))
 | 
			
		||||
 | 
			
		||||
    def test_import_public_key(self):
 | 
			
		||||
        # test when user provides all values
 | 
			
		||||
        result1 = self.cloud.import_public_key(self.context,
 | 
			
		||||
                                               'testimportkey1',
 | 
			
		||||
                                               'mytestpubkey',
 | 
			
		||||
                                               'mytestfprint')
 | 
			
		||||
        self.assertTrue(result1)
 | 
			
		||||
        keydata = db.key_pair_get(self.context,
 | 
			
		||||
                                  self.context.user.id,
 | 
			
		||||
                                  'testimportkey1')
 | 
			
		||||
        self.assertEqual('mytestpubkey', keydata['public_key'])
 | 
			
		||||
        self.assertEqual('mytestfprint', keydata['fingerprint'])
 | 
			
		||||
        # test when user omits fingerprint
 | 
			
		||||
        pubkey_path = os.path.join(os.path.dirname(__file__), 'public_key')
 | 
			
		||||
        f = open(pubkey_path + '/dummy.pub', 'r')
 | 
			
		||||
        dummypub = f.readline().rstrip()
 | 
			
		||||
        f.close
 | 
			
		||||
        f = open(pubkey_path + '/dummy.fingerprint', 'r')
 | 
			
		||||
        dummyfprint = f.readline().rstrip()
 | 
			
		||||
        f.close
 | 
			
		||||
        result2 = self.cloud.import_public_key(self.context,
 | 
			
		||||
                                               'testimportkey2',
 | 
			
		||||
                                               dummypub)
 | 
			
		||||
        self.assertTrue(result2)
 | 
			
		||||
        keydata = db.key_pair_get(self.context,
 | 
			
		||||
                                  self.context.user.id,
 | 
			
		||||
                                  'testimportkey2')
 | 
			
		||||
        self.assertEqual(dummypub, keydata['public_key'])
 | 
			
		||||
        self.assertEqual(dummyfprint, keydata['fingerprint'])
 | 
			
		||||
 | 
			
		||||
    def test_delete_key_pair(self):
 | 
			
		||||
        self._create_key('test')
 | 
			
		||||
        self.cloud.delete_key_pair(self.context, 'test')
 | 
			
		||||
 
 | 
			
		||||
@@ -91,6 +91,20 @@ class FlagsTestCase(test.TestCase):
 | 
			
		||||
        self.assert_('runtime_answer' in self.global_FLAGS)
 | 
			
		||||
        self.assertEqual(self.global_FLAGS.runtime_answer, 60)
 | 
			
		||||
 | 
			
		||||
    def test_long_vs_short_flags(self):
 | 
			
		||||
        flags.DEFINE_string('duplicate_answer_long', 'val', 'desc',
 | 
			
		||||
                            flag_values=self.global_FLAGS)
 | 
			
		||||
        argv = ['flags_test', '--duplicate_answer=60', 'extra_arg']
 | 
			
		||||
        args = self.global_FLAGS(argv)
 | 
			
		||||
 | 
			
		||||
        self.assert_('duplicate_answer' not in self.global_FLAGS)
 | 
			
		||||
        self.assert_(self.global_FLAGS.duplicate_answer_long, 60)
 | 
			
		||||
 | 
			
		||||
        flags.DEFINE_integer('duplicate_answer', 60, 'desc',
 | 
			
		||||
                             flag_values=self.global_FLAGS)
 | 
			
		||||
        self.assertEqual(self.global_FLAGS.duplicate_answer, 60)
 | 
			
		||||
        self.assertEqual(self.global_FLAGS.duplicate_answer_long, 'val')
 | 
			
		||||
 | 
			
		||||
    def test_flag_leak_left(self):
 | 
			
		||||
        self.assertEqual(FLAGS.flags_unittest, 'foo')
 | 
			
		||||
        FLAGS.flags_unittest = 'bar'
 | 
			
		||||
 
 | 
			
		||||
@@ -32,7 +32,8 @@ from nova import utils
 | 
			
		||||
from nova.api.ec2 import cloud
 | 
			
		||||
from nova.auth import manager
 | 
			
		||||
from nova.compute import power_state
 | 
			
		||||
from nova.virt import libvirt_conn
 | 
			
		||||
from nova.virt.libvirt import connection
 | 
			
		||||
from nova.virt.libvirt import firewall
 | 
			
		||||
 | 
			
		||||
libvirt = None
 | 
			
		||||
FLAGS = flags.FLAGS
 | 
			
		||||
@@ -83,7 +84,7 @@ class CacheConcurrencyTestCase(test.TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_same_fname_concurrency(self):
 | 
			
		||||
        """Ensures that the same fname cache runs at a sequentially"""
 | 
			
		||||
        conn = libvirt_conn.LibvirtConnection
 | 
			
		||||
        conn = connection.LibvirtConnection
 | 
			
		||||
        wait1 = eventlet.event.Event()
 | 
			
		||||
        done1 = eventlet.event.Event()
 | 
			
		||||
        eventlet.spawn(conn._cache_image, _concurrency,
 | 
			
		||||
@@ -104,7 +105,7 @@ class CacheConcurrencyTestCase(test.TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_different_fname_concurrency(self):
 | 
			
		||||
        """Ensures that two different fname caches are concurrent"""
 | 
			
		||||
        conn = libvirt_conn.LibvirtConnection
 | 
			
		||||
        conn = connection.LibvirtConnection
 | 
			
		||||
        wait1 = eventlet.event.Event()
 | 
			
		||||
        done1 = eventlet.event.Event()
 | 
			
		||||
        eventlet.spawn(conn._cache_image, _concurrency,
 | 
			
		||||
@@ -125,7 +126,7 @@ class CacheConcurrencyTestCase(test.TestCase):
 | 
			
		||||
class LibvirtConnTestCase(test.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        super(LibvirtConnTestCase, self).setUp()
 | 
			
		||||
        libvirt_conn._late_load_cheetah()
 | 
			
		||||
        connection._late_load_cheetah()
 | 
			
		||||
        self.flags(fake_call=True)
 | 
			
		||||
        self.manager = manager.AuthManager()
 | 
			
		||||
 | 
			
		||||
@@ -171,8 +172,8 @@ class LibvirtConnTestCase(test.TestCase):
 | 
			
		||||
            return False
 | 
			
		||||
        global libvirt
 | 
			
		||||
        libvirt = __import__('libvirt')
 | 
			
		||||
        libvirt_conn.libvirt = __import__('libvirt')
 | 
			
		||||
        libvirt_conn.libxml2 = __import__('libxml2')
 | 
			
		||||
        connection.libvirt = __import__('libvirt')
 | 
			
		||||
        connection.libxml2 = __import__('libxml2')
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def create_fake_libvirt_mock(self, **kwargs):
 | 
			
		||||
@@ -182,7 +183,7 @@ class LibvirtConnTestCase(test.TestCase):
 | 
			
		||||
        class FakeLibvirtConnection(object):
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        # A fake libvirt_conn.IptablesFirewallDriver
 | 
			
		||||
        # A fake connection.IptablesFirewallDriver
 | 
			
		||||
        class FakeIptablesFirewallDriver(object):
 | 
			
		||||
 | 
			
		||||
            def __init__(self, **kwargs):
 | 
			
		||||
@@ -198,11 +199,11 @@ class LibvirtConnTestCase(test.TestCase):
 | 
			
		||||
        for key, val in kwargs.items():
 | 
			
		||||
            fake.__setattr__(key, val)
 | 
			
		||||
 | 
			
		||||
        # Inevitable mocks for libvirt_conn.LibvirtConnection
 | 
			
		||||
        self.mox.StubOutWithMock(libvirt_conn.utils, 'import_class')
 | 
			
		||||
        libvirt_conn.utils.import_class(mox.IgnoreArg()).AndReturn(fakeip)
 | 
			
		||||
        self.mox.StubOutWithMock(libvirt_conn.LibvirtConnection, '_conn')
 | 
			
		||||
        libvirt_conn.LibvirtConnection._conn = fake
 | 
			
		||||
        # Inevitable mocks for connection.LibvirtConnection
 | 
			
		||||
        self.mox.StubOutWithMock(connection.utils, 'import_class')
 | 
			
		||||
        connection.utils.import_class(mox.IgnoreArg()).AndReturn(fakeip)
 | 
			
		||||
        self.mox.StubOutWithMock(connection.LibvirtConnection, '_conn')
 | 
			
		||||
        connection.LibvirtConnection._conn = fake
 | 
			
		||||
 | 
			
		||||
    def create_service(self, **kwargs):
 | 
			
		||||
        service_ref = {'host': kwargs.get('host', 'dummy'),
 | 
			
		||||
@@ -214,7 +215,7 @@ class LibvirtConnTestCase(test.TestCase):
 | 
			
		||||
        return db.service_create(context.get_admin_context(), service_ref)
 | 
			
		||||
 | 
			
		||||
    def test_preparing_xml_info(self):
 | 
			
		||||
        conn = libvirt_conn.LibvirtConnection(True)
 | 
			
		||||
        conn = connection.LibvirtConnection(True)
 | 
			
		||||
        instance_ref = db.instance_create(self.context, self.test_instance)
 | 
			
		||||
 | 
			
		||||
        result = conn._prepare_xml_info(instance_ref, False)
 | 
			
		||||
@@ -229,7 +230,7 @@ class LibvirtConnTestCase(test.TestCase):
 | 
			
		||||
        self.assertTrue(len(result['nics']) == 2)
 | 
			
		||||
 | 
			
		||||
    def test_get_nic_for_xml_v4(self):
 | 
			
		||||
        conn = libvirt_conn.LibvirtConnection(True)
 | 
			
		||||
        conn = connection.LibvirtConnection(True)
 | 
			
		||||
        network, mapping = _create_network_info()[0]
 | 
			
		||||
        self.flags(use_ipv6=False)
 | 
			
		||||
        params = conn._get_nic_for_xml(network, mapping)['extra_params']
 | 
			
		||||
@@ -237,7 +238,7 @@ class LibvirtConnTestCase(test.TestCase):
 | 
			
		||||
        self.assertTrue(params.find('PROJMASKV6') == -1)
 | 
			
		||||
 | 
			
		||||
    def test_get_nic_for_xml_v6(self):
 | 
			
		||||
        conn = libvirt_conn.LibvirtConnection(True)
 | 
			
		||||
        conn = connection.LibvirtConnection(True)
 | 
			
		||||
        network, mapping = _create_network_info()[0]
 | 
			
		||||
        self.flags(use_ipv6=True)
 | 
			
		||||
        params = conn._get_nic_for_xml(network, mapping)['extra_params']
 | 
			
		||||
@@ -282,7 +283,7 @@ class LibvirtConnTestCase(test.TestCase):
 | 
			
		||||
    def test_multi_nic(self):
 | 
			
		||||
        instance_data = dict(self.test_instance)
 | 
			
		||||
        network_info = _create_network_info(2)
 | 
			
		||||
        conn = libvirt_conn.LibvirtConnection(True)
 | 
			
		||||
        conn = connection.LibvirtConnection(True)
 | 
			
		||||
        instance_ref = db.instance_create(self.context, instance_data)
 | 
			
		||||
        xml = conn.to_xml(instance_ref, False, network_info)
 | 
			
		||||
        tree = xml_to_tree(xml)
 | 
			
		||||
@@ -313,7 +314,7 @@ class LibvirtConnTestCase(test.TestCase):
 | 
			
		||||
                                  'instance_id': instance_ref['id']})
 | 
			
		||||
 | 
			
		||||
        self.flags(libvirt_type='lxc')
 | 
			
		||||
        conn = libvirt_conn.LibvirtConnection(True)
 | 
			
		||||
        conn = connection.LibvirtConnection(True)
 | 
			
		||||
 | 
			
		||||
        uri = conn.get_uri()
 | 
			
		||||
        self.assertEquals(uri, 'lxc:///')
 | 
			
		||||
@@ -419,7 +420,7 @@ class LibvirtConnTestCase(test.TestCase):
 | 
			
		||||
 | 
			
		||||
        for (libvirt_type, (expected_uri, checks)) in type_uri_map.iteritems():
 | 
			
		||||
            FLAGS.libvirt_type = libvirt_type
 | 
			
		||||
            conn = libvirt_conn.LibvirtConnection(True)
 | 
			
		||||
            conn = connection.LibvirtConnection(True)
 | 
			
		||||
 | 
			
		||||
            uri = conn.get_uri()
 | 
			
		||||
            self.assertEquals(uri, expected_uri)
 | 
			
		||||
@@ -446,7 +447,7 @@ class LibvirtConnTestCase(test.TestCase):
 | 
			
		||||
        FLAGS.libvirt_uri = testuri
 | 
			
		||||
        for (libvirt_type, (expected_uri, checks)) in type_uri_map.iteritems():
 | 
			
		||||
            FLAGS.libvirt_type = libvirt_type
 | 
			
		||||
            conn = libvirt_conn.LibvirtConnection(True)
 | 
			
		||||
            conn = connection.LibvirtConnection(True)
 | 
			
		||||
            uri = conn.get_uri()
 | 
			
		||||
            self.assertEquals(uri, testuri)
 | 
			
		||||
        db.instance_destroy(user_context, instance_ref['id'])
 | 
			
		||||
@@ -470,13 +471,13 @@ class LibvirtConnTestCase(test.TestCase):
 | 
			
		||||
        self.create_fake_libvirt_mock(getVersion=getVersion,
 | 
			
		||||
                                      getType=getType,
 | 
			
		||||
                                      listDomainsID=listDomainsID)
 | 
			
		||||
        self.mox.StubOutWithMock(libvirt_conn.LibvirtConnection,
 | 
			
		||||
        self.mox.StubOutWithMock(connection.LibvirtConnection,
 | 
			
		||||
                                 'get_cpu_info')
 | 
			
		||||
        libvirt_conn.LibvirtConnection.get_cpu_info().AndReturn('cpuinfo')
 | 
			
		||||
        connection.LibvirtConnection.get_cpu_info().AndReturn('cpuinfo')
 | 
			
		||||
 | 
			
		||||
        # Start test
 | 
			
		||||
        self.mox.ReplayAll()
 | 
			
		||||
        conn = libvirt_conn.LibvirtConnection(False)
 | 
			
		||||
        conn = connection.LibvirtConnection(False)
 | 
			
		||||
        conn.update_available_resource(self.context, 'dummy')
 | 
			
		||||
        service_ref = db.service_get(self.context, service_ref['id'])
 | 
			
		||||
        compute_node = service_ref['compute_node'][0]
 | 
			
		||||
@@ -510,7 +511,7 @@ class LibvirtConnTestCase(test.TestCase):
 | 
			
		||||
        self.create_fake_libvirt_mock()
 | 
			
		||||
 | 
			
		||||
        self.mox.ReplayAll()
 | 
			
		||||
        conn = libvirt_conn.LibvirtConnection(False)
 | 
			
		||||
        conn = connection.LibvirtConnection(False)
 | 
			
		||||
        self.assertRaises(exception.ComputeServiceUnavailable,
 | 
			
		||||
                          conn.update_available_resource,
 | 
			
		||||
                          self.context, 'dummy')
 | 
			
		||||
@@ -545,7 +546,7 @@ class LibvirtConnTestCase(test.TestCase):
 | 
			
		||||
        # Start test
 | 
			
		||||
        self.mox.ReplayAll()
 | 
			
		||||
        try:
 | 
			
		||||
            conn = libvirt_conn.LibvirtConnection(False)
 | 
			
		||||
            conn = connection.LibvirtConnection(False)
 | 
			
		||||
            conn.firewall_driver.setattr('setup_basic_filtering', fake_none)
 | 
			
		||||
            conn.firewall_driver.setattr('prepare_instance_filter', fake_none)
 | 
			
		||||
            conn.firewall_driver.setattr('instance_filter_exists', fake_none)
 | 
			
		||||
@@ -594,7 +595,7 @@ class LibvirtConnTestCase(test.TestCase):
 | 
			
		||||
 | 
			
		||||
        # Start test
 | 
			
		||||
        self.mox.ReplayAll()
 | 
			
		||||
        conn = libvirt_conn.LibvirtConnection(False)
 | 
			
		||||
        conn = connection.LibvirtConnection(False)
 | 
			
		||||
        self.assertRaises(libvirt.libvirtError,
 | 
			
		||||
                      conn._live_migration,
 | 
			
		||||
                      self.context, instance_ref, 'dest', '',
 | 
			
		||||
@@ -623,7 +624,7 @@ class LibvirtConnTestCase(test.TestCase):
 | 
			
		||||
 | 
			
		||||
        # Start test
 | 
			
		||||
        self.mox.ReplayAll()
 | 
			
		||||
        conn = libvirt_conn.LibvirtConnection(False)
 | 
			
		||||
        conn = connection.LibvirtConnection(False)
 | 
			
		||||
        conn.firewall_driver.setattr('setup_basic_filtering', fake_none)
 | 
			
		||||
        conn.firewall_driver.setattr('prepare_instance_filter', fake_none)
 | 
			
		||||
 | 
			
		||||
@@ -647,7 +648,7 @@ class LibvirtConnTestCase(test.TestCase):
 | 
			
		||||
        self.assertTrue(count)
 | 
			
		||||
 | 
			
		||||
    def test_get_host_ip_addr(self):
 | 
			
		||||
        conn = libvirt_conn.LibvirtConnection(False)
 | 
			
		||||
        conn = connection.LibvirtConnection(False)
 | 
			
		||||
        ip = conn.get_host_ip_addr()
 | 
			
		||||
        self.assertEquals(ip, FLAGS.my_ip)
 | 
			
		||||
 | 
			
		||||
@@ -671,7 +672,7 @@ class IptablesFirewallTestCase(test.TestCase):
 | 
			
		||||
        class FakeLibvirtConnection(object):
 | 
			
		||||
            pass
 | 
			
		||||
        self.fake_libvirt_connection = FakeLibvirtConnection()
 | 
			
		||||
        self.fw = libvirt_conn.IptablesFirewallDriver(
 | 
			
		||||
        self.fw = firewall.IptablesFirewallDriver(
 | 
			
		||||
                      get_connection=lambda: self.fake_libvirt_connection)
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
@@ -895,7 +896,7 @@ class NWFilterTestCase(test.TestCase):
 | 
			
		||||
 | 
			
		||||
        self.fake_libvirt_connection = Mock()
 | 
			
		||||
 | 
			
		||||
        self.fw = libvirt_conn.NWFilterFirewall(
 | 
			
		||||
        self.fw = firewall.NWFilterFirewall(
 | 
			
		||||
                                         lambda: self.fake_libvirt_connection)
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
@@ -31,7 +31,6 @@ LOG = logging.getLogger('nova.tests.rpc')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RpcTestCase(test.TestCase):
 | 
			
		||||
    """Test cases for rpc"""
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        super(RpcTestCase, self).setUp()
 | 
			
		||||
        self.conn = rpc.Connection.instance(True)
 | 
			
		||||
@@ -43,14 +42,55 @@ class RpcTestCase(test.TestCase):
 | 
			
		||||
        self.context = context.get_admin_context()
 | 
			
		||||
 | 
			
		||||
    def test_call_succeed(self):
 | 
			
		||||
        """Get a value through rpc call"""
 | 
			
		||||
        value = 42
 | 
			
		||||
        result = rpc.call(self.context, 'test', {"method": "echo",
 | 
			
		||||
                                                 "args": {"value": value}})
 | 
			
		||||
        self.assertEqual(value, result)
 | 
			
		||||
 | 
			
		||||
    def test_call_succeed_despite_multiple_returns(self):
 | 
			
		||||
        value = 42
 | 
			
		||||
        result = rpc.call(self.context, 'test', {"method": "echo_three_times",
 | 
			
		||||
                                                 "args": {"value": value}})
 | 
			
		||||
        self.assertEqual(value + 2, result)
 | 
			
		||||
 | 
			
		||||
    def test_call_succeed_despite_multiple_returns_yield(self):
 | 
			
		||||
        value = 42
 | 
			
		||||
        result = rpc.call(self.context, 'test',
 | 
			
		||||
                          {"method": "echo_three_times_yield",
 | 
			
		||||
                           "args": {"value": value}})
 | 
			
		||||
        self.assertEqual(value + 2, result)
 | 
			
		||||
 | 
			
		||||
    def test_multicall_succeed_once(self):
 | 
			
		||||
        value = 42
 | 
			
		||||
        result = rpc.multicall(self.context,
 | 
			
		||||
                              'test',
 | 
			
		||||
                              {"method": "echo",
 | 
			
		||||
                               "args": {"value": value}})
 | 
			
		||||
        for i, x in enumerate(result):
 | 
			
		||||
            if i > 0:
 | 
			
		||||
                self.fail('should only receive one response')
 | 
			
		||||
            self.assertEqual(value + i, x)
 | 
			
		||||
 | 
			
		||||
    def test_multicall_succeed_three_times(self):
 | 
			
		||||
        value = 42
 | 
			
		||||
        result = rpc.multicall(self.context,
 | 
			
		||||
                              'test',
 | 
			
		||||
                              {"method": "echo_three_times",
 | 
			
		||||
                               "args": {"value": value}})
 | 
			
		||||
        for i, x in enumerate(result):
 | 
			
		||||
            self.assertEqual(value + i, x)
 | 
			
		||||
 | 
			
		||||
    def test_multicall_succeed_three_times_yield(self):
 | 
			
		||||
        value = 42
 | 
			
		||||
        result = rpc.multicall(self.context,
 | 
			
		||||
                              'test',
 | 
			
		||||
                              {"method": "echo_three_times_yield",
 | 
			
		||||
                               "args": {"value": value}})
 | 
			
		||||
        for i, x in enumerate(result):
 | 
			
		||||
            self.assertEqual(value + i, x)
 | 
			
		||||
 | 
			
		||||
    def test_context_passed(self):
 | 
			
		||||
        """Makes sure a context is passed through rpc call"""
 | 
			
		||||
        """Makes sure a context is passed through rpc call."""
 | 
			
		||||
        value = 42
 | 
			
		||||
        result = rpc.call(self.context,
 | 
			
		||||
                          'test', {"method": "context",
 | 
			
		||||
@@ -58,11 +98,12 @@ class RpcTestCase(test.TestCase):
 | 
			
		||||
        self.assertEqual(self.context.to_dict(), result)
 | 
			
		||||
 | 
			
		||||
    def test_call_exception(self):
 | 
			
		||||
        """Test that exception gets passed back properly
 | 
			
		||||
        """Test that exception gets passed back properly.
 | 
			
		||||
 | 
			
		||||
        rpc.call returns a RemoteError object.  The value of the
 | 
			
		||||
        exception is converted to a string, so we convert it back
 | 
			
		||||
        to an int in the test.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        value = 42
 | 
			
		||||
        self.assertRaises(rpc.RemoteError,
 | 
			
		||||
@@ -81,7 +122,7 @@ class RpcTestCase(test.TestCase):
 | 
			
		||||
            self.assertEqual(int(exc.value), value)
 | 
			
		||||
 | 
			
		||||
    def test_nested_calls(self):
 | 
			
		||||
        """Test that we can do an rpc.call inside another call"""
 | 
			
		||||
        """Test that we can do an rpc.call inside another call."""
 | 
			
		||||
        class Nested(object):
 | 
			
		||||
            @staticmethod
 | 
			
		||||
            def echo(context, queue, value):
 | 
			
		||||
@@ -108,25 +149,80 @@ class RpcTestCase(test.TestCase):
 | 
			
		||||
                                              "value": value}})
 | 
			
		||||
        self.assertEqual(value, result)
 | 
			
		||||
 | 
			
		||||
    def test_connectionpool_single(self):
 | 
			
		||||
        """Test that ConnectionPool recycles a single connection."""
 | 
			
		||||
        conn1 = rpc.ConnectionPool.get()
 | 
			
		||||
        rpc.ConnectionPool.put(conn1)
 | 
			
		||||
        conn2 = rpc.ConnectionPool.get()
 | 
			
		||||
        rpc.ConnectionPool.put(conn2)
 | 
			
		||||
        self.assertEqual(conn1, conn2)
 | 
			
		||||
 | 
			
		||||
    def test_connectionpool_double(self):
 | 
			
		||||
        """Test that ConnectionPool returns and reuses separate connections.
 | 
			
		||||
 | 
			
		||||
        When called consecutively we should get separate connections and upon
 | 
			
		||||
        returning them those connections should be reused for future calls
 | 
			
		||||
        before generating a new connection.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        conn1 = rpc.ConnectionPool.get()
 | 
			
		||||
        conn2 = rpc.ConnectionPool.get()
 | 
			
		||||
 | 
			
		||||
        self.assertNotEqual(conn1, conn2)
 | 
			
		||||
        rpc.ConnectionPool.put(conn1)
 | 
			
		||||
        rpc.ConnectionPool.put(conn2)
 | 
			
		||||
 | 
			
		||||
        conn3 = rpc.ConnectionPool.get()
 | 
			
		||||
        conn4 = rpc.ConnectionPool.get()
 | 
			
		||||
        self.assertEqual(conn1, conn3)
 | 
			
		||||
        self.assertEqual(conn2, conn4)
 | 
			
		||||
 | 
			
		||||
    def test_connectionpool_limit(self):
 | 
			
		||||
        """Test connection pool limit and connection uniqueness."""
 | 
			
		||||
        max_size = FLAGS.rpc_conn_pool_size
 | 
			
		||||
        conns = []
 | 
			
		||||
 | 
			
		||||
        for i in xrange(max_size):
 | 
			
		||||
            conns.append(rpc.ConnectionPool.get())
 | 
			
		||||
 | 
			
		||||
        self.assertFalse(rpc.ConnectionPool.free_items)
 | 
			
		||||
        self.assertEqual(rpc.ConnectionPool.current_size,
 | 
			
		||||
                rpc.ConnectionPool.max_size)
 | 
			
		||||
        self.assertEqual(len(set(conns)), max_size)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestReceiver(object):
 | 
			
		||||
    """Simple Proxy class so the consumer has methods to call
 | 
			
		||||
    """Simple Proxy class so the consumer has methods to call.
 | 
			
		||||
 | 
			
		||||
    Uses static methods because we aren't actually storing any state"""
 | 
			
		||||
    Uses static methods because we aren't actually storing any state.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def echo(context, value):
 | 
			
		||||
        """Simply returns whatever value is sent in"""
 | 
			
		||||
        """Simply returns whatever value is sent in."""
 | 
			
		||||
        LOG.debug(_("Received %s"), value)
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def context(context, value):
 | 
			
		||||
        """Returns dictionary version of context"""
 | 
			
		||||
        """Returns dictionary version of context."""
 | 
			
		||||
        LOG.debug(_("Received %s"), context)
 | 
			
		||||
        return context.to_dict()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def echo_three_times(context, value):
 | 
			
		||||
        context.reply(value)
 | 
			
		||||
        context.reply(value + 1)
 | 
			
		||||
        context.reply(value + 2)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def echo_three_times_yield(context, value):
 | 
			
		||||
        yield value
 | 
			
		||||
        yield value + 1
 | 
			
		||||
        yield value + 2
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def fail(context, value):
 | 
			
		||||
        """Raises an exception with the value sent in"""
 | 
			
		||||
        """Raises an exception with the value sent in."""
 | 
			
		||||
        raise Exception(value)
 | 
			
		||||
 
 | 
			
		||||
@@ -912,7 +912,8 @@ class SimpleDriverTestCase(test.TestCase):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FakeZone(object):
 | 
			
		||||
    def __init__(self, api_url, username, password):
 | 
			
		||||
    def __init__(self, id, api_url, username, password):
 | 
			
		||||
        self.id = id
 | 
			
		||||
        self.api_url = api_url
 | 
			
		||||
        self.username = username
 | 
			
		||||
        self.password = password
 | 
			
		||||
@@ -920,7 +921,7 @@ class FakeZone(object):
 | 
			
		||||
 | 
			
		||||
def zone_get_all(context):
 | 
			
		||||
    return [
 | 
			
		||||
                FakeZone('http://example.com', 'bob', 'xxx'),
 | 
			
		||||
                FakeZone(1, 'http://example.com', 'bob', 'xxx'),
 | 
			
		||||
           ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -1037,7 +1038,7 @@ class FakeNovaClient(object):
 | 
			
		||||
 | 
			
		||||
class DynamicNovaClientTest(test.TestCase):
 | 
			
		||||
    def test_issue_novaclient_command_found(self):
 | 
			
		||||
        zone = FakeZone('http://example.com', 'bob', 'xxx')
 | 
			
		||||
        zone = FakeZone(1, 'http://example.com', 'bob', 'xxx')
 | 
			
		||||
        self.assertEquals(api._issue_novaclient_command(
 | 
			
		||||
                    FakeNovaClient(FakeServerCollection()),
 | 
			
		||||
                    zone, "servers", "get", 100).a, 10)
 | 
			
		||||
@@ -1051,7 +1052,7 @@ class DynamicNovaClientTest(test.TestCase):
 | 
			
		||||
                    zone, "servers", "pause", 100), None)
 | 
			
		||||
 | 
			
		||||
    def test_issue_novaclient_command_not_found(self):
 | 
			
		||||
        zone = FakeZone('http://example.com', 'bob', 'xxx')
 | 
			
		||||
        zone = FakeZone(1, 'http://example.com', 'bob', 'xxx')
 | 
			
		||||
        self.assertEquals(api._issue_novaclient_command(
 | 
			
		||||
                    FakeNovaClient(FakeEmptyServerCollection()),
 | 
			
		||||
                    zone, "servers", "get", 100), None)
 | 
			
		||||
@@ -1063,3 +1064,55 @@ class DynamicNovaClientTest(test.TestCase):
 | 
			
		||||
        self.assertEquals(api._issue_novaclient_command(
 | 
			
		||||
                    FakeNovaClient(FakeEmptyServerCollection()),
 | 
			
		||||
                    zone, "servers", "any", "name"), None)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FakeZonesProxy(object):
 | 
			
		||||
    def do_something(*args, **kwargs):
 | 
			
		||||
        return 42
 | 
			
		||||
 | 
			
		||||
    def raises_exception(*args, **kwargs):
 | 
			
		||||
        raise Exception('testing')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FakeNovaClientOpenStack(object):
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        self.zones = FakeZonesProxy()
 | 
			
		||||
 | 
			
		||||
    def authenticate(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CallZoneMethodTest(test.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        super(CallZoneMethodTest, self).setUp()
 | 
			
		||||
        self.stubs = stubout.StubOutForTesting()
 | 
			
		||||
        self.stubs.Set(db, 'zone_get_all', zone_get_all)
 | 
			
		||||
        self.stubs.Set(novaclient, 'OpenStack', FakeNovaClientOpenStack)
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        self.stubs.UnsetAll()
 | 
			
		||||
        super(CallZoneMethodTest, self).tearDown()
 | 
			
		||||
 | 
			
		||||
    def test_call_zone_method(self):
 | 
			
		||||
        context = {}
 | 
			
		||||
        method = 'do_something'
 | 
			
		||||
        results = api.call_zone_method(context, method)
 | 
			
		||||
        expected = [(1, 42)]
 | 
			
		||||
        self.assertEqual(expected, results)
 | 
			
		||||
 | 
			
		||||
    def test_call_zone_method_not_present(self):
 | 
			
		||||
        context = {}
 | 
			
		||||
        method = 'not_present'
 | 
			
		||||
        self.assertRaises(AttributeError, api.call_zone_method,
 | 
			
		||||
                          context, method)
 | 
			
		||||
 | 
			
		||||
    def test_call_zone_method_generates_exception(self):
 | 
			
		||||
        context = {}
 | 
			
		||||
        method = 'raises_exception'
 | 
			
		||||
        results = api.call_zone_method(context, method)
 | 
			
		||||
 | 
			
		||||
        # FIXME(sirp): for now the _error_trap code is catching errors and
 | 
			
		||||
        # converting them to a ("ERROR", "string") tuples. The code (and this
 | 
			
		||||
        # test) should eventually handle real exceptions.
 | 
			
		||||
        expected = [(1, ('ERROR', 'testing'))]
 | 
			
		||||
        self.assertEqual(expected, results)
 | 
			
		||||
 
 | 
			
		||||
@@ -395,6 +395,29 @@ class XenAPIVMTestCase(test.TestCase):
 | 
			
		||||
                         os_type="linux")
 | 
			
		||||
        self.check_vm_params_for_linux()
 | 
			
		||||
 | 
			
		||||
    def test_spawn_vhd_glance_swapdisk(self):
 | 
			
		||||
        # Change the default host_call_plugin to one that'll return
 | 
			
		||||
        # a swap disk
 | 
			
		||||
        orig_func = stubs.FakeSessionForVMTests.host_call_plugin
 | 
			
		||||
 | 
			
		||||
        stubs.FakeSessionForVMTests.host_call_plugin = \
 | 
			
		||||
                stubs.FakeSessionForVMTests.host_call_plugin_swap
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            # We'll steal the above glance linux test
 | 
			
		||||
            self.test_spawn_vhd_glance_linux()
 | 
			
		||||
        finally:
 | 
			
		||||
            # Make sure to put this back
 | 
			
		||||
            stubs.FakeSessionForVMTests.host_call_plugin = orig_func
 | 
			
		||||
 | 
			
		||||
        # We should have 2 VBDs.
 | 
			
		||||
        self.assertEqual(len(self.vm['VBDs']), 2)
 | 
			
		||||
        # Now test that we have 1.
 | 
			
		||||
        self.tearDown()
 | 
			
		||||
        self.setUp()
 | 
			
		||||
        self.test_spawn_vhd_glance_linux()
 | 
			
		||||
        self.assertEqual(len(self.vm['VBDs']), 1)
 | 
			
		||||
 | 
			
		||||
    def test_spawn_vhd_glance_windows(self):
 | 
			
		||||
        FLAGS.xenapi_image_service = 'glance'
 | 
			
		||||
        self._test_spawn(glance_stubs.FakeGlance.IMAGE_VHD, None, None,
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										119
									
								
								nova/tests/test_zone_aware_scheduler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								nova/tests/test_zone_aware_scheduler.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,119 @@
 | 
			
		||||
# Copyright 2011 OpenStack LLC.
 | 
			
		||||
# All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
 | 
			
		||||
#    not use this file except in compliance with the License. You may obtain
 | 
			
		||||
#    a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#         http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
#    Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
#    License for the specific language governing permissions and limitations
 | 
			
		||||
#    under the License.
 | 
			
		||||
"""
 | 
			
		||||
Tests For Zone Aware Scheduler.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
from nova import test
 | 
			
		||||
from nova.scheduler import driver
 | 
			
		||||
from nova.scheduler import zone_aware_scheduler
 | 
			
		||||
from nova.scheduler import zone_manager
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FakeZoneAwareScheduler(zone_aware_scheduler.ZoneAwareScheduler):
 | 
			
		||||
    def filter_hosts(self, num, specs):
 | 
			
		||||
        # NOTE(sirp): this is returning [(hostname, services)]
 | 
			
		||||
        return self.zone_manager.service_states.items()
 | 
			
		||||
 | 
			
		||||
    def weigh_hosts(self, num, specs, hosts):
 | 
			
		||||
        fake_weight = 99
 | 
			
		||||
        weighted = []
 | 
			
		||||
        for hostname, caps in hosts:
 | 
			
		||||
            weighted.append(dict(weight=fake_weight, name=hostname))
 | 
			
		||||
        return weighted
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FakeZoneManager(zone_manager.ZoneManager):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.service_states = {
 | 
			
		||||
                        'host1': {
 | 
			
		||||
                            'compute': {'ram': 1000}
 | 
			
		||||
                         },
 | 
			
		||||
                         'host2': {
 | 
			
		||||
                            'compute': {'ram': 2000}
 | 
			
		||||
                         },
 | 
			
		||||
                         'host3': {
 | 
			
		||||
                            'compute': {'ram': 3000}
 | 
			
		||||
                         }
 | 
			
		||||
                     }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FakeEmptyZoneManager(zone_manager.ZoneManager):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.service_states = {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fake_empty_call_zone_method(context, method, specs):
 | 
			
		||||
    return []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fake_call_zone_method(context, method, specs):
 | 
			
		||||
    return [
 | 
			
		||||
        ('zone1', [
 | 
			
		||||
            dict(weight=1, blob='AAAAAAA'),
 | 
			
		||||
            dict(weight=111, blob='BBBBBBB'),
 | 
			
		||||
            dict(weight=112, blob='CCCCCCC'),
 | 
			
		||||
            dict(weight=113, blob='DDDDDDD'),
 | 
			
		||||
        ]),
 | 
			
		||||
        ('zone2', [
 | 
			
		||||
            dict(weight=120, blob='EEEEEEE'),
 | 
			
		||||
            dict(weight=2, blob='FFFFFFF'),
 | 
			
		||||
            dict(weight=122, blob='GGGGGGG'),
 | 
			
		||||
            dict(weight=123, blob='HHHHHHH'),
 | 
			
		||||
        ]),
 | 
			
		||||
        ('zone3', [
 | 
			
		||||
            dict(weight=130, blob='IIIIIII'),
 | 
			
		||||
            dict(weight=131, blob='JJJJJJJ'),
 | 
			
		||||
            dict(weight=132, blob='KKKKKKK'),
 | 
			
		||||
            dict(weight=3, blob='LLLLLLL'),
 | 
			
		||||
        ]),
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ZoneAwareSchedulerTestCase(test.TestCase):
 | 
			
		||||
    """Test case for Zone Aware Scheduler."""
 | 
			
		||||
 | 
			
		||||
    def test_zone_aware_scheduler(self):
 | 
			
		||||
        """
 | 
			
		||||
        Create a nested set of FakeZones, ensure that a select call returns the
 | 
			
		||||
        appropriate build plan.
 | 
			
		||||
        """
 | 
			
		||||
        sched = FakeZoneAwareScheduler()
 | 
			
		||||
        self.stubs.Set(sched, '_call_zone_method', fake_call_zone_method)
 | 
			
		||||
 | 
			
		||||
        zm = FakeZoneManager()
 | 
			
		||||
        sched.set_zone_manager(zm)
 | 
			
		||||
 | 
			
		||||
        fake_context = {}
 | 
			
		||||
        build_plan = sched.select(fake_context, {})
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(15, len(build_plan))
 | 
			
		||||
 | 
			
		||||
        hostnames = [plan_item['name']
 | 
			
		||||
                     for plan_item in build_plan if 'name' in plan_item]
 | 
			
		||||
        self.assertEqual(3, len(hostnames))
 | 
			
		||||
 | 
			
		||||
    def test_empty_zone_aware_scheduler(self):
 | 
			
		||||
        """
 | 
			
		||||
        Ensure empty hosts & child_zones result in NoValidHosts exception.
 | 
			
		||||
        """
 | 
			
		||||
        sched = FakeZoneAwareScheduler()
 | 
			
		||||
        self.stubs.Set(sched, '_call_zone_method', fake_empty_call_zone_method)
 | 
			
		||||
 | 
			
		||||
        zm = FakeEmptyZoneManager()
 | 
			
		||||
        sched.set_zone_manager(zm)
 | 
			
		||||
 | 
			
		||||
        fake_context = {}
 | 
			
		||||
        self.assertRaises(driver.NoValidHost, sched.schedule, fake_context, {})
 | 
			
		||||
@@ -17,6 +17,7 @@
 | 
			
		||||
"""Stubouts, mocks and fixtures for the test suite"""
 | 
			
		||||
 | 
			
		||||
import eventlet
 | 
			
		||||
import json
 | 
			
		||||
from nova.virt import xenapi_conn
 | 
			
		||||
from nova.virt.xenapi import fake
 | 
			
		||||
from nova.virt.xenapi import volume_utils
 | 
			
		||||
@@ -37,7 +38,7 @@ def stubout_instance_snapshot(stubs):
 | 
			
		||||
                             sr_ref=sr_ref, sharable=False)
 | 
			
		||||
        vdi_rec = session.get_xenapi().VDI.get_record(vdi_ref)
 | 
			
		||||
        vdi_uuid = vdi_rec['uuid']
 | 
			
		||||
        return vdi_uuid
 | 
			
		||||
        return [dict(vdi_type='os', vdi_uuid=vdi_uuid)]
 | 
			
		||||
 | 
			
		||||
    stubs.Set(vm_utils.VMHelper, 'fetch_image', fake_fetch_image)
 | 
			
		||||
 | 
			
		||||
@@ -132,11 +133,30 @@ class FakeSessionForVMTests(fake.SessionBase):
 | 
			
		||||
    def __init__(self, uri):
 | 
			
		||||
        super(FakeSessionForVMTests, self).__init__(uri)
 | 
			
		||||
 | 
			
		||||
    def host_call_plugin(self, _1, _2, _3, _4, _5):
 | 
			
		||||
    def host_call_plugin(self, _1, _2, plugin, method, _5):
 | 
			
		||||
        sr_ref = fake.get_all('SR')[0]
 | 
			
		||||
        vdi_ref = fake.create_vdi('', False, sr_ref, False)
 | 
			
		||||
        vdi_rec = fake.get_record('VDI', vdi_ref)
 | 
			
		||||
        return '<string>%s</string>' % vdi_rec['uuid']
 | 
			
		||||
        if plugin == "glance" and method == "download_vhd":
 | 
			
		||||
            ret_str = json.dumps([dict(vdi_type='os',
 | 
			
		||||
                    vdi_uuid=vdi_rec['uuid'])])
 | 
			
		||||
        else:
 | 
			
		||||
            ret_str = vdi_rec['uuid']
 | 
			
		||||
        return '<string>%s</string>' % ret_str
 | 
			
		||||
 | 
			
		||||
    def host_call_plugin_swap(self, _1, _2, plugin, method, _5):
 | 
			
		||||
        sr_ref = fake.get_all('SR')[0]
 | 
			
		||||
        vdi_ref = fake.create_vdi('', False, sr_ref, False)
 | 
			
		||||
        vdi_rec = fake.get_record('VDI', vdi_ref)
 | 
			
		||||
        if plugin == "glance" and method == "download_vhd":
 | 
			
		||||
            swap_vdi_ref = fake.create_vdi('', False, sr_ref, False)
 | 
			
		||||
            swap_vdi_rec = fake.get_record('VDI', swap_vdi_ref)
 | 
			
		||||
            ret_str = json.dumps(
 | 
			
		||||
                    [dict(vdi_type='os', vdi_uuid=vdi_rec['uuid']),
 | 
			
		||||
                    dict(vdi_type='swap', vdi_uuid=swap_vdi_rec['uuid'])])
 | 
			
		||||
        else:
 | 
			
		||||
            ret_str = vdi_rec['uuid']
 | 
			
		||||
        return '<string>%s</string>' % ret_str
 | 
			
		||||
 | 
			
		||||
    def VM_start(self, _1, ref, _2, _3):
 | 
			
		||||
        vm = fake.get_record('VM', ref)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user