tests working after merge-3 update

This commit is contained in:
Sandy Walsh
2011-06-09 11:52:53 -07:00
36 changed files with 1610 additions and 1381 deletions

View File

@@ -30,6 +30,7 @@ Gabe Westmaas <gabe.westmaas@rackspace.com>
Hisaharu Ishii <ishii.hisaharu@lab.ntt.co.jp> Hisaharu Ishii <ishii.hisaharu@lab.ntt.co.jp>
Hisaki Ohara <hisaki.ohara@intel.com> Hisaki Ohara <hisaki.ohara@intel.com>
Ilya Alekseyev <ialekseev@griddynamics.com> Ilya Alekseyev <ialekseev@griddynamics.com>
Isaku Yamahata <yamahata@valinux.co.jp>
Jason Koelker <jason@koelker.net> Jason Koelker <jason@koelker.net>
Jay Pipes <jaypipes@gmail.com> Jay Pipes <jaypipes@gmail.com>
Jesse Andrews <anotherjesse@gmail.com> Jesse Andrews <anotherjesse@gmail.com>
@@ -58,6 +59,7 @@ Mark Washenberger <mark.washenberger@rackspace.com>
Masanori Itoh <itoumsn@nttdata.co.jp> Masanori Itoh <itoumsn@nttdata.co.jp>
Matt Dietz <matt.dietz@rackspace.com> Matt Dietz <matt.dietz@rackspace.com>
Michael Gundlach <michael.gundlach@rackspace.com> Michael Gundlach <michael.gundlach@rackspace.com>
Mike Scherbakov <mihgen@gmail.com>
Monsyne Dragon <mdragon@rackspace.com> Monsyne Dragon <mdragon@rackspace.com>
Monty Taylor <mordred@inaugust.com> Monty Taylor <mordred@inaugust.com>
MORITA Kazutaka <morita.kazutaka@gmail.com> MORITA Kazutaka <morita.kazutaka@gmail.com>
@@ -83,6 +85,7 @@ Trey Morris <trey.morris@rackspace.com>
Tushar Patil <tushar.vitthal.patil@gmail.com> Tushar Patil <tushar.vitthal.patil@gmail.com>
Vasiliy Shlykov <vash@vasiliyshlykov.org> Vasiliy Shlykov <vash@vasiliyshlykov.org>
Vishvananda Ishaya <vishvananda@gmail.com> Vishvananda Ishaya <vishvananda@gmail.com>
Vivek Y S <vivek.ys@gmail.com>
William Wolf <throughnothing@gmail.com> William Wolf <throughnothing@gmail.com>
Yoshiaki Tamura <yoshi@midokura.jp> Yoshiaki Tamura <yoshi@midokura.jp>
Youcef Laribi <Youcef.Laribi@eu.citrix.com> Youcef Laribi <Youcef.Laribi@eu.citrix.com>

View File

@@ -53,7 +53,6 @@
CLI interface for nova management. CLI interface for nova management.
""" """
import datetime
import gettext import gettext
import glob import glob
import json import json
@@ -78,6 +77,7 @@ from nova import crypto
from nova import db from nova import db
from nova import exception from nova import exception
from nova import flags from nova import flags
from nova import image
from nova import log as logging from nova import log as logging
from nova import quota from nova import quota
from nova import rpc from nova import rpc
@@ -96,6 +96,7 @@ flags.DECLARE('network_size', 'nova.network.manager')
flags.DECLARE('vlan_start', 'nova.network.manager') flags.DECLARE('vlan_start', 'nova.network.manager')
flags.DECLARE('vpn_start', 'nova.network.manager') flags.DECLARE('vpn_start', 'nova.network.manager')
flags.DECLARE('fixed_range_v6', 'nova.network.manager') flags.DECLARE('fixed_range_v6', 'nova.network.manager')
flags.DECLARE('gateway_v6', 'nova.network.manager')
flags.DECLARE('images_path', 'nova.image.local') flags.DECLARE('images_path', 'nova.image.local')
flags.DECLARE('libvirt_type', 'nova.virt.libvirt.connection') flags.DECLARE('libvirt_type', 'nova.virt.libvirt.connection')
flags.DEFINE_flag(flags.HelpFlag()) flags.DEFINE_flag(flags.HelpFlag())
@@ -536,7 +537,7 @@ class FloatingIpCommands(object):
for floating_ip in floating_ips: for floating_ip in floating_ips:
instance = None instance = None
if floating_ip['fixed_ip']: if floating_ip['fixed_ip']:
instance = floating_ip['fixed_ip']['instance']['ec2_id'] instance = floating_ip['fixed_ip']['instance']['hostname']
print "%s\t%s\t%s" % (floating_ip['host'], print "%s\t%s\t%s" % (floating_ip['host'],
floating_ip['address'], floating_ip['address'],
instance) instance)
@@ -545,13 +546,10 @@ class FloatingIpCommands(object):
class NetworkCommands(object): class NetworkCommands(object):
"""Class for managing networks.""" """Class for managing networks."""
def create(self, fixed_range=None, num_networks=None, def create(self, fixed_range=None, num_networks=None, network_size=None,
network_size=None, vlan_start=None, vlan_start=None, vpn_start=None, fixed_range_v6=None,
vpn_start=None, fixed_range_v6=None, label='public'): gateway_v6=None, label='public'):
"""Creates fixed ips for host by range """Creates fixed ips for host by range"""
arguments: fixed_range=FLAG, [num_networks=FLAG],
[network_size=FLAG], [vlan_start=FLAG],
[vpn_start=FLAG], [fixed_range_v6=FLAG]"""
if not fixed_range: if not fixed_range:
msg = _('Fixed range in the form of 10.0.0.0/8 is ' msg = _('Fixed range in the form of 10.0.0.0/8 is '
'required to create networks.') 'required to create networks.')
@@ -567,6 +565,8 @@ class NetworkCommands(object):
vpn_start = FLAGS.vpn_start vpn_start = FLAGS.vpn_start
if not fixed_range_v6: if not fixed_range_v6:
fixed_range_v6 = FLAGS.fixed_range_v6 fixed_range_v6 = FLAGS.fixed_range_v6
if not gateway_v6:
gateway_v6 = FLAGS.gateway_v6
net_manager = utils.import_object(FLAGS.network_manager) net_manager = utils.import_object(FLAGS.network_manager)
try: try:
net_manager.create_networks(context.get_admin_context(), net_manager.create_networks(context.get_admin_context(),
@@ -576,6 +576,7 @@ class NetworkCommands(object):
vlan_start=int(vlan_start), vlan_start=int(vlan_start),
vpn_start=int(vpn_start), vpn_start=int(vpn_start),
cidr_v6=fixed_range_v6, cidr_v6=fixed_range_v6,
gateway_v6=gateway_v6,
label=label) label=label)
except ValueError, e: except ValueError, e:
print e print e
@@ -689,7 +690,7 @@ class ServiceCommands(object):
"""Show a list of all running services. Filter by host & service name. """Show a list of all running services. Filter by host & service name.
args: [host] [service]""" args: [host] [service]"""
ctxt = context.get_admin_context() ctxt = context.get_admin_context()
now = datetime.datetime.utcnow() now = utils.utcnow()
services = db.service_get_all(ctxt) services = db.service_get_all(ctxt)
if host: if host:
services = [s for s in services if s['host'] == host] services = [s for s in services if s['host'] == host]
@@ -936,7 +937,7 @@ class ImageCommands(object):
"""Methods for dealing with a cloud in an odd state""" """Methods for dealing with a cloud in an odd state"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.image_service = utils.import_object(FLAGS.image_service) self.image_service = image.get_default_image_service()
def _register(self, container_format, disk_format, def _register(self, container_format, disk_format,
path, owner, name=None, is_public='T', path, owner, name=None, is_public='T',
@@ -1081,24 +1082,35 @@ class ImageCommands(object):
self._convert_images(machine_images) self._convert_images(machine_images)
class ConfigCommands(object):
"""Class for exposing the flags defined by flag_file(s)."""
def __init__(self):
pass
def list(self):
print FLAGS.FlagsIntoString()
CATEGORIES = [ CATEGORIES = [
('user', UserCommands),
('account', AccountCommands), ('account', AccountCommands),
('project', ProjectCommands), ('config', ConfigCommands),
('role', RoleCommands),
('shell', ShellCommands),
('vpn', VpnCommands),
('fixed', FixedIpCommands),
('floating', FloatingIpCommands),
('network', NetworkCommands),
('vm', VmCommands),
('service', ServiceCommands),
('db', DbCommands), ('db', DbCommands),
('volume', VolumeCommands), ('fixed', FixedIpCommands),
('flavor', InstanceTypeCommands),
('floating', FloatingIpCommands),
('instance_type', InstanceTypeCommands), ('instance_type', InstanceTypeCommands),
('image', ImageCommands), ('image', ImageCommands),
('flavor', InstanceTypeCommands), ('network', NetworkCommands),
('version', VersionCommands)] ('project', ProjectCommands),
('role', RoleCommands),
('service', ServiceCommands),
('shell', ShellCommands),
('user', UserCommands),
('version', VersionCommands),
('vm', VmCommands),
('volume', VolumeCommands),
('vpn', VpnCommands)]
def lazy_match(name, key_value_tuples): def lazy_match(name, key_value_tuples):

View File

@@ -24,6 +24,7 @@ other backends by creating another class that exposes the same
public methods. public methods.
""" """
import functools
import sys import sys
from nova import exception from nova import exception
@@ -68,6 +69,12 @@ flags.DEFINE_string('ldap_developer',
LOG = logging.getLogger("nova.ldapdriver") LOG = logging.getLogger("nova.ldapdriver")
if FLAGS.memcached_servers:
import memcache
else:
from nova import fakememcache as memcache
# TODO(vish): make an abstract base class with the same public methods # TODO(vish): make an abstract base class with the same public methods
# to define a set interface for AuthDrivers. I'm delaying # to define a set interface for AuthDrivers. I'm delaying
# creating this now because I'm expecting an auth refactor # creating this now because I'm expecting an auth refactor
@@ -85,6 +92,7 @@ def _clean(attr):
def sanitize(fn): def sanitize(fn):
"""Decorator to sanitize all args""" """Decorator to sanitize all args"""
@functools.wraps(fn)
def _wrapped(self, *args, **kwargs): def _wrapped(self, *args, **kwargs):
args = [_clean(x) for x in args] args = [_clean(x) for x in args]
kwargs = dict((k, _clean(v)) for (k, v) in kwargs) kwargs = dict((k, _clean(v)) for (k, v) in kwargs)
@@ -103,29 +111,56 @@ class LdapDriver(object):
isadmin_attribute = 'isNovaAdmin' isadmin_attribute = 'isNovaAdmin'
project_attribute = 'owner' project_attribute = 'owner'
project_objectclass = 'groupOfNames' project_objectclass = 'groupOfNames'
conn = None
mc = None
def __init__(self): def __init__(self):
"""Imports the LDAP module""" """Imports the LDAP module"""
self.ldap = __import__('ldap') self.ldap = __import__('ldap')
self.conn = None
if FLAGS.ldap_schema_version == 1: if FLAGS.ldap_schema_version == 1:
LdapDriver.project_pattern = '(objectclass=novaProject)' LdapDriver.project_pattern = '(objectclass=novaProject)'
LdapDriver.isadmin_attribute = 'isAdmin' LdapDriver.isadmin_attribute = 'isAdmin'
LdapDriver.project_attribute = 'projectManager' LdapDriver.project_attribute = 'projectManager'
LdapDriver.project_objectclass = 'novaProject' LdapDriver.project_objectclass = 'novaProject'
self.__cache = None
if LdapDriver.conn is None:
LdapDriver.conn = self.ldap.initialize(FLAGS.ldap_url)
LdapDriver.conn.simple_bind_s(FLAGS.ldap_user_dn,
FLAGS.ldap_password)
if LdapDriver.mc is None:
LdapDriver.mc = memcache.Client(FLAGS.memcached_servers, debug=0)
def __enter__(self): def __enter__(self):
"""Creates the connection to LDAP""" # TODO(yorik-sar): Should be per-request cache, not per-driver-request
self.conn = self.ldap.initialize(FLAGS.ldap_url) self.__cache = {}
self.conn.simple_bind_s(FLAGS.ldap_user_dn, FLAGS.ldap_password)
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
"""Destroys the connection to LDAP""" self.__cache = None
self.conn.unbind_s()
return False return False
def __local_cache(key_fmt):
"""Wrap function to cache it's result in self.__cache.
Works only with functions with one fixed argument.
"""
def do_wrap(fn):
@functools.wraps(fn)
def inner(self, arg, **kwargs):
cache_key = key_fmt % (arg,)
try:
res = self.__cache[cache_key]
LOG.debug('Local cache hit for %s by key %s' %
(fn.__name__, cache_key))
return res
except KeyError:
res = fn(self, arg, **kwargs)
self.__cache[cache_key] = res
return res
return inner
return do_wrap
@sanitize @sanitize
@__local_cache('uid_user-%s')
def get_user(self, uid): def get_user(self, uid):
"""Retrieve user by id""" """Retrieve user by id"""
attr = self.__get_ldap_user(uid) attr = self.__get_ldap_user(uid)
@@ -134,15 +169,31 @@ class LdapDriver(object):
@sanitize @sanitize
def get_user_from_access_key(self, access): def get_user_from_access_key(self, access):
"""Retrieve user by access key""" """Retrieve user by access key"""
cache_key = 'uak_dn_%s' % (access,)
user_dn = self.mc.get(cache_key)
if user_dn:
user = self.__to_user(
self.__find_object(user_dn, scope=self.ldap.SCOPE_BASE))
if user:
if user['access'] == access:
return user
else:
self.mc.set(cache_key, None)
query = '(accessKey=%s)' % access query = '(accessKey=%s)' % access
dn = FLAGS.ldap_user_subtree dn = FLAGS.ldap_user_subtree
return self.__to_user(self.__find_object(dn, query)) user_obj = self.__find_object(dn, query)
user = self.__to_user(user_obj)
if user:
self.mc.set(cache_key, user_obj['dn'][0])
return user
@sanitize @sanitize
@__local_cache('pid_project-%s')
def get_project(self, pid): def get_project(self, pid):
"""Retrieve project by id""" """Retrieve project by id"""
dn = self.__project_to_dn(pid) dn = self.__project_to_dn(pid, search=False)
attr = self.__find_object(dn, LdapDriver.project_pattern) attr = self.__find_object(dn, LdapDriver.project_pattern,
scope=self.ldap.SCOPE_BASE)
return self.__to_project(attr) return self.__to_project(attr)
@sanitize @sanitize
@@ -395,6 +446,7 @@ class LdapDriver(object):
"""Check if project exists""" """Check if project exists"""
return self.get_project(project_id) is not None return self.get_project(project_id) is not None
@__local_cache('uid_attrs-%s')
def __get_ldap_user(self, uid): def __get_ldap_user(self, uid):
"""Retrieve LDAP user entry by id""" """Retrieve LDAP user entry by id"""
dn = FLAGS.ldap_user_subtree dn = FLAGS.ldap_user_subtree
@@ -426,12 +478,20 @@ class LdapDriver(object):
if scope is None: if scope is None:
# One of the flags is 0! # One of the flags is 0!
scope = self.ldap.SCOPE_SUBTREE scope = self.ldap.SCOPE_SUBTREE
if query is None:
query = "(objectClass=*)"
try: try:
res = self.conn.search_s(dn, scope, query) res = self.conn.search_s(dn, scope, query)
except self.ldap.NO_SUCH_OBJECT: except self.ldap.NO_SUCH_OBJECT:
return [] return []
# Just return the attributes # Just return the attributes
return [attributes for dn, attributes in res] # FIXME(yorik-sar): Whole driver should be refactored to
# prevent this hack
res1 = []
for dn, attrs in res:
attrs['dn'] = [dn]
res1.append(attrs)
return res1
def __find_role_dns(self, tree): def __find_role_dns(self, tree):
"""Find dns of role objects in given tree""" """Find dns of role objects in given tree"""
@@ -564,6 +624,7 @@ class LdapDriver(object):
'description': attr.get('description', [None])[0], 'description': attr.get('description', [None])[0],
'member_ids': [self.__dn_to_uid(x) for x in member_dns]} 'member_ids': [self.__dn_to_uid(x) for x in member_dns]}
@__local_cache('uid_dn-%s')
def __uid_to_dn(self, uid, search=True): def __uid_to_dn(self, uid, search=True):
"""Convert uid to dn""" """Convert uid to dn"""
# By default return a generated DN # By default return a generated DN
@@ -576,6 +637,7 @@ class LdapDriver(object):
userdn = user[0] userdn = user[0]
return userdn return userdn
@__local_cache('pid_dn-%s')
def __project_to_dn(self, pid, search=True): def __project_to_dn(self, pid, search=True):
"""Convert pid to dn""" """Convert pid to dn"""
# By default return a generated DN # By default return a generated DN
@@ -603,16 +665,18 @@ class LdapDriver(object):
else: else:
return None return None
@__local_cache('dn_uid-%s')
def __dn_to_uid(self, dn): def __dn_to_uid(self, dn):
"""Convert user dn to uid""" """Convert user dn to uid"""
query = '(objectclass=novaUser)' query = '(objectclass=novaUser)'
user = self.__find_object(dn, query) user = self.__find_object(dn, query, scope=self.ldap.SCOPE_BASE)
return user[FLAGS.ldap_user_id_attribute][0] return user[FLAGS.ldap_user_id_attribute][0]
class FakeLdapDriver(LdapDriver): class FakeLdapDriver(LdapDriver):
"""Fake Ldap Auth driver""" """Fake Ldap Auth driver"""
def __init__(self): # pylint: disable=W0231 def __init__(self):
__import__('nova.auth.fakeldap') import nova.auth.fakeldap
self.ldap = sys.modules['nova.auth.fakeldap'] sys.modules['ldap'] = nova.auth.fakeldap
super(FakeLdapDriver, self).__init__()

View File

@@ -73,6 +73,12 @@ flags.DEFINE_string('auth_driver', 'nova.auth.dbdriver.DbDriver',
LOG = logging.getLogger('nova.auth.manager') LOG = logging.getLogger('nova.auth.manager')
if FLAGS.memcached_servers:
import memcache
else:
from nova import fakememcache as memcache
class AuthBase(object): class AuthBase(object):
"""Base class for objects relating to auth """Base class for objects relating to auth
@@ -206,6 +212,7 @@ class AuthManager(object):
""" """
_instance = None _instance = None
mc = None
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
"""Returns the AuthManager singleton""" """Returns the AuthManager singleton"""
@@ -222,13 +229,8 @@ class AuthManager(object):
self.network_manager = utils.import_object(FLAGS.network_manager) self.network_manager = utils.import_object(FLAGS.network_manager)
if driver or not getattr(self, 'driver', None): if driver or not getattr(self, 'driver', None):
self.driver = utils.import_class(driver or FLAGS.auth_driver) self.driver = utils.import_class(driver or FLAGS.auth_driver)
if AuthManager.mc is None:
if FLAGS.memcached_servers: AuthManager.mc = memcache.Client(FLAGS.memcached_servers, debug=0)
import memcache
else:
from nova import fakememcache as memcache
self.mc = memcache.Client(FLAGS.memcached_servers,
debug=0)
def authenticate(self, access, signature, params, verb='GET', def authenticate(self, access, signature, params, verb='GET',
server_string='127.0.0.1:8773', path='/', server_string='127.0.0.1:8773', path='/',

View File

@@ -1,4 +1,6 @@
NOVA_KEY_DIR=$(pushd $(dirname $BASH_SOURCE)>/dev/null; pwd; popd>/dev/null) NOVARC=$(readlink -f "${BASH_SOURCE:-${0}}" 2>/dev/null) ||
NOVARC=$(python -c 'import os,sys; print os.path.abspath(os.path.realpath(sys.argv[1]))' "${BASH_SOURCE:-${0}}")
NOVA_KEY_DIR=${NOVARC%%/*}
export EC2_ACCESS_KEY="%(access)s:%(project)s" export EC2_ACCESS_KEY="%(access)s:%(project)s"
export EC2_SECRET_KEY="%(secret)s" export EC2_SECRET_KEY="%(secret)s"
export EC2_URL="%(ec2)s" export EC2_URL="%(ec2)s"

View File

@@ -31,6 +31,7 @@ LOG = logging.getLogger("nova.fakerabbit")
EXCHANGES = {} EXCHANGES = {}
QUEUES = {} QUEUES = {}
CONSUMERS = {}
class Message(base.BaseMessage): class Message(base.BaseMessage):
@@ -96,17 +97,29 @@ class Backend(base.BaseBackend):
' key %(routing_key)s') % locals()) ' key %(routing_key)s') % locals())
EXCHANGES[exchange].bind(QUEUES[queue].push, routing_key) EXCHANGES[exchange].bind(QUEUES[queue].push, routing_key)
def declare_consumer(self, queue, callback, *args, **kwargs): def declare_consumer(self, queue, callback, consumer_tag, *args, **kwargs):
self.current_queue = queue global CONSUMERS
self.current_callback = callback LOG.debug("Adding consumer %s", consumer_tag)
CONSUMERS[consumer_tag] = (queue, callback)
def cancel(self, consumer_tag):
global CONSUMERS
LOG.debug("Removing consumer %s", consumer_tag)
del CONSUMERS[consumer_tag]
def consume(self, limit=None): def consume(self, limit=None):
global CONSUMERS
num = 0
while True: while True:
item = self.get(self.current_queue) for (queue, callback) in CONSUMERS.itervalues():
item = self.get(queue)
if item: if item:
self.current_callback(item) callback(item)
num += 1
yield
if limit and num == limit:
raise StopIteration() raise StopIteration()
greenthread.sleep(0) greenthread.sleep(0.1)
def get(self, queue, no_ack=False): def get(self, queue, no_ack=False):
global QUEUES global QUEUES
@@ -134,5 +147,7 @@ class Backend(base.BaseBackend):
def reset_all(): def reset_all():
global EXCHANGES global EXCHANGES
global QUEUES global QUEUES
global CONSUMERS
EXCHANGES = {} EXCHANGES = {}
QUEUES = {} QUEUES = {}
CONSUMERS = {}

View File

@@ -296,6 +296,7 @@ DEFINE_bool('fake_network', False,
'should we use fake network devices and addresses') 'should we use fake network devices and addresses')
DEFINE_string('rabbit_host', 'localhost', 'rabbit host') DEFINE_string('rabbit_host', 'localhost', 'rabbit host')
DEFINE_integer('rabbit_port', 5672, 'rabbit port') DEFINE_integer('rabbit_port', 5672, 'rabbit port')
DEFINE_bool('rabbit_use_ssl', False, 'connect over SSL')
DEFINE_string('rabbit_userid', 'guest', 'rabbit userid') DEFINE_string('rabbit_userid', 'guest', 'rabbit userid')
DEFINE_string('rabbit_password', 'guest', 'rabbit password') DEFINE_string('rabbit_password', 'guest', 'rabbit password')
DEFINE_string('rabbit_virtual_host', '/', 'rabbit virtual host') DEFINE_string('rabbit_virtual_host', '/', 'rabbit virtual host')

View File

@@ -35,6 +35,7 @@ import os
import sys import sys
import traceback import traceback
import nova
from nova import flags from nova import flags
from nova import version from nova import version
@@ -63,6 +64,7 @@ flags.DEFINE_list('default_log_levels',
'eventlet.wsgi.server=WARN'], 'eventlet.wsgi.server=WARN'],
'list of logger=LEVEL pairs') 'list of logger=LEVEL pairs')
flags.DEFINE_bool('use_syslog', False, 'output to syslog') flags.DEFINE_bool('use_syslog', False, 'output to syslog')
flags.DEFINE_bool('publish_errors', False, 'publish error events')
flags.DEFINE_string('logfile', None, 'output to named file') flags.DEFINE_string('logfile', None, 'output to named file')
@@ -258,12 +260,20 @@ class NovaRootLogger(NovaLogger):
else: else:
self.removeHandler(self.filelog) self.removeHandler(self.filelog)
self.addHandler(self.streamlog) self.addHandler(self.streamlog)
if FLAGS.publish_errors:
self.addHandler(PublishErrorsHandler(ERROR))
if FLAGS.verbose: if FLAGS.verbose:
self.setLevel(DEBUG) self.setLevel(DEBUG)
else: else:
self.setLevel(INFO) self.setLevel(INFO)
class PublishErrorsHandler(logging.Handler):
def emit(self, record):
nova.notifier.api.notify('nova.error.publisher', 'error_notification',
nova.notifier.api.ERROR, dict(error=record.msg))
def handle_exception(type, value, tb): def handle_exception(type, value, tb):
extra = {} extra = {}
if FLAGS.verbose: if FLAGS.verbose:

View File

@@ -11,9 +11,8 @@
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License.import datetime # under the License.
import datetime
import uuid import uuid
from nova import flags from nova import flags
@@ -64,7 +63,7 @@ def notify(publisher_id, event_type, priority, payload):
{'message_id': str(uuid.uuid4()), {'message_id': str(uuid.uuid4()),
'publisher_id': 'compute.host1', 'publisher_id': 'compute.host1',
'timestamp': datetime.datetime.utcnow(), 'timestamp': utils.utcnow(),
'priority': 'WARN', 'priority': 'WARN',
'event_type': 'compute.create_instance', 'event_type': 'compute.create_instance',
'payload': {'instance_id': 12, ... }} 'payload': {'instance_id': 12, ... }}
@@ -79,5 +78,5 @@ def notify(publisher_id, event_type, priority, payload):
event_type=event_type, event_type=event_type,
priority=priority, priority=priority,
payload=payload, payload=payload,
timestamp=str(datetime.datetime.utcnow())) timestamp=str(utils.utcnow()))
driver.notify(msg) driver.notify(msg)

View File

@@ -28,12 +28,15 @@ import json
import sys import sys
import time import time
import traceback import traceback
import types
import uuid import uuid
from carrot import connection as carrot_connection from carrot import connection as carrot_connection
from carrot import messaging from carrot import messaging
from eventlet import greenpool from eventlet import greenpool
from eventlet import greenthread from eventlet import pools
from eventlet import queue
import greenlet
from nova import context from nova import context
from nova import exception from nova import exception
@@ -47,7 +50,10 @@ LOG = logging.getLogger('nova.rpc')
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_integer('rpc_thread_pool_size', 1024, 'Size of RPC thread pool') flags.DEFINE_integer('rpc_thread_pool_size', 1024,
'Size of RPC thread pool')
flags.DEFINE_integer('rpc_conn_pool_size', 30,
'Size of RPC connection pool')
class Connection(carrot_connection.BrokerConnection): class Connection(carrot_connection.BrokerConnection):
@@ -59,6 +65,7 @@ class Connection(carrot_connection.BrokerConnection):
if new or not hasattr(cls, '_instance'): if new or not hasattr(cls, '_instance'):
params = dict(hostname=FLAGS.rabbit_host, params = dict(hostname=FLAGS.rabbit_host,
port=FLAGS.rabbit_port, port=FLAGS.rabbit_port,
ssl=FLAGS.rabbit_use_ssl,
userid=FLAGS.rabbit_userid, userid=FLAGS.rabbit_userid,
password=FLAGS.rabbit_password, password=FLAGS.rabbit_password,
virtual_host=FLAGS.rabbit_virtual_host) virtual_host=FLAGS.rabbit_virtual_host)
@@ -90,6 +97,22 @@ class Connection(carrot_connection.BrokerConnection):
return cls.instance() return cls.instance()
class Pool(pools.Pool):
"""Class that implements a Pool of Connections."""
# TODO(comstud): Timeout connections not used in a while
def create(self):
LOG.debug('Creating new connection')
return Connection.instance(new=True)
# Create a ConnectionPool to use for RPC calls. We'll order the
# pool as a stack (LIFO), so that we can potentially loop through and
# timeout old unused connections at some point
ConnectionPool = Pool(
max_size=FLAGS.rpc_conn_pool_size,
order_as_stack=True)
class Consumer(messaging.Consumer): class Consumer(messaging.Consumer):
"""Consumer base class. """Consumer base class.
@@ -131,7 +154,9 @@ class Consumer(messaging.Consumer):
self.connection = Connection.recreate() self.connection = Connection.recreate()
self.backend = self.connection.create_backend() self.backend = self.connection.create_backend()
self.declare() self.declare()
super(Consumer, self).fetch(no_ack, auto_ack, enable_callbacks) return super(Consumer, self).fetch(no_ack,
auto_ack,
enable_callbacks)
if self.failed_connection: if self.failed_connection:
LOG.error(_('Reconnected to queue')) LOG.error(_('Reconnected to queue'))
self.failed_connection = False self.failed_connection = False
@@ -159,13 +184,13 @@ class AdapterConsumer(Consumer):
self.pool = greenpool.GreenPool(FLAGS.rpc_thread_pool_size) self.pool = greenpool.GreenPool(FLAGS.rpc_thread_pool_size)
super(AdapterConsumer, self).__init__(connection=connection, super(AdapterConsumer, self).__init__(connection=connection,
topic=topic) topic=topic)
self.register_callback(self.process_data)
def receive(self, *args, **kwargs): def process_data(self, message_data, message):
self.pool.spawn_n(self._receive, *args, **kwargs) """Consumer callback to call a method on a proxy object.
@exception.wrap_exception Parses the message for validity and fires off a thread to call the
def _receive(self, message_data, message): proxy object method.
"""Magically looks for a method on the proxy object and calls it.
Message data should be a dictionary with two keys: Message data should be a dictionary with two keys:
method: string representing the method to call method: string representing the method to call
@@ -175,8 +200,8 @@ class AdapterConsumer(Consumer):
""" """
LOG.debug(_('received %s') % message_data) LOG.debug(_('received %s') % message_data)
msg_id = message_data.pop('_msg_id', None) # This will be popped off in _unpack_context
msg_id = message_data.get('_msg_id', None)
ctxt = _unpack_context(message_data) ctxt = _unpack_context(message_data)
method = message_data.get('method') method = message_data.get('method')
@@ -188,8 +213,17 @@ class AdapterConsumer(Consumer):
# we just log the message and send an error string # we just log the message and send an error string
# back to the caller # back to the caller
LOG.warn(_('no method for message: %s') % message_data) LOG.warn(_('no method for message: %s') % message_data)
msg_reply(msg_id, _('No method for message: %s') % message_data) if msg_id:
msg_reply(msg_id,
_('No method for message: %s') % message_data)
return return
self.pool.spawn_n(self._process_data, msg_id, ctxt, method, args)
@exception.wrap_exception
def _process_data(self, msg_id, ctxt, method, args):
"""Thread that maigcally looks for a method on the proxy
object and calls it.
"""
node_func = getattr(self.proxy, str(method)) node_func = getattr(self.proxy, str(method))
node_args = dict((str(k), v) for k, v in args.iteritems()) node_args = dict((str(k), v) for k, v in args.iteritems())
@@ -197,7 +231,18 @@ class AdapterConsumer(Consumer):
try: try:
rval = node_func(context=ctxt, **node_args) rval = node_func(context=ctxt, **node_args)
if msg_id: if msg_id:
# Check if the result was a generator
if isinstance(rval, types.GeneratorType):
for x in rval:
msg_reply(msg_id, x, None)
else:
msg_reply(msg_id, rval, None) msg_reply(msg_id, rval, None)
# This final None tells multicall that it is done.
msg_reply(msg_id, None, None)
elif isinstance(rval, types.GeneratorType):
# NOTE(vish): this iterates through the generator
list(rval)
except Exception as e: except Exception as e:
logging.exception('Exception during message handling') logging.exception('Exception during message handling')
if msg_id: if msg_id:
@@ -205,11 +250,6 @@ class AdapterConsumer(Consumer):
return return
class Publisher(messaging.Publisher):
"""Publisher base class."""
pass
class TopicAdapterConsumer(AdapterConsumer): class TopicAdapterConsumer(AdapterConsumer):
"""Consumes messages on a specific topic.""" """Consumes messages on a specific topic."""
@@ -242,6 +282,58 @@ class FanoutAdapterConsumer(AdapterConsumer):
topic=topic, proxy=proxy) topic=topic, proxy=proxy)
class ConsumerSet(object):
"""Groups consumers to listen on together on a single connection."""
def __init__(self, connection, consumer_list):
self.consumer_list = set(consumer_list)
self.consumer_set = None
self.enabled = True
self.init(connection)
def init(self, conn):
if not conn:
conn = Connection.instance(new=True)
if self.consumer_set:
self.consumer_set.close()
self.consumer_set = messaging.ConsumerSet(conn)
for consumer in self.consumer_list:
consumer.connection = conn
# consumer.backend is set for us
self.consumer_set.add_consumer(consumer)
def reconnect(self):
self.init(None)
def wait(self, limit=None):
running = True
while running:
it = self.consumer_set.iterconsume(limit=limit)
if not it:
break
while True:
try:
it.next()
except StopIteration:
return
except greenlet.GreenletExit:
running = False
break
except Exception as e:
LOG.exception(_("Exception while processing consumer"))
self.reconnect()
# Break to outer loop
break
def close(self):
self.consumer_set.close()
class Publisher(messaging.Publisher):
"""Publisher base class."""
pass
class TopicPublisher(Publisher): class TopicPublisher(Publisher):
"""Publishes messages on a specific topic.""" """Publishes messages on a specific topic."""
@@ -306,7 +398,8 @@ def msg_reply(msg_id, reply=None, failure=None):
LOG.error(_("Returning exception %s to caller"), message) LOG.error(_("Returning exception %s to caller"), message)
LOG.error(tb) LOG.error(tb)
failure = (failure[0].__name__, str(failure[1]), tb) failure = (failure[0].__name__, str(failure[1]), tb)
conn = Connection.instance()
with ConnectionPool.item() as conn:
publisher = DirectPublisher(connection=conn, msg_id=msg_id) publisher = DirectPublisher(connection=conn, msg_id=msg_id)
try: try:
publisher.send({'result': reply, 'failure': failure}) publisher.send({'result': reply, 'failure': failure})
@@ -315,6 +408,7 @@ def msg_reply(msg_id, reply=None, failure=None):
{'result': dict((k, repr(v)) {'result': dict((k, repr(v))
for k, v in reply.__dict__.iteritems()), for k, v in reply.__dict__.iteritems()),
'failure': failure}) 'failure': failure})
publisher.close() publisher.close()
@@ -347,8 +441,9 @@ def _unpack_context(msg):
if key.startswith('_context_'): if key.startswith('_context_'):
value = msg.pop(key) value = msg.pop(key)
context_dict[key[9:]] = value context_dict[key[9:]] = value
context_dict['msg_id'] = msg.pop('_msg_id', None)
LOG.debug(_('unpacked context: %s'), context_dict) LOG.debug(_('unpacked context: %s'), context_dict)
return context.RequestContext.from_dict(context_dict) return RpcContext.from_dict(context_dict)
def _pack_context(msg, context): def _pack_context(msg, context):
@@ -360,57 +455,99 @@ def _pack_context(msg, context):
for args at some point. for args at some point.
""" """
context = dict([('_context_%s' % key, value) context_d = dict([('_context_%s' % key, value)
for (key, value) in context.to_dict().iteritems()]) for (key, value) in context.to_dict().iteritems()])
msg.update(context) msg.update(context_d)
def call(context, topic, msg): class RpcContext(context.RequestContext):
"""Sends a message on a topic and wait for a response.""" def __init__(self, *args, **kwargs):
msg_id = kwargs.pop('msg_id', None)
self.msg_id = msg_id
super(RpcContext, self).__init__(*args, **kwargs)
def reply(self, *args, **kwargs):
msg_reply(self.msg_id, *args, **kwargs)
def multicall(context, topic, msg):
"""Make a call that returns multiple times."""
LOG.debug(_('Making asynchronous call on %s ...'), topic) LOG.debug(_('Making asynchronous call on %s ...'), topic)
msg_id = uuid.uuid4().hex msg_id = uuid.uuid4().hex
msg.update({'_msg_id': msg_id}) msg.update({'_msg_id': msg_id})
LOG.debug(_('MSG_ID is %s') % (msg_id)) LOG.debug(_('MSG_ID is %s') % (msg_id))
_pack_context(msg, context) _pack_context(msg, context)
class WaitMessage(object): con_conn = ConnectionPool.get()
consumer = DirectConsumer(connection=con_conn, msg_id=msg_id)
wait_msg = MulticallWaiter(consumer)
consumer.register_callback(wait_msg)
publisher = TopicPublisher(connection=con_conn, topic=topic)
publisher.send(msg)
publisher.close()
return wait_msg
class MulticallWaiter(object):
def __init__(self, consumer):
self._consumer = consumer
self._results = queue.Queue()
self._closed = False
def close(self):
self._closed = True
self._consumer.close()
ConnectionPool.put(self._consumer.connection)
def __call__(self, data, message): def __call__(self, data, message):
"""Acks message and sets result.""" """Acks message and sets result."""
message.ack() message.ack()
if data['failure']: if data['failure']:
self.result = RemoteError(*data['failure']) self._results.put(RemoteError(*data['failure']))
else: else:
self.result = data['result'] self._results.put(data['result'])
wait_msg = WaitMessage() def __iter__(self):
conn = Connection.instance() return self.wait()
consumer = DirectConsumer(connection=conn, msg_id=msg_id)
consumer.register_callback(wait_msg)
conn = Connection.instance()
publisher = TopicPublisher(connection=conn, topic=topic)
publisher.send(msg)
publisher.close()
def wait(self):
while True:
rv = None
while rv is None and not self._closed:
try: try:
consumer.wait(limit=1) rv = self._consumer.fetch(enable_callbacks=True)
except StopIteration: except Exception:
pass self.close()
consumer.close() raise
# NOTE(termie): this is a little bit of a change from the original time.sleep(0.01)
# non-eventlet code where returning a Failure
# instance from a deferred call is very similar to result = self._results.get()
# raising an exception if isinstance(result, Exception):
if isinstance(wait_msg.result, Exception): self.close()
raise wait_msg.result raise result
return wait_msg.result if result == None:
self.close()
raise StopIteration
yield result
def call(context, topic, msg):
"""Sends a message on a topic and wait for a response."""
rv = multicall(context, topic, msg)
# NOTE(vish): return the last result from the multicall
rv = list(rv)
if not rv:
return
return rv[-1]
def cast(context, topic, msg): def cast(context, topic, msg):
"""Sends a message on a topic without waiting for a response.""" """Sends a message on a topic without waiting for a response."""
LOG.debug(_('Making asynchronous cast on %s...'), topic) LOG.debug(_('Making asynchronous cast on %s...'), topic)
_pack_context(msg, context) _pack_context(msg, context)
conn = Connection.instance() with ConnectionPool.item() as conn:
publisher = TopicPublisher(connection=conn, topic=topic) publisher = TopicPublisher(connection=conn, topic=topic)
publisher.send(msg) publisher.send(msg)
publisher.close() publisher.close()
@@ -420,7 +557,7 @@ def fanout_cast(context, topic, msg):
"""Sends a message on a fanout exchange without waiting for a response.""" """Sends a message on a fanout exchange without waiting for a response."""
LOG.debug(_('Making asynchronous fanout cast...')) LOG.debug(_('Making asynchronous fanout cast...'))
_pack_context(msg, context) _pack_context(msg, context)
conn = Connection.instance() with ConnectionPool.item() as conn:
publisher = FanoutPublisher(topic, connection=conn) publisher = FanoutPublisher(topic, connection=conn)
publisher.send(msg) publisher.send(msg)
publisher.close() publisher.close()
@@ -459,6 +596,7 @@ def send_message(topic, message, wait=True):
if wait: if wait:
consumer.wait() consumer.wait()
consumer.close()
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -41,6 +41,7 @@ import json
from nova import exception from nova import exception
from nova import flags from nova import flags
from nova import log as logging from nova import log as logging
from nova.scheduler import zone_aware_scheduler
from nova import utils from nova import utils
from nova.scheduler import zone_aware_scheduler from nova.scheduler import zone_aware_scheduler
@@ -69,9 +70,10 @@ class HostFilter(object):
class AllHostsFilter(HostFilter): class AllHostsFilter(HostFilter):
"""NOP host filter. Returns all hosts in ZoneManager. """ NOP host filter. Returns all hosts in ZoneManager.
This essentially does what the old Scheduler+Chance used This essentially does what the old Scheduler+Chance used
to give us.""" to give us.
"""
def instance_type_to_filter(self, instance_type): def instance_type_to_filter(self, instance_type):
"""Return anything to prevent base-class from raising """Return anything to prevent base-class from raising
@@ -134,7 +136,8 @@ class InstanceTypeFilter(HostFilter):
class JsonFilter(HostFilter): class JsonFilter(HostFilter):
"""Host Filter to allow simple JSON-based grammar for """Host Filter to allow simple JSON-based grammar for
selecting hosts.""" selecting hosts.
"""
def _equals(self, args): def _equals(self, args):
"""First term is == all the other terms.""" """First term is == all the other terms."""
@@ -224,13 +227,14 @@ class JsonFilter(HostFilter):
required_disk = instance_type['local_gb'] required_disk = instance_type['local_gb']
query = ['and', query = ['and',
['>=', '$compute.host_memory_free', required_ram], ['>=', '$compute.host_memory_free', required_ram],
['>=', '$compute.disk_available', required_disk] ['>=', '$compute.disk_available', required_disk],
] ]
return (self._full_name(), json.dumps(query)) return (self._full_name(), json.dumps(query))
def _parse_string(self, string, host, services): def _parse_string(self, string, host, services):
"""Strings prefixed with $ are capability lookups in the """Strings prefixed with $ are capability lookups in the
form '$service.capability[.subcap*]'""" form '$service.capability[.subcap*]'
"""
if not string: if not string:
return None return None
if string[0] != '$': if string[0] != '$':
@@ -280,13 +284,14 @@ def choose_host_filter(filter_name=None):
"""Since the caller may specify which filter to use we need """Since the caller may specify which filter to use we need
to have an authoritative list of what is permissible. This to have an authoritative list of what is permissible. This
function checks the filter name against a predefined set function checks the filter name against a predefined set
of acceptable filters.""" of acceptable filters.
"""
if not filter_name: if not filter_name:
filter_name = FLAGS.default_host_filter filter_name = FLAGS.default_host_filter
for filter_class in FILTERS: for filter_class in FILTERS:
if "%s.%s" % (filter_class.__module__, filter_class.__name__) == \ host_match = "%s.%s" % (filter_class.__module__, filter_class.__name__)
filter_name: if host_match == filter_name:
return filter_class() return filter_class()
raise exception.SchedulerHostFilterNotFound(filter_name=filter_name) raise exception.SchedulerHostFilterNotFound(filter_name=filter_name)
@@ -314,5 +319,6 @@ class HostFilterScheduler(zone_aware_scheduler.ZoneAwareScheduler):
def weigh_hosts(self, num, request_spec, hosts): def weigh_hosts(self, num, request_spec, hosts):
"""Derived classes must override this method and return """Derived classes must override this method and return
a lists of hosts in [{weight, hostname}] format.""" a lists of hosts in [{weight, hostname}] format.
"""
return [dict(weight=1, hostname=host) for host, caps in hosts] return [dict(weight=1, hostname=host) for host, caps in hosts]

View File

@@ -0,0 +1,156 @@
# Copyright (c) 2011 Openstack, LLC.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Least Cost Scheduler is a mechanism for choosing which host machines to
provision a set of resources to. The input of the least-cost-scheduler is a
set of objective-functions, called the 'cost-functions', a weight for each
cost-function, and a list of candidate hosts (gathered via FilterHosts).
The cost-function and weights are tabulated, and the host with the least cost
is then selected for provisioning.
"""
import collections
from nova import flags
from nova import log as logging
from nova.scheduler import zone_aware_scheduler
from nova import utils
LOG = logging.getLogger('nova.scheduler.least_cost')
FLAGS = flags.FLAGS
flags.DEFINE_list('least_cost_scheduler_cost_functions',
['nova.scheduler.least_cost.noop_cost_fn'],
'Which cost functions the LeastCostScheduler should use.')
# TODO(sirp): Once we have enough of these rules, we can break them out into a
# cost_functions.py file (perhaps in a least_cost_scheduler directory)
flags.DEFINE_integer('noop_cost_fn_weight', 1,
'How much weight to give the noop cost function')
def noop_cost_fn(host):
"""Return a pre-weight cost of 1 for each host"""
return 1
flags.DEFINE_integer('fill_first_cost_fn_weight', 1,
'How much weight to give the fill-first cost function')
def fill_first_cost_fn(host):
"""Prefer hosts that have less ram available, filter_hosts will exclude
hosts that don't have enough ram"""
hostname, caps = host
free_mem = caps['compute']['host_memory_free']
return free_mem
class LeastCostScheduler(zone_aware_scheduler.ZoneAwareScheduler):
def get_cost_fns(self):
"""Returns a list of tuples containing weights and cost functions to
use for weighing hosts
"""
cost_fns = []
for cost_fn_str in FLAGS.least_cost_scheduler_cost_functions:
try:
# NOTE(sirp): import_class is somewhat misnamed since it can
# any callable from a module
cost_fn = utils.import_class(cost_fn_str)
except exception.ClassNotFound:
raise exception.SchedulerCostFunctionNotFound(
cost_fn_str=cost_fn_str)
try:
weight = getattr(FLAGS, "%s_weight" % cost_fn.__name__)
except AttributeError:
raise exception.SchedulerWeightFlagNotFound(
flag_name=flag_name)
cost_fns.append((weight, cost_fn))
return cost_fns
def weigh_hosts(self, num, request_spec, hosts):
"""Returns a list of dictionaries of form:
[ {weight: weight, hostname: hostname} ]"""
# FIXME(sirp): weigh_hosts should handle more than just instances
hostnames = [hostname for hostname, caps in hosts]
cost_fns = self.get_cost_fns()
costs = weighted_sum(domain=hosts, weighted_fns=cost_fns)
weighted = []
weight_log = []
for cost, hostname in zip(costs, hostnames):
weight_log.append("%s: %s" % (hostname, "%.2f" % cost))
weight_dict = dict(weight=cost, hostname=hostname)
weighted.append(weight_dict)
LOG.debug(_("Weighted Costs => %s") % weight_log)
return weighted
def normalize_list(L):
"""Normalize an array of numbers such that each element satisfies:
0 <= e <= 1"""
if not L:
return L
max_ = max(L)
if max_ > 0:
return [(float(e) / max_) for e in L]
return L
def weighted_sum(domain, weighted_fns, normalize=True):
"""Use the weighted-sum method to compute a score for an array of objects.
Normalize the results of the objective-functions so that the weights are
meaningful regardless of objective-function's range.
domain - input to be scored
weighted_fns - list of weights and functions like:
[(weight, objective-functions)]
Returns an unsorted of scores. To pair with hosts do: zip(scores, hosts)
"""
# Table of form:
# { domain1: [score1, score2, ..., scoreM]
# ...
# domainN: [score1, score2, ..., scoreM] }
score_table = collections.defaultdict(list)
for weight, fn in weighted_fns:
scores = [fn(elem) for elem in domain]
if normalize:
norm_scores = normalize_list(scores)
else:
norm_scores = scores
for idx, score in enumerate(norm_scores):
weighted_score = score * weight
score_table[idx].append(weighted_score)
# Sum rows in table to compute score for each element in domain
domain_scores = []
for idx in sorted(score_table):
elem_score = sum(score_table[idx])
elem = domain[idx]
domain_scores.append(elem_score)
return domain_scores

View File

@@ -21,10 +21,9 @@
Simple Scheduler Simple Scheduler
""" """
import datetime
from nova import db from nova import db
from nova import flags from nova import flags
from nova import utils
from nova.scheduler import driver from nova.scheduler import driver
from nova.scheduler import chance from nova.scheduler import chance
@@ -54,7 +53,7 @@ class SimpleScheduler(chance.ChanceScheduler):
# TODO(vish): this probably belongs in the manager, if we # TODO(vish): this probably belongs in the manager, if we
# can generalize this somehow # can generalize this somehow
now = datetime.datetime.utcnow() now = utils.utcnow()
db.instance_update(context, instance_id, {'host': host, db.instance_update(context, instance_id, {'host': host,
'scheduled_at': now}) 'scheduled_at': now})
return host return host
@@ -66,7 +65,7 @@ class SimpleScheduler(chance.ChanceScheduler):
if self.service_is_up(service): if self.service_is_up(service):
# NOTE(vish): this probably belongs in the manager, if we # NOTE(vish): this probably belongs in the manager, if we
# can generalize this somehow # can generalize this somehow
now = datetime.datetime.utcnow() now = utils.utcnow()
db.instance_update(context, db.instance_update(context,
instance_id, instance_id,
{'host': service['host'], {'host': service['host'],
@@ -90,7 +89,7 @@ class SimpleScheduler(chance.ChanceScheduler):
# TODO(vish): this probably belongs in the manager, if we # TODO(vish): this probably belongs in the manager, if we
# can generalize this somehow # can generalize this somehow
now = datetime.datetime.utcnow() now = utils.utcnow()
db.volume_update(context, volume_id, {'host': host, db.volume_update(context, volume_id, {'host': host,
'scheduled_at': now}) 'scheduled_at': now})
return host return host
@@ -103,7 +102,7 @@ class SimpleScheduler(chance.ChanceScheduler):
if self.service_is_up(service): if self.service_is_up(service):
# NOTE(vish): this probably belongs in the manager, if we # NOTE(vish): this probably belongs in the manager, if we
# can generalize this somehow # can generalize this somehow
now = datetime.datetime.utcnow() now = utils.utcnow()
db.volume_update(context, db.volume_update(context,
volume_id, volume_id,
{'host': service['host'], {'host': service['host'],

View File

@@ -22,6 +22,7 @@ across zones. There are two expansion points to this class for:
import operator import operator
import json import json
import M2Crypto import M2Crypto
import novaclient import novaclient
@@ -167,7 +168,8 @@ class ZoneAwareScheduler(driver.Scheduler):
1. Create a Build Plan and then provision, or 1. Create a Build Plan and then provision, or
2. Use the Build Plan information in the request parameters 2. Use the Build Plan information in the request parameters
to simply create the instance (either in this zone or to simply create the instance (either in this zone or
a child zone).""" a child zone).
"""
# TODO(sandy): We'll have to look for richer specs at some point. # TODO(sandy): We'll have to look for richer specs at some point.
@@ -194,7 +196,8 @@ class ZoneAwareScheduler(driver.Scheduler):
"""Select returns a list of weights and zone/host information """Select returns a list of weights and zone/host information
corresponding to the best hosts to service the request. Any corresponding to the best hosts to service the request. Any
child zone information has been encrypted so as not to reveal child zone information has been encrypted so as not to reveal
anything about the children.""" anything about the children.
"""
return self._schedule(context, "compute", request_spec, return self._schedule(context, "compute", request_spec,
*args, **kwargs) *args, **kwargs)
@@ -222,6 +225,9 @@ class ZoneAwareScheduler(driver.Scheduler):
# Filter local hosts based on requirements ... # Filter local hosts based on requirements ...
host_list = self.filter_hosts(num_instances, request_spec) host_list = self.filter_hosts(num_instances, request_spec)
# TODO(sirp): weigh_hosts should also be a function of 'topic' or
# resources, so that we can apply different objective functions to it
# then weigh the selected hosts. # then weigh the selected hosts.
# weighted = [{weight=weight, name=hostname}, ...] # weighted = [{weight=weight, name=hostname}, ...]
weighted = self.weigh_hosts(num_instances, request_spec, host_list) weighted = self.weigh_hosts(num_instances, request_spec, host_list)
@@ -245,10 +251,16 @@ class ZoneAwareScheduler(driver.Scheduler):
def filter_hosts(self, num, request_spec): def filter_hosts(self, num, request_spec):
"""Derived classes must override this method and return """Derived classes must override this method and return
a list of hosts in [(hostname, capability_dict)] format.""" a list of hosts in [(hostname, capability_dict)] format.
raise NotImplemented() """
# NOTE(sirp): The default logic is the equivalent to AllHostsFilter
service_states = self.zone_manager.service_states
return [(host, services)
for host, services in service_states.iteritems()]
def weigh_hosts(self, num, request_spec, hosts): def weigh_hosts(self, num, request_spec, hosts):
"""Derived classes must override this method and return """Derived classes may override this to provide more sophisticated
a lists of hosts in [{weight, hostname}] format.""" scheduling objectives
raise NotImplemented() """
# NOTE(sirp): The default logic is the same as the NoopCostFunction
return [dict(weight=1, hostname=host) for host, caps in hosts]

View File

@@ -17,16 +17,17 @@
ZoneManager oversees all communications with child Zones. ZoneManager oversees all communications with child Zones.
""" """
import datetime
import novaclient import novaclient
import thread import thread
import traceback import traceback
from datetime import datetime
from eventlet import greenpool from eventlet import greenpool
from nova import db from nova import db
from nova import flags from nova import flags
from nova import log as logging from nova import log as logging
from nova import utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_integer('zone_db_check_interval', 60, flags.DEFINE_integer('zone_db_check_interval', 60,
@@ -42,7 +43,7 @@ class ZoneState(object):
self.name = None self.name = None
self.capabilities = None self.capabilities = None
self.attempt = 0 self.attempt = 0
self.last_seen = datetime.min self.last_seen = datetime.datetime.min
self.last_exception = None self.last_exception = None
self.last_exception_time = None self.last_exception_time = None
@@ -56,7 +57,7 @@ class ZoneState(object):
def update_metadata(self, zone_metadata): def update_metadata(self, zone_metadata):
"""Update zone metadata after successful communications with """Update zone metadata after successful communications with
child zone.""" child zone."""
self.last_seen = datetime.now() self.last_seen = utils.utcnow()
self.attempt = 0 self.attempt = 0
self.name = zone_metadata.get("name", "n/a") self.name = zone_metadata.get("name", "n/a")
self.capabilities = ", ".join(["%s=%s" % (k, v) self.capabilities = ", ".join(["%s=%s" % (k, v)
@@ -72,7 +73,7 @@ class ZoneState(object):
"""Something went wrong. Check to see if zone should be """Something went wrong. Check to see if zone should be
marked as offline.""" marked as offline."""
self.last_exception = exception self.last_exception = exception
self.last_exception_time = datetime.now() self.last_exception_time = utils.utcnow()
api_url = self.api_url api_url = self.api_url
logging.warning(_("'%(exception)s' error talking to " logging.warning(_("'%(exception)s' error talking to "
"zone %(api_url)s") % locals()) "zone %(api_url)s") % locals())
@@ -104,7 +105,7 @@ def _poll_zone(zone):
class ZoneManager(object): class ZoneManager(object):
"""Keeps the zone states updated.""" """Keeps the zone states updated."""
def __init__(self): def __init__(self):
self.last_zone_db_check = datetime.min self.last_zone_db_check = datetime.datetime.min
self.zone_states = {} # { <zone_id> : ZoneState } self.zone_states = {} # { <zone_id> : ZoneState }
self.service_states = {} # { <host> : { <service> : { cap k : v }}} self.service_states = {} # { <host> : { <service> : { cap k : v }}}
self.green_pool = greenpool.GreenPool() self.green_pool = greenpool.GreenPool()
@@ -158,10 +159,10 @@ class ZoneManager(object):
def ping(self, context=None): def ping(self, context=None):
"""Ping should be called periodically to update zone status.""" """Ping should be called periodically to update zone status."""
diff = datetime.now() - self.last_zone_db_check diff = utils.utcnow() - self.last_zone_db_check
if diff.seconds >= FLAGS.zone_db_check_interval: if diff.seconds >= FLAGS.zone_db_check_interval:
logging.debug(_("Updating zone cache from db.")) logging.debug(_("Updating zone cache from db."))
self.last_zone_db_check = datetime.now() self.last_zone_db_check = utils.utcnow()
self._refresh_from_db(context) self._refresh_from_db(context)
self._poll_zones(context) self._poll_zones(context)

View File

View File

@@ -0,0 +1,206 @@
# Copyright 2011 OpenStack LLC.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Tests For Scheduler Host Filters.
"""
import json
from nova import exception
from nova import flags
from nova import test
from nova.scheduler import host_filter
FLAGS = flags.FLAGS
class FakeZoneManager:
pass
class HostFilterTestCase(test.TestCase):
"""Test case for host filters."""
def _host_caps(self, multiplier):
# Returns host capabilities in the following way:
# host1 = memory:free 10 (100max)
# disk:available 100 (1000max)
# hostN = memory:free 10 + 10N
# disk:available 100 + 100N
# in other words: hostN has more resources than host0
# which means ... don't go above 10 hosts.
return {'host_name-description': 'XenServer %s' % multiplier,
'host_hostname': 'xs-%s' % multiplier,
'host_memory_total': 100,
'host_memory_overhead': 10,
'host_memory_free': 10 + multiplier * 10,
'host_memory_free-computed': 10 + multiplier * 10,
'host_other-config': {},
'host_ip_address': '192.168.1.%d' % (100 + multiplier),
'host_cpu_info': {},
'disk_available': 100 + multiplier * 100,
'disk_total': 1000,
'disk_used': 0,
'host_uuid': 'xxx-%d' % multiplier,
'host_name-label': 'xs-%s' % multiplier}
def setUp(self):
self.old_flag = FLAGS.default_host_filter
FLAGS.default_host_filter = \
'nova.scheduler.host_filter.AllHostsFilter'
self.instance_type = dict(name='tiny',
memory_mb=50,
vcpus=10,
local_gb=500,
flavorid=1,
swap=500,
rxtx_quota=30000,
rxtx_cap=200)
self.zone_manager = FakeZoneManager()
states = {}
for x in xrange(10):
states['host%02d' % (x + 1)] = {'compute': self._host_caps(x)}
self.zone_manager.service_states = states
def tearDown(self):
FLAGS.default_host_filter = self.old_flag
def test_choose_filter(self):
# Test default filter ...
hf = host_filter.choose_host_filter()
self.assertEquals(hf._full_name(),
'nova.scheduler.host_filter.AllHostsFilter')
# Test valid filter ...
hf = host_filter.choose_host_filter(
'nova.scheduler.host_filter.InstanceTypeFilter')
self.assertEquals(hf._full_name(),
'nova.scheduler.host_filter.InstanceTypeFilter')
# Test invalid filter ...
try:
host_filter.choose_host_filter('does not exist')
self.fail("Should not find host filter.")
except exception.SchedulerHostFilterNotFound:
pass
def test_all_host_filter(self):
hf = host_filter.AllHostsFilter()
cooked = hf.instance_type_to_filter(self.instance_type)
hosts = hf.filter_hosts(self.zone_manager, cooked)
self.assertEquals(10, len(hosts))
for host, capabilities in hosts:
self.assertTrue(host.startswith('host'))
def test_instance_type_filter(self):
hf = host_filter.InstanceTypeFilter()
# filter all hosts that can support 50 ram and 500 disk
name, cooked = hf.instance_type_to_filter(self.instance_type)
self.assertEquals('nova.scheduler.host_filter.InstanceTypeFilter',
name)
hosts = hf.filter_hosts(self.zone_manager, cooked)
self.assertEquals(6, len(hosts))
just_hosts = [host for host, caps in hosts]
just_hosts.sort()
self.assertEquals('host05', just_hosts[0])
self.assertEquals('host10', just_hosts[5])
def test_json_filter(self):
hf = host_filter.JsonFilter()
# filter all hosts that can support 50 ram and 500 disk
name, cooked = hf.instance_type_to_filter(self.instance_type)
self.assertEquals('nova.scheduler.host_filter.JsonFilter', name)
hosts = hf.filter_hosts(self.zone_manager, cooked)
self.assertEquals(6, len(hosts))
just_hosts = [host for host, caps in hosts]
just_hosts.sort()
self.assertEquals('host05', just_hosts[0])
self.assertEquals('host10', just_hosts[5])
# Try some custom queries
raw = ['or',
['and',
['<', '$compute.host_memory_free', 30],
['<', '$compute.disk_available', 300]
],
['and',
['>', '$compute.host_memory_free', 70],
['>', '$compute.disk_available', 700]
]
]
cooked = json.dumps(raw)
hosts = hf.filter_hosts(self.zone_manager, cooked)
self.assertEquals(5, len(hosts))
just_hosts = [host for host, caps in hosts]
just_hosts.sort()
for index, host in zip([1, 2, 8, 9, 10], just_hosts):
self.assertEquals('host%02d' % index, host)
raw = ['not',
['=', '$compute.host_memory_free', 30],
]
cooked = json.dumps(raw)
hosts = hf.filter_hosts(self.zone_manager, cooked)
self.assertEquals(9, len(hosts))
just_hosts = [host for host, caps in hosts]
just_hosts.sort()
for index, host in zip([1, 2, 4, 5, 6, 7, 8, 9, 10], just_hosts):
self.assertEquals('host%02d' % index, host)
raw = ['in', '$compute.host_memory_free', 20, 40, 60, 80, 100]
cooked = json.dumps(raw)
hosts = hf.filter_hosts(self.zone_manager, cooked)
self.assertEquals(5, len(hosts))
just_hosts = [host for host, caps in hosts]
just_hosts.sort()
for index, host in zip([2, 4, 6, 8, 10], just_hosts):
self.assertEquals('host%02d' % index, host)
# Try some bogus input ...
raw = ['unknown command', ]
cooked = json.dumps(raw)
try:
hf.filter_hosts(self.zone_manager, cooked)
self.fail("Should give KeyError")
except KeyError, e:
pass
self.assertTrue(hf.filter_hosts(self.zone_manager, json.dumps([])))
self.assertTrue(hf.filter_hosts(self.zone_manager, json.dumps({})))
self.assertTrue(hf.filter_hosts(self.zone_manager, json.dumps(
['not', True, False, True, False]
)))
try:
hf.filter_hosts(self.zone_manager, json.dumps(
'not', True, False, True, False
))
self.fail("Should give KeyError")
except KeyError, e:
pass
self.assertFalse(hf.filter_hosts(self.zone_manager,
json.dumps(['=', '$foo', 100])))
self.assertFalse(hf.filter_hosts(self.zone_manager,
json.dumps(['=', '$.....', 100])))
self.assertFalse(hf.filter_hosts(self.zone_manager,
json.dumps(
['>', ['and', ['or', ['not', ['<', ['>=', ['<=', ['in', ]]]]]]]])))
self.assertFalse(hf.filter_hosts(self.zone_manager,
json.dumps(['=', {}, ['>', '$missing....foo']])))

View File

@@ -0,0 +1,144 @@
# Copyright 2011 OpenStack LLC.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Tests For Least Cost Scheduler
"""
from nova import flags
from nova import test
from nova.scheduler import least_cost
from nova.tests.scheduler import test_zone_aware_scheduler
MB = 1024 * 1024
FLAGS = flags.FLAGS
class FakeHost(object):
def __init__(self, host_id, free_ram, io):
self.id = host_id
self.free_ram = free_ram
self.io = io
class WeightedSumTestCase(test.TestCase):
def test_empty_domain(self):
domain = []
weighted_fns = []
result = least_cost.weighted_sum(domain, weighted_fns)
expected = []
self.assertEqual(expected, result)
def test_basic_costing(self):
hosts = [
FakeHost(1, 512 * MB, 100),
FakeHost(2, 256 * MB, 400),
FakeHost(3, 512 * MB, 100)
]
weighted_fns = [
(1, lambda h: h.free_ram), # Fill-first, free_ram is a *cost*
(2, lambda h: h.io), # Avoid high I/O
]
costs = least_cost.weighted_sum(
domain=hosts, weighted_fns=weighted_fns)
# Each 256 MB unit of free-ram contributes 0.5 points by way of:
# cost = weight * (score/max_score) = 1 * (256/512) = 0.5
# Each 100 iops of IO adds 0.5 points by way of:
# cost = 2 * (100/400) = 2 * 0.25 = 0.5
expected = [1.5, 2.5, 1.5]
self.assertEqual(expected, costs)
class LeastCostSchedulerTestCase(test.TestCase):
def setUp(self):
super(LeastCostSchedulerTestCase, self).setUp()
class FakeZoneManager:
pass
zone_manager = FakeZoneManager()
states = test_zone_aware_scheduler.fake_zone_manager_service_states(
num_hosts=10)
zone_manager.service_states = states
self.sched = least_cost.LeastCostScheduler()
self.sched.zone_manager = zone_manager
def tearDown(self):
super(LeastCostSchedulerTestCase, self).tearDown()
def assertWeights(self, expected, num, request_spec, hosts):
weighted = self.sched.weigh_hosts(num, request_spec, hosts)
self.assertDictListMatch(weighted, expected, approx_equal=True)
def test_no_hosts(self):
num = 1
request_spec = {}
hosts = []
expected = []
self.assertWeights(expected, num, request_spec, hosts)
def test_noop_cost_fn(self):
FLAGS.least_cost_scheduler_cost_functions = [
'nova.scheduler.least_cost.noop_cost_fn'
]
FLAGS.noop_cost_fn_weight = 1
num = 1
request_spec = {}
hosts = self.sched.filter_hosts(num, request_spec)
expected = [dict(weight=1, hostname=hostname)
for hostname, caps in hosts]
self.assertWeights(expected, num, request_spec, hosts)
def test_cost_fn_weights(self):
FLAGS.least_cost_scheduler_cost_functions = [
'nova.scheduler.least_cost.noop_cost_fn'
]
FLAGS.noop_cost_fn_weight = 2
num = 1
request_spec = {}
hosts = self.sched.filter_hosts(num, request_spec)
expected = [dict(weight=2, hostname=hostname)
for hostname, caps in hosts]
self.assertWeights(expected, num, request_spec, hosts)
def test_fill_first_cost_fn(self):
FLAGS.least_cost_scheduler_cost_functions = [
'nova.scheduler.least_cost.fill_first_cost_fn'
]
FLAGS.fill_first_cost_fn_weight = 1
num = 1
request_spec = {}
hosts = self.sched.filter_hosts(num, request_spec)
expected = []
for idx, (hostname, caps) in enumerate(hosts):
# Costs are normalized so over 10 hosts, each host with increasing
# free ram will cost 1/N more. Since the lowest cost host has some
# free ram, we add in the 1/N for the base_cost
weight = 0.1 + (0.1 * idx)
weight_dict = dict(weight=weight, hostname=hostname)
expected.append(weight_dict)
self.assertWeights(expected, num, request_spec, hosts)

View File

@@ -23,6 +23,37 @@ from nova.scheduler import zone_aware_scheduler
from nova.scheduler import zone_manager from nova.scheduler import zone_manager
def _host_caps(multiplier):
# Returns host capabilities in the following way:
# host1 = memory:free 10 (100max)
# disk:available 100 (1000max)
# hostN = memory:free 10 + 10N
# disk:available 100 + 100N
# in other words: hostN has more resources than host0
# which means ... don't go above 10 hosts.
return {'host_name-description': 'XenServer %s' % multiplier,
'host_hostname': 'xs-%s' % multiplier,
'host_memory_total': 100,
'host_memory_overhead': 10,
'host_memory_free': 10 + multiplier * 10,
'host_memory_free-computed': 10 + multiplier * 10,
'host_other-config': {},
'host_ip_address': '192.168.1.%d' % (100 + multiplier),
'host_cpu_info': {},
'disk_available': 100 + multiplier * 100,
'disk_total': 1000,
'disk_used': 0,
'host_uuid': 'xxx-%d' % multiplier,
'host_name-label': 'xs-%s' % multiplier}
def fake_zone_manager_service_states(num_hosts):
states = {}
for x in xrange(num_hosts):
states['host%02d' % (x + 1)] = {'compute': _host_caps(x)}
return states
class FakeZoneAwareScheduler(zone_aware_scheduler.ZoneAwareScheduler): class FakeZoneAwareScheduler(zone_aware_scheduler.ZoneAwareScheduler):
def filter_hosts(self, num, specs): def filter_hosts(self, num, specs):
# NOTE(sirp): this is returning [(hostname, services)] # NOTE(sirp): this is returning [(hostname, services)]
@@ -40,14 +71,14 @@ class FakeZoneManager(zone_manager.ZoneManager):
def __init__(self): def __init__(self):
self.service_states = { self.service_states = {
'host1': { 'host1': {
'compute': {'ram': 1000} 'compute': {'ram': 1000},
}, },
'host2': { 'host2': {
'compute': {'ram': 2000} 'compute': {'ram': 2000},
}, },
'host3': { 'host3': {
'compute': {'ram': 3000} 'compute': {'ram': 3000},
} },
} }

View File

@@ -86,6 +86,7 @@ class _AuthManagerBaseTestCase(test.TestCase):
super(_AuthManagerBaseTestCase, self).setUp() super(_AuthManagerBaseTestCase, self).setUp()
self.flags(connection_type='fake') self.flags(connection_type='fake')
self.manager = manager.AuthManager(new=True) self.manager = manager.AuthManager(new=True)
self.manager.mc.cache = {}
def test_create_and_find_user(self): def test_create_and_find_user(self):
with user_generator(self.manager): with user_generator(self.manager):

View File

@@ -17,32 +17,25 @@
# under the License. # under the License.
from base64 import b64decode from base64 import b64decode
import json
from M2Crypto import BIO from M2Crypto import BIO
from M2Crypto import RSA from M2Crypto import RSA
import os import os
import shutil
import tempfile
import time
from eventlet import greenthread from eventlet import greenthread
from nova import context from nova import context
from nova import crypto from nova import crypto
from nova import db from nova import db
from nova import exception
from nova import flags from nova import flags
from nova import log as logging from nova import log as logging
from nova import rpc from nova import rpc
from nova import service
from nova import test from nova import test
from nova import utils from nova import utils
from nova import exception
from nova.auth import manager from nova.auth import manager
from nova.compute import power_state
from nova.api.ec2 import cloud from nova.api.ec2 import cloud
from nova.api.ec2 import ec2utils from nova.api.ec2 import ec2utils
from nova.image import local from nova.image import local
from nova.exception import NotFound
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
@@ -74,19 +67,26 @@ class CloudTestCase(test.TestCase):
def fake_show(meh, context, id): def fake_show(meh, context, id):
return {'id': 1, 'properties': {'kernel_id': 1, 'ramdisk_id': 1, return {'id': 1, 'properties': {'kernel_id': 1, 'ramdisk_id': 1,
'type': 'machine'}} 'type': 'machine', 'image_state': 'available'}}
self.stubs.Set(local.LocalImageService, 'show', fake_show) self.stubs.Set(local.LocalImageService, 'show', fake_show)
self.stubs.Set(local.LocalImageService, 'show_by_name', fake_show) self.stubs.Set(local.LocalImageService, 'show_by_name', fake_show)
# NOTE(vish): set up a manual wait so rpc.cast has a chance to finish
rpc_cast = rpc.cast
def finish_cast(*args, **kwargs):
rpc_cast(*args, **kwargs)
greenthread.sleep(0.2)
self.stubs.Set(rpc, 'cast', finish_cast)
def tearDown(self): def tearDown(self):
network_ref = db.project_get_network(self.context, network_ref = db.project_get_network(self.context,
self.project.id) self.project.id)
db.network_disassociate(self.context, network_ref['id']) db.network_disassociate(self.context, network_ref['id'])
self.manager.delete_project(self.project) self.manager.delete_project(self.project)
self.manager.delete_user(self.user) self.manager.delete_user(self.user)
self.compute.kill()
self.network.kill()
super(CloudTestCase, self).tearDown() super(CloudTestCase, self).tearDown()
def _create_key(self, name): def _create_key(self, name):
@@ -113,7 +113,6 @@ class CloudTestCase(test.TestCase):
self.cloud.describe_addresses(self.context) self.cloud.describe_addresses(self.context)
self.cloud.release_address(self.context, self.cloud.release_address(self.context,
public_ip=address) public_ip=address)
greenthread.sleep(0.3)
db.floating_ip_destroy(self.context, address) db.floating_ip_destroy(self.context, address)
def test_associate_disassociate_address(self): def test_associate_disassociate_address(self):
@@ -129,12 +128,10 @@ class CloudTestCase(test.TestCase):
self.cloud.associate_address(self.context, self.cloud.associate_address(self.context,
instance_id=ec2_id, instance_id=ec2_id,
public_ip=address) public_ip=address)
greenthread.sleep(0.3)
self.cloud.disassociate_address(self.context, self.cloud.disassociate_address(self.context,
public_ip=address) public_ip=address)
self.cloud.release_address(self.context, self.cloud.release_address(self.context,
public_ip=address) public_ip=address)
greenthread.sleep(0.3)
self.network.deallocate_fixed_ip(self.context, fixed) self.network.deallocate_fixed_ip(self.context, fixed)
db.instance_destroy(self.context, inst['id']) db.instance_destroy(self.context, inst['id'])
db.floating_ip_destroy(self.context, address) db.floating_ip_destroy(self.context, address)
@@ -171,6 +168,25 @@ class CloudTestCase(test.TestCase):
db.volume_destroy(self.context, vol1['id']) db.volume_destroy(self.context, vol1['id'])
db.volume_destroy(self.context, vol2['id']) db.volume_destroy(self.context, vol2['id'])
def test_create_volume_from_snapshot(self):
"""Makes sure create_volume works when we specify a snapshot."""
vol = db.volume_create(self.context, {'size': 1})
snap = db.snapshot_create(self.context, {'volume_id': vol['id'],
'volume_size': vol['size'],
'status': "available"})
snapshot_id = ec2utils.id_to_ec2_id(snap['id'], 'snap-%08x')
result = self.cloud.create_volume(self.context,
snapshot_id=snapshot_id)
volume_id = result['volumeId']
result = self.cloud.describe_volumes(self.context)
self.assertEqual(len(result['volumeSet']), 2)
self.assertEqual(result['volumeSet'][1]['volumeId'], volume_id)
db.volume_destroy(self.context, ec2utils.ec2_id_to_id(volume_id))
db.snapshot_destroy(self.context, snap['id'])
db.volume_destroy(self.context, vol['id'])
def test_describe_availability_zones(self): def test_describe_availability_zones(self):
"""Makes sure describe_availability_zones works and filters results.""" """Makes sure describe_availability_zones works and filters results."""
service1 = db.service_create(self.context, {'host': 'host1_zones', service1 = db.service_create(self.context, {'host': 'host1_zones',
@@ -188,13 +204,59 @@ class CloudTestCase(test.TestCase):
db.service_destroy(self.context, service1['id']) db.service_destroy(self.context, service1['id'])
db.service_destroy(self.context, service2['id']) db.service_destroy(self.context, service2['id'])
def test_describe_snapshots(self):
"""Makes sure describe_snapshots works and filters results."""
vol = db.volume_create(self.context, {})
snap1 = db.snapshot_create(self.context, {'volume_id': vol['id']})
snap2 = db.snapshot_create(self.context, {'volume_id': vol['id']})
result = self.cloud.describe_snapshots(self.context)
self.assertEqual(len(result['snapshotSet']), 2)
snapshot_id = ec2utils.id_to_ec2_id(snap2['id'], 'snap-%08x')
result = self.cloud.describe_snapshots(self.context,
snapshot_id=[snapshot_id])
self.assertEqual(len(result['snapshotSet']), 1)
self.assertEqual(
ec2utils.ec2_id_to_id(result['snapshotSet'][0]['snapshotId']),
snap2['id'])
db.snapshot_destroy(self.context, snap1['id'])
db.snapshot_destroy(self.context, snap2['id'])
db.volume_destroy(self.context, vol['id'])
def test_create_snapshot(self):
"""Makes sure create_snapshot works."""
vol = db.volume_create(self.context, {'status': "available"})
volume_id = ec2utils.id_to_ec2_id(vol['id'], 'vol-%08x')
result = self.cloud.create_snapshot(self.context,
volume_id=volume_id)
snapshot_id = result['snapshotId']
result = self.cloud.describe_snapshots(self.context)
self.assertEqual(len(result['snapshotSet']), 1)
self.assertEqual(result['snapshotSet'][0]['snapshotId'], snapshot_id)
db.snapshot_destroy(self.context, ec2utils.ec2_id_to_id(snapshot_id))
db.volume_destroy(self.context, vol['id'])
def test_delete_snapshot(self):
"""Makes sure delete_snapshot works."""
vol = db.volume_create(self.context, {'status': "available"})
snap = db.snapshot_create(self.context, {'volume_id': vol['id'],
'status': "available"})
snapshot_id = ec2utils.id_to_ec2_id(snap['id'], 'snap-%08x')
result = self.cloud.delete_snapshot(self.context,
snapshot_id=snapshot_id)
self.assertTrue(result)
db.volume_destroy(self.context, vol['id'])
def test_describe_instances(self): def test_describe_instances(self):
"""Makes sure describe_instances works and filters results.""" """Makes sure describe_instances works and filters results."""
inst1 = db.instance_create(self.context, {'reservation_id': 'a', inst1 = db.instance_create(self.context, {'reservation_id': 'a',
'image_id': 1, 'image_ref': 1,
'host': 'host1'}) 'host': 'host1'})
inst2 = db.instance_create(self.context, {'reservation_id': 'a', inst2 = db.instance_create(self.context, {'reservation_id': 'a',
'image_id': 1, 'image_ref': 1,
'host': 'host2'}) 'host': 'host2'})
comp1 = db.service_create(self.context, {'host': 'host1', comp1 = db.service_create(self.context, {'host': 'host1',
'availability_zone': 'zone1', 'availability_zone': 'zone1',
@@ -227,7 +289,7 @@ class CloudTestCase(test.TestCase):
'type': 'machine'}}] 'type': 'machine'}}]
def fake_show_none(meh, context, id): def fake_show_none(meh, context, id):
raise NotFound raise exception.ImageNotFound(image_id='bad_image_id')
self.stubs.Set(local.LocalImageService, 'detail', fake_detail) self.stubs.Set(local.LocalImageService, 'detail', fake_detail)
# list all # list all
@@ -245,7 +307,7 @@ class CloudTestCase(test.TestCase):
self.stubs.UnsetAll() self.stubs.UnsetAll()
self.stubs.Set(local.LocalImageService, 'show', fake_show_none) self.stubs.Set(local.LocalImageService, 'show', fake_show_none)
self.stubs.Set(local.LocalImageService, 'show_by_name', fake_show_none) self.stubs.Set(local.LocalImageService, 'show_by_name', fake_show_none)
self.assertRaises(NotFound, describe_images, self.assertRaises(exception.ImageNotFound, describe_images,
self.context, ['ami-fake']) self.context, ['ami-fake'])
def test_describe_image_attribute(self): def test_describe_image_attribute(self):
@@ -306,31 +368,25 @@ class CloudTestCase(test.TestCase):
'instance_type': instance_type, 'instance_type': instance_type,
'max_count': max_count} 'max_count': max_count}
rv = self.cloud.run_instances(self.context, **kwargs) rv = self.cloud.run_instances(self.context, **kwargs)
greenthread.sleep(0.3)
instance_id = rv['instancesSet'][0]['instanceId'] instance_id = rv['instancesSet'][0]['instanceId']
output = self.cloud.get_console_output(context=self.context, output = self.cloud.get_console_output(context=self.context,
instance_id=[instance_id]) instance_id=[instance_id])
self.assertEquals(b64decode(output['output']), 'FAKE CONSOLE?OUTPUT') self.assertEquals(b64decode(output['output']), 'FAKE CONSOLE?OUTPUT')
# TODO(soren): We need this until we can stop polling in the rpc code # TODO(soren): We need this until we can stop polling in the rpc code
# for unit tests. # for unit tests.
greenthread.sleep(0.3)
rv = self.cloud.terminate_instances(self.context, [instance_id]) rv = self.cloud.terminate_instances(self.context, [instance_id])
greenthread.sleep(0.3)
def test_ajax_console(self): def test_ajax_console(self):
kwargs = {'image_id': 'ami-1'} kwargs = {'image_id': 'ami-1'}
rv = self.cloud.run_instances(self.context, **kwargs) rv = self.cloud.run_instances(self.context, **kwargs)
instance_id = rv['instancesSet'][0]['instanceId'] instance_id = rv['instancesSet'][0]['instanceId']
greenthread.sleep(0.3)
output = self.cloud.get_ajax_console(context=self.context, output = self.cloud.get_ajax_console(context=self.context,
instance_id=[instance_id]) instance_id=[instance_id])
self.assertEquals(output['url'], self.assertEquals(output['url'],
'%s/?token=FAKETOKEN' % FLAGS.ajax_console_proxy_url) '%s/?token=FAKETOKEN' % FLAGS.ajax_console_proxy_url)
# TODO(soren): We need this until we can stop polling in the rpc code # TODO(soren): We need this until we can stop polling in the rpc code
# for unit tests. # for unit tests.
greenthread.sleep(0.3)
rv = self.cloud.terminate_instances(self.context, [instance_id]) rv = self.cloud.terminate_instances(self.context, [instance_id])
greenthread.sleep(0.3)
def test_key_generation(self): def test_key_generation(self):
result = self._create_key('test') result = self._create_key('test')
@@ -388,9 +444,67 @@ class CloudTestCase(test.TestCase):
self._create_key('test') self._create_key('test')
self.cloud.delete_key_pair(self.context, 'test') self.cloud.delete_key_pair(self.context, 'test')
def test_run_instances(self):
kwargs = {'image_id': FLAGS.default_image,
'instance_type': FLAGS.default_instance_type,
'max_count': 1}
run_instances = self.cloud.run_instances
result = run_instances(self.context, **kwargs)
instance = result['instancesSet'][0]
self.assertEqual(instance['imageId'], 'ami-00000001')
self.assertEqual(instance['displayName'], 'Server 1')
self.assertEqual(instance['instanceId'], 'i-00000001')
self.assertEqual(instance['instanceState']['name'], 'networking')
self.assertEqual(instance['instanceType'], 'm1.small')
def test_run_instances_image_state_none(self):
kwargs = {'image_id': FLAGS.default_image,
'instance_type': FLAGS.default_instance_type,
'max_count': 1}
run_instances = self.cloud.run_instances
def fake_show_no_state(self, context, id):
return {'id': 1, 'properties': {'kernel_id': 1, 'ramdisk_id': 1,
'type': 'machine'}}
self.stubs.UnsetAll()
self.stubs.Set(local.LocalImageService, 'show', fake_show_no_state)
self.assertRaises(exception.ApiError, run_instances,
self.context, **kwargs)
def test_run_instances_image_state_invalid(self):
kwargs = {'image_id': FLAGS.default_image,
'instance_type': FLAGS.default_instance_type,
'max_count': 1}
run_instances = self.cloud.run_instances
def fake_show_decrypt(self, context, id):
return {'id': 1, 'properties': {'kernel_id': 1, 'ramdisk_id': 1,
'type': 'machine', 'image_state': 'decrypting'}}
self.stubs.UnsetAll()
self.stubs.Set(local.LocalImageService, 'show', fake_show_decrypt)
self.assertRaises(exception.ApiError, run_instances,
self.context, **kwargs)
def test_run_instances_image_status_active(self):
kwargs = {'image_id': FLAGS.default_image,
'instance_type': FLAGS.default_instance_type,
'max_count': 1}
run_instances = self.cloud.run_instances
def fake_show_stat_active(self, context, id):
return {'id': 1, 'properties': {'kernel_id': 1, 'ramdisk_id': 1,
'type': 'machine'}, 'status': 'active'}
self.stubs.Set(local.LocalImageService, 'show', fake_show_stat_active)
result = run_instances(self.context, **kwargs)
self.assertEqual(len(result['instancesSet']), 1)
def test_terminate_instances(self): def test_terminate_instances(self):
inst1 = db.instance_create(self.context, {'reservation_id': 'a', inst1 = db.instance_create(self.context, {'reservation_id': 'a',
'image_id': 1, 'image_ref': 1,
'host': 'host1'}) 'host': 'host1'})
terminate_instances = self.cloud.terminate_instances terminate_instances = self.cloud.terminate_instances
# valid instance_id # valid instance_id

View File

@@ -19,7 +19,6 @@
Tests For Compute Tests For Compute
""" """
import datetime
import mox import mox
import stubout import stubout
@@ -84,7 +83,7 @@ class ComputeTestCase(test.TestCase):
def _create_instance(self, params={}): def _create_instance(self, params={}):
"""Create a test instance""" """Create a test instance"""
inst = {} inst = {}
inst['image_id'] = 1 inst['image_ref'] = 1
inst['reservation_id'] = 'r-fakeres' inst['reservation_id'] = 'r-fakeres'
inst['launch_time'] = '10' inst['launch_time'] = '10'
inst['user_id'] = self.user.id inst['user_id'] = self.user.id
@@ -150,7 +149,7 @@ class ComputeTestCase(test.TestCase):
ref = self.compute_api.create( ref = self.compute_api.create(
self.context, self.context,
instance_type=instance_types.get_default_instance_type(), instance_type=instance_types.get_default_instance_type(),
image_id=None, image_href=None,
security_group=['testgroup']) security_group=['testgroup'])
try: try:
self.assertEqual(len(db.security_group_get_by_instance( self.assertEqual(len(db.security_group_get_by_instance(
@@ -168,7 +167,7 @@ class ComputeTestCase(test.TestCase):
ref = self.compute_api.create( ref = self.compute_api.create(
self.context, self.context,
instance_type=instance_types.get_default_instance_type(), instance_type=instance_types.get_default_instance_type(),
image_id=None, image_href=None,
security_group=['testgroup']) security_group=['testgroup'])
try: try:
db.instance_destroy(self.context, ref[0]['id']) db.instance_destroy(self.context, ref[0]['id'])
@@ -184,7 +183,7 @@ class ComputeTestCase(test.TestCase):
ref = self.compute_api.create( ref = self.compute_api.create(
self.context, self.context,
instance_type=instance_types.get_default_instance_type(), instance_type=instance_types.get_default_instance_type(),
image_id=None, image_href=None,
security_group=['testgroup']) security_group=['testgroup'])
try: try:
@@ -217,12 +216,12 @@ class ComputeTestCase(test.TestCase):
instance_ref = db.instance_get(self.context, instance_id) instance_ref = db.instance_get(self.context, instance_id)
self.assertEqual(instance_ref['launched_at'], None) self.assertEqual(instance_ref['launched_at'], None)
self.assertEqual(instance_ref['deleted_at'], None) self.assertEqual(instance_ref['deleted_at'], None)
launch = datetime.datetime.utcnow() launch = utils.utcnow()
self.compute.run_instance(self.context, instance_id) self.compute.run_instance(self.context, instance_id)
instance_ref = db.instance_get(self.context, instance_id) instance_ref = db.instance_get(self.context, instance_id)
self.assert_(instance_ref['launched_at'] > launch) self.assert_(instance_ref['launched_at'] > launch)
self.assertEqual(instance_ref['deleted_at'], None) self.assertEqual(instance_ref['deleted_at'], None)
terminate = datetime.datetime.utcnow() terminate = utils.utcnow()
self.compute.terminate_instance(self.context, instance_id) self.compute.terminate_instance(self.context, instance_id)
self.context = self.context.elevated(True) self.context = self.context.elevated(True)
instance_ref = db.instance_get(self.context, instance_id) instance_ref = db.instance_get(self.context, instance_id)

View File

@@ -20,8 +20,6 @@
Tests For Console proxy. Tests For Console proxy.
""" """
import datetime
from nova import context from nova import context
from nova import db from nova import db
from nova import exception from nova import exception

View File

@@ -133,13 +133,14 @@ class HostFilterTestCase(test.TestCase):
raw = ['or', raw = ['or',
['and', ['and',
['<', '$compute.host_memory_free', 30], ['<', '$compute.host_memory_free', 30],
['<', '$compute.disk_available', 300] ['<', '$compute.disk_available', 300],
], ],
['and', ['and',
['>', '$compute.host_memory_free', 70], ['>', '$compute.host_memory_free', 70],
['>', '$compute.disk_available', 700] ['>', '$compute.disk_available', 700],
] ],
] ]
cooked = json.dumps(raw) cooked = json.dumps(raw)
hosts = hf.filter_hosts(self.zone_manager, cooked) hosts = hf.filter_hosts(self.zone_manager, cooked)
@@ -183,13 +184,11 @@ class HostFilterTestCase(test.TestCase):
self.assertTrue(hf.filter_hosts(self.zone_manager, json.dumps([]))) self.assertTrue(hf.filter_hosts(self.zone_manager, json.dumps([])))
self.assertTrue(hf.filter_hosts(self.zone_manager, json.dumps({}))) self.assertTrue(hf.filter_hosts(self.zone_manager, json.dumps({})))
self.assertTrue(hf.filter_hosts(self.zone_manager, json.dumps( self.assertTrue(hf.filter_hosts(self.zone_manager, json.dumps(
['not', True, False, True, False] ['not', True, False, True, False])))
)))
try: try:
hf.filter_hosts(self.zone_manager, json.dumps( hf.filter_hosts(self.zone_manager, json.dumps(
'not', True, False, True, False 'not', True, False, True, False))
))
self.fail("Should give KeyError") self.fail("Should give KeyError")
except KeyError, e: except KeyError, e:
pass pass

View File

@@ -14,10 +14,12 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import copy
import eventlet import eventlet
import mox import mox
import os import os
import re import re
import shutil
import sys import sys
from xml.etree.ElementTree import fromstring as xml_to_tree from xml.etree.ElementTree import fromstring as xml_to_tree
@@ -124,6 +126,7 @@ class CacheConcurrencyTestCase(test.TestCase):
class LibvirtConnTestCase(test.TestCase): class LibvirtConnTestCase(test.TestCase):
def setUp(self): def setUp(self):
super(LibvirtConnTestCase, self).setUp() super(LibvirtConnTestCase, self).setUp()
connection._late_load_cheetah() connection._late_load_cheetah()
@@ -160,6 +163,7 @@ class LibvirtConnTestCase(test.TestCase):
'vcpus': 2, 'vcpus': 2,
'project_id': 'fake', 'project_id': 'fake',
'bridge': 'br101', 'bridge': 'br101',
'image_ref': '123456',
'instance_type_id': '5'} # m1.small 'instance_type_id': '5'} # m1.small
def lazy_load_library_exists(self): def lazy_load_library_exists(self):
@@ -205,6 +209,29 @@ class LibvirtConnTestCase(test.TestCase):
self.mox.StubOutWithMock(connection.LibvirtConnection, '_conn') self.mox.StubOutWithMock(connection.LibvirtConnection, '_conn')
connection.LibvirtConnection._conn = fake connection.LibvirtConnection._conn = fake
def fake_lookup(self, instance_name):
class FakeVirtDomain(object):
def snapshotCreateXML(self, *args):
return None
def XMLDesc(self, *args):
return """
<domain type='kvm'>
<devices>
<disk type='file'>
<source file='filename'/>
</disk>
</devices>
</domain>
"""
return FakeVirtDomain()
def fake_execute(self, *args):
open(args[-1], "a").close()
def create_service(self, **kwargs): def create_service(self, **kwargs):
service_ref = {'host': kwargs.get('host', 'dummy'), service_ref = {'host': kwargs.get('host', 'dummy'),
'binary': 'nova-compute', 'binary': 'nova-compute',
@@ -280,6 +307,81 @@ class LibvirtConnTestCase(test.TestCase):
instance_data = dict(self.test_instance) instance_data = dict(self.test_instance)
self._check_xml_and_container(instance_data) self._check_xml_and_container(instance_data)
def test_snapshot(self):
if not self.lazy_load_library_exists():
return
FLAGS.image_service = 'nova.image.fake.FakeImageService'
# Start test
image_service = utils.import_object(FLAGS.image_service)
# Assuming that base image already exists in image_service
instance_ref = db.instance_create(self.context, self.test_instance)
properties = {'instance_id': instance_ref['id'],
'user_id': str(self.context.user_id)}
snapshot_name = 'test-snap'
sent_meta = {'name': snapshot_name, 'is_public': False,
'status': 'creating', 'properties': properties}
# Create new image. It will be updated in snapshot method
# To work with it from snapshot, the single image_service is needed
recv_meta = image_service.create(context, sent_meta)
self.mox.StubOutWithMock(connection.LibvirtConnection, '_conn')
connection.LibvirtConnection._conn.lookupByName = self.fake_lookup
self.mox.StubOutWithMock(connection.utils, 'execute')
connection.utils.execute = self.fake_execute
self.mox.ReplayAll()
conn = connection.LibvirtConnection(False)
conn.snapshot(instance_ref, recv_meta['id'])
snapshot = image_service.show(context, recv_meta['id'])
self.assertEquals(snapshot['properties']['image_state'], 'available')
self.assertEquals(snapshot['status'], 'active')
self.assertEquals(snapshot['name'], snapshot_name)
def test_snapshot_no_image_architecture(self):
if not self.lazy_load_library_exists():
return
FLAGS.image_service = 'nova.image.fake.FakeImageService'
# Start test
image_service = utils.import_object(FLAGS.image_service)
# Assign image_ref = 2 from nova/images/fakes for testing different
# base image
test_instance = copy.deepcopy(self.test_instance)
test_instance["image_ref"] = "2"
# Assuming that base image already exists in image_service
instance_ref = db.instance_create(self.context, test_instance)
properties = {'instance_id': instance_ref['id'],
'user_id': str(self.context.user_id)}
snapshot_name = 'test-snap'
sent_meta = {'name': snapshot_name, 'is_public': False,
'status': 'creating', 'properties': properties}
# Create new image. It will be updated in snapshot method
# To work with it from snapshot, the single image_service is needed
recv_meta = image_service.create(context, sent_meta)
self.mox.StubOutWithMock(connection.LibvirtConnection, '_conn')
connection.LibvirtConnection._conn.lookupByName = self.fake_lookup
self.mox.StubOutWithMock(connection.utils, 'execute')
connection.utils.execute = self.fake_execute
self.mox.ReplayAll()
conn = connection.LibvirtConnection(False)
conn.snapshot(instance_ref, recv_meta['id'])
snapshot = image_service.show(context, recv_meta['id'])
self.assertEquals(snapshot['properties']['image_state'], 'available')
self.assertEquals(snapshot['status'], 'active')
self.assertEquals(snapshot['name'], snapshot_name)
def test_multi_nic(self): def test_multi_nic(self):
instance_data = dict(self.test_instance) instance_data = dict(self.test_instance)
network_info = _create_network_info(2) network_info = _create_network_info(2)
@@ -645,6 +747,8 @@ class LibvirtConnTestCase(test.TestCase):
except Exception, e: except Exception, e:
count = (0 <= str(e.message).find('Unexpected method call')) count = (0 <= str(e.message).find('Unexpected method call'))
shutil.rmtree(os.path.join(FLAGS.instances_path, instance.name))
self.assertTrue(count) self.assertTrue(count)
def test_get_host_ip_addr(self): def test_get_host_ip_addr(self):
@@ -658,6 +762,31 @@ class LibvirtConnTestCase(test.TestCase):
super(LibvirtConnTestCase, self).tearDown() super(LibvirtConnTestCase, self).tearDown()
class NWFilterFakes:
def __init__(self):
self.filters = {}
def nwfilterLookupByName(self, name):
if name in self.filters:
return self.filters[name]
raise libvirt.libvirtError('Filter Not Found')
def filterDefineXMLMock(self, xml):
class FakeNWFilterInternal:
def __init__(self, parent, name):
self.name = name
self.parent = parent
def undefine(self):
del self.parent.filters[self.name]
pass
tree = xml_to_tree(xml)
name = tree.get('name')
if name not in self.filters:
self.filters[name] = FakeNWFilterInternal(self, name)
return True
class IptablesFirewallTestCase(test.TestCase): class IptablesFirewallTestCase(test.TestCase):
def setUp(self): def setUp(self):
super(IptablesFirewallTestCase, self).setUp() super(IptablesFirewallTestCase, self).setUp()
@@ -675,6 +804,20 @@ class IptablesFirewallTestCase(test.TestCase):
self.fw = firewall.IptablesFirewallDriver( self.fw = firewall.IptablesFirewallDriver(
get_connection=lambda: self.fake_libvirt_connection) get_connection=lambda: self.fake_libvirt_connection)
def lazy_load_library_exists(self):
"""check if libvirt is available."""
# try to connect libvirt. if fail, skip test.
try:
import libvirt
import libxml2
except ImportError:
return False
global libvirt
libvirt = __import__('libvirt')
connection.libvirt = __import__('libvirt')
connection.libxml2 = __import__('libxml2')
return True
def tearDown(self): def tearDown(self):
self.manager.delete_project(self.project) self.manager.delete_project(self.project)
self.manager.delete_user(self.user) self.manager.delete_user(self.user)
@@ -880,6 +1023,40 @@ class IptablesFirewallTestCase(test.TestCase):
self.mox.ReplayAll() self.mox.ReplayAll()
self.fw.do_refresh_security_group_rules("fake") self.fw.do_refresh_security_group_rules("fake")
def test_unfilter_instance_undefines_nwfilter(self):
# Skip if non-libvirt environment
if not self.lazy_load_library_exists():
return
admin_ctxt = context.get_admin_context()
fakefilter = NWFilterFakes()
self.fw.nwfilter._conn.nwfilterDefineXML =\
fakefilter.filterDefineXMLMock
self.fw.nwfilter._conn.nwfilterLookupByName =\
fakefilter.nwfilterLookupByName
instance_ref = self._create_instance_ref()
inst_id = instance_ref['id']
instance = db.instance_get(self.context, inst_id)
ip = '10.11.12.13'
network_ref = db.project_get_network(self.context, 'fake')
fixed_ip = {'address': ip, 'network_id': network_ref['id']}
db.fixed_ip_create(admin_ctxt, fixed_ip)
db.fixed_ip_update(admin_ctxt, ip, {'allocated': True,
'instance_id': inst_id})
self.fw.setup_basic_filtering(instance)
self.fw.prepare_instance_filter(instance)
self.fw.apply_instance_filter(instance)
original_filter_count = len(fakefilter.filters)
self.fw.unfilter_instance(instance)
# should undefine just the instance filter
self.assertEqual(original_filter_count - len(fakefilter.filters), 1)
db.instance_destroy(admin_ctxt, instance_ref['id'])
class NWFilterTestCase(test.TestCase): class NWFilterTestCase(test.TestCase):
def setUp(self): def setUp(self):
@@ -1056,3 +1233,37 @@ class NWFilterTestCase(test.TestCase):
network_info, network_info,
"fake") "fake")
self.assertEquals(len(result), 3) self.assertEquals(len(result), 3)
def test_unfilter_instance_undefines_nwfilters(self):
admin_ctxt = context.get_admin_context()
fakefilter = NWFilterFakes()
self.fw._conn.nwfilterDefineXML = fakefilter.filterDefineXMLMock
self.fw._conn.nwfilterLookupByName = fakefilter.nwfilterLookupByName
instance_ref = self._create_instance()
inst_id = instance_ref['id']
self.security_group = self.setup_and_return_security_group()
db.instance_add_security_group(self.context, inst_id,
self.security_group.id)
instance = db.instance_get(self.context, inst_id)
ip = '10.11.12.13'
network_ref = db.project_get_network(self.context, 'fake')
fixed_ip = {'address': ip, 'network_id': network_ref['id']}
db.fixed_ip_create(admin_ctxt, fixed_ip)
db.fixed_ip_update(admin_ctxt, ip, {'allocated': True,
'instance_id': inst_id})
self.fw.setup_basic_filtering(instance)
self.fw.prepare_instance_filter(instance)
self.fw.apply_instance_filter(instance)
original_filter_count = len(fakefilter.filters)
self.fw.unfilter_instance(instance)
# should undefine 2 filters: instance and instance-secgroup
self.assertEqual(original_filter_count - len(fakefilter.filters), 2)
db.instance_destroy(admin_ctxt, instance_ref['id'])

View File

@@ -16,7 +16,6 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import datetime
import webob import webob
import webob.dec import webob.dec
import webob.exc import webob.exc

View File

@@ -21,11 +21,24 @@ import select
from eventlet import greenpool from eventlet import greenpool
from eventlet import greenthread from eventlet import greenthread
from nova import exception
from nova import test from nova import test
from nova import utils from nova import utils
from nova.utils import parse_mailmap, str_dict_replace from nova.utils import parse_mailmap, str_dict_replace
class ExceptionTestCase(test.TestCase):
@staticmethod
def _raise_exc(exc):
raise exc()
def test_exceptions_raise(self):
for name in dir(exception):
exc = getattr(exception, name)
if isinstance(exc, type):
self.assertRaises(exc, self._raise_exc, exc)
class ProjectTestCase(test.TestCase): class ProjectTestCase(test.TestCase):
def test_authors_up_to_date(self): def test_authors_up_to_date(self):
topdir = os.path.normpath(os.path.dirname(__file__) + '/../../') topdir = os.path.normpath(os.path.dirname(__file__) + '/../../')

View File

@@ -13,10 +13,12 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import nova import stubout
import nova
from nova import context from nova import context
from nova import flags from nova import flags
from nova import log
from nova import rpc from nova import rpc
import nova.notifier.api import nova.notifier.api
from nova.notifier.api import notify from nova.notifier.api import notify
@@ -24,8 +26,6 @@ from nova.notifier import no_op_notifier
from nova.notifier import rabbit_notifier from nova.notifier import rabbit_notifier
from nova import test from nova import test
import stubout
class NotifierTestCase(test.TestCase): class NotifierTestCase(test.TestCase):
"""Test case for notifications""" """Test case for notifications"""
@@ -115,3 +115,22 @@ class NotifierTestCase(test.TestCase):
notify('publisher_id', notify('publisher_id',
'event_type', 'DEBUG', dict(a=3)) 'event_type', 'DEBUG', dict(a=3))
self.assertEqual(self.test_topic, 'testnotify.debug') self.assertEqual(self.test_topic, 'testnotify.debug')
def test_error_notification(self):
self.stubs.Set(nova.flags.FLAGS, 'notification_driver',
'nova.notifier.rabbit_notifier')
self.stubs.Set(nova.flags.FLAGS, 'publish_errors', True)
LOG = log.getLogger('nova')
LOG.setup_from_flags()
msgs = []
def mock_cast(context, topic, data):
msgs.append(data)
self.stubs.Set(nova.rpc, 'cast', mock_cast)
LOG.error('foo')
self.assertEqual(1, len(msgs))
msg = msgs[0]
self.assertEqual(msg['event_type'], 'error_notification')
self.assertEqual(msg['priority'], 'ERROR')
self.assertEqual(msg['payload']['error'], 'foo')

View File

@@ -31,7 +31,6 @@ LOG = logging.getLogger('nova.tests.rpc')
class RpcTestCase(test.TestCase): class RpcTestCase(test.TestCase):
"""Test cases for rpc"""
def setUp(self): def setUp(self):
super(RpcTestCase, self).setUp() super(RpcTestCase, self).setUp()
self.conn = rpc.Connection.instance(True) self.conn = rpc.Connection.instance(True)
@@ -43,14 +42,55 @@ class RpcTestCase(test.TestCase):
self.context = context.get_admin_context() self.context = context.get_admin_context()
def test_call_succeed(self): def test_call_succeed(self):
"""Get a value through rpc call"""
value = 42 value = 42
result = rpc.call(self.context, 'test', {"method": "echo", result = rpc.call(self.context, 'test', {"method": "echo",
"args": {"value": value}}) "args": {"value": value}})
self.assertEqual(value, result) self.assertEqual(value, result)
def test_call_succeed_despite_multiple_returns(self):
value = 42
result = rpc.call(self.context, 'test', {"method": "echo_three_times",
"args": {"value": value}})
self.assertEqual(value + 2, result)
def test_call_succeed_despite_multiple_returns_yield(self):
value = 42
result = rpc.call(self.context, 'test',
{"method": "echo_three_times_yield",
"args": {"value": value}})
self.assertEqual(value + 2, result)
def test_multicall_succeed_once(self):
value = 42
result = rpc.multicall(self.context,
'test',
{"method": "echo",
"args": {"value": value}})
for i, x in enumerate(result):
if i > 0:
self.fail('should only receive one response')
self.assertEqual(value + i, x)
def test_multicall_succeed_three_times(self):
value = 42
result = rpc.multicall(self.context,
'test',
{"method": "echo_three_times",
"args": {"value": value}})
for i, x in enumerate(result):
self.assertEqual(value + i, x)
def test_multicall_succeed_three_times_yield(self):
value = 42
result = rpc.multicall(self.context,
'test',
{"method": "echo_three_times_yield",
"args": {"value": value}})
for i, x in enumerate(result):
self.assertEqual(value + i, x)
def test_context_passed(self): def test_context_passed(self):
"""Makes sure a context is passed through rpc call""" """Makes sure a context is passed through rpc call."""
value = 42 value = 42
result = rpc.call(self.context, result = rpc.call(self.context,
'test', {"method": "context", 'test', {"method": "context",
@@ -58,11 +98,12 @@ class RpcTestCase(test.TestCase):
self.assertEqual(self.context.to_dict(), result) self.assertEqual(self.context.to_dict(), result)
def test_call_exception(self): def test_call_exception(self):
"""Test that exception gets passed back properly """Test that exception gets passed back properly.
rpc.call returns a RemoteError object. The value of the rpc.call returns a RemoteError object. The value of the
exception is converted to a string, so we convert it back exception is converted to a string, so we convert it back
to an int in the test. to an int in the test.
""" """
value = 42 value = 42
self.assertRaises(rpc.RemoteError, self.assertRaises(rpc.RemoteError,
@@ -81,7 +122,7 @@ class RpcTestCase(test.TestCase):
self.assertEqual(int(exc.value), value) self.assertEqual(int(exc.value), value)
def test_nested_calls(self): def test_nested_calls(self):
"""Test that we can do an rpc.call inside another call""" """Test that we can do an rpc.call inside another call."""
class Nested(object): class Nested(object):
@staticmethod @staticmethod
def echo(context, queue, value): def echo(context, queue, value):
@@ -108,25 +149,80 @@ class RpcTestCase(test.TestCase):
"value": value}}) "value": value}})
self.assertEqual(value, result) self.assertEqual(value, result)
def test_connectionpool_single(self):
"""Test that ConnectionPool recycles a single connection."""
conn1 = rpc.ConnectionPool.get()
rpc.ConnectionPool.put(conn1)
conn2 = rpc.ConnectionPool.get()
rpc.ConnectionPool.put(conn2)
self.assertEqual(conn1, conn2)
def test_connectionpool_double(self):
"""Test that ConnectionPool returns and reuses separate connections.
When called consecutively we should get separate connections and upon
returning them those connections should be reused for future calls
before generating a new connection.
"""
conn1 = rpc.ConnectionPool.get()
conn2 = rpc.ConnectionPool.get()
self.assertNotEqual(conn1, conn2)
rpc.ConnectionPool.put(conn1)
rpc.ConnectionPool.put(conn2)
conn3 = rpc.ConnectionPool.get()
conn4 = rpc.ConnectionPool.get()
self.assertEqual(conn1, conn3)
self.assertEqual(conn2, conn4)
def test_connectionpool_limit(self):
"""Test connection pool limit and connection uniqueness."""
max_size = FLAGS.rpc_conn_pool_size
conns = []
for i in xrange(max_size):
conns.append(rpc.ConnectionPool.get())
self.assertFalse(rpc.ConnectionPool.free_items)
self.assertEqual(rpc.ConnectionPool.current_size,
rpc.ConnectionPool.max_size)
self.assertEqual(len(set(conns)), max_size)
class TestReceiver(object): class TestReceiver(object):
"""Simple Proxy class so the consumer has methods to call """Simple Proxy class so the consumer has methods to call.
Uses static methods because we aren't actually storing any state""" Uses static methods because we aren't actually storing any state.
"""
@staticmethod @staticmethod
def echo(context, value): def echo(context, value):
"""Simply returns whatever value is sent in""" """Simply returns whatever value is sent in."""
LOG.debug(_("Received %s"), value) LOG.debug(_("Received %s"), value)
return value return value
@staticmethod @staticmethod
def context(context, value): def context(context, value):
"""Returns dictionary version of context""" """Returns dictionary version of context."""
LOG.debug(_("Received %s"), context) LOG.debug(_("Received %s"), context)
return context.to_dict() return context.to_dict()
@staticmethod
def echo_three_times(context, value):
context.reply(value)
context.reply(value + 1)
context.reply(value + 2)
@staticmethod
def echo_three_times_yield(context, value):
yield value
yield value + 1
yield value + 2
@staticmethod @staticmethod
def fail(context, value): def fail(context, value):
"""Raises an exception with the value sent in""" """Raises an exception with the value sent in."""
raise Exception(value) raise Exception(value)

File diff suppressed because it is too large Load Diff

View File

@@ -55,8 +55,7 @@ class VMWareAPIVMTestCase(test.TestCase):
vmwareapi_fake.reset() vmwareapi_fake.reset()
db_fakes.stub_out_db_instance_api(self.stubs) db_fakes.stub_out_db_instance_api(self.stubs)
stubs.set_stubs(self.stubs) stubs.set_stubs(self.stubs)
glance_stubs.stubout_glance_client(self.stubs, glance_stubs.stubout_glance_client(self.stubs)
glance_stubs.FakeGlance)
self.conn = vmwareapi_conn.get_connection(False) self.conn = vmwareapi_conn.get_connection(False)
def _create_instance_in_the_db(self): def _create_instance_in_the_db(self):
@@ -64,7 +63,7 @@ class VMWareAPIVMTestCase(test.TestCase):
'id': 1, 'id': 1,
'project_id': self.project.id, 'project_id': self.project.id,
'user_id': self.user.id, 'user_id': self.user.id,
'image_id': "1", 'image_ref': "1",
'kernel_id': "1", 'kernel_id': "1",
'ramdisk_id': "1", 'ramdisk_id': "1",
'instance_type': 'm1.large', 'instance_type': 'm1.large',

View File

@@ -45,10 +45,11 @@ class VolumeTestCase(test.TestCase):
self.context = context.get_admin_context() self.context = context.get_admin_context()
@staticmethod @staticmethod
def _create_volume(size='0'): def _create_volume(size='0', snapshot_id=None):
"""Create a volume object.""" """Create a volume object."""
vol = {} vol = {}
vol['size'] = size vol['size'] = size
vol['snapshot_id'] = snapshot_id
vol['user_id'] = 'fake' vol['user_id'] = 'fake'
vol['project_id'] = 'fake' vol['project_id'] = 'fake'
vol['availability_zone'] = FLAGS.storage_availability_zone vol['availability_zone'] = FLAGS.storage_availability_zone
@@ -69,6 +70,25 @@ class VolumeTestCase(test.TestCase):
self.context, self.context,
volume_id) volume_id)
def test_create_volume_from_snapshot(self):
"""Test volume can be created from a snapshot."""
volume_src_id = self._create_volume()
self.volume.create_volume(self.context, volume_src_id)
snapshot_id = self._create_snapshot(volume_src_id)
self.volume.create_snapshot(self.context, volume_src_id, snapshot_id)
volume_dst_id = self._create_volume(0, snapshot_id)
self.volume.create_volume(self.context, volume_dst_id, snapshot_id)
self.assertEqual(volume_dst_id, db.volume_get(
context.get_admin_context(),
volume_dst_id).id)
self.assertEqual(snapshot_id, db.volume_get(
context.get_admin_context(),
volume_dst_id).snapshot_id)
self.volume.delete_volume(self.context, volume_dst_id)
self.volume.delete_snapshot(self.context, snapshot_id)
self.volume.delete_volume(self.context, volume_src_id)
def test_too_big_volume(self): def test_too_big_volume(self):
"""Ensure failure if a too large of a volume is requested.""" """Ensure failure if a too large of a volume is requested."""
# FIXME(vish): validation needs to move into the data layer in # FIXME(vish): validation needs to move into the data layer in
@@ -176,6 +196,34 @@ class VolumeTestCase(test.TestCase):
# This will allow us to test cross-node interactions # This will allow us to test cross-node interactions
pass pass
@staticmethod
def _create_snapshot(volume_id, size='0'):
"""Create a snapshot object."""
snap = {}
snap['volume_size'] = size
snap['user_id'] = 'fake'
snap['project_id'] = 'fake'
snap['volume_id'] = volume_id
snap['status'] = "creating"
return db.snapshot_create(context.get_admin_context(), snap)['id']
def test_create_delete_snapshot(self):
"""Test snapshot can be created and deleted."""
volume_id = self._create_volume()
self.volume.create_volume(self.context, volume_id)
snapshot_id = self._create_snapshot(volume_id)
self.volume.create_snapshot(self.context, volume_id, snapshot_id)
self.assertEqual(snapshot_id,
db.snapshot_get(context.get_admin_context(),
snapshot_id).id)
self.volume.delete_snapshot(self.context, snapshot_id)
self.assertRaises(exception.NotFound,
db.snapshot_get,
self.context,
snapshot_id)
self.volume.delete_volume(self.context, volume_id)
class DriverTestCase(test.TestCase): class DriverTestCase(test.TestCase):
"""Base Test class for Drivers.""" """Base Test class for Drivers."""

View File

@@ -79,7 +79,7 @@ class XenAPIVolumeTestCase(test.TestCase):
self.values = {'id': 1, self.values = {'id': 1,
'project_id': 'fake', 'project_id': 'fake',
'user_id': 'fake', 'user_id': 'fake',
'image_id': 1, 'image_ref': 1,
'kernel_id': 2, 'kernel_id': 2,
'ramdisk_id': 3, 'ramdisk_id': 3,
'instance_type_id': '3', # m1.large 'instance_type_id': '3', # m1.large
@@ -193,8 +193,7 @@ class XenAPIVMTestCase(test.TestCase):
stubs.stubout_is_vdi_pv(self.stubs) stubs.stubout_is_vdi_pv(self.stubs)
self.stubs.Set(VMOps, 'reset_network', reset_network) self.stubs.Set(VMOps, 'reset_network', reset_network)
stubs.stub_out_vm_methods(self.stubs) stubs.stub_out_vm_methods(self.stubs)
glance_stubs.stubout_glance_client(self.stubs, glance_stubs.stubout_glance_client(self.stubs)
glance_stubs.FakeGlance)
fake_utils.stub_out_utils_execute(self.stubs) fake_utils.stub_out_utils_execute(self.stubs)
self.context = context.RequestContext('fake', 'fake', False) self.context = context.RequestContext('fake', 'fake', False)
self.conn = xenapi_conn.get_connection(False) self.conn = xenapi_conn.get_connection(False)
@@ -207,7 +206,7 @@ class XenAPIVMTestCase(test.TestCase):
'id': id, 'id': id,
'project_id': proj, 'project_id': proj,
'user_id': user, 'user_id': user,
'image_id': 1, 'image_ref': 1,
'kernel_id': 2, 'kernel_id': 2,
'ramdisk_id': 3, 'ramdisk_id': 3,
'instance_type_id': '3', # m1.large 'instance_type_id': '3', # m1.large
@@ -351,14 +350,14 @@ class XenAPIVMTestCase(test.TestCase):
self.assertEquals(self.vm['HVM_boot_params'], {}) self.assertEquals(self.vm['HVM_boot_params'], {})
self.assertEquals(self.vm['HVM_boot_policy'], '') self.assertEquals(self.vm['HVM_boot_policy'], '')
def _test_spawn(self, image_id, kernel_id, ramdisk_id, def _test_spawn(self, image_ref, kernel_id, ramdisk_id,
instance_type_id="3", os_type="linux", instance_type_id="3", os_type="linux",
instance_id=1, check_injection=False): instance_id=1, check_injection=False):
stubs.stubout_loopingcall_start(self.stubs) stubs.stubout_loopingcall_start(self.stubs)
values = {'id': instance_id, values = {'id': instance_id,
'project_id': self.project.id, 'project_id': self.project.id,
'user_id': self.user.id, 'user_id': self.user.id,
'image_id': image_id, 'image_ref': image_ref,
'kernel_id': kernel_id, 'kernel_id': kernel_id,
'ramdisk_id': ramdisk_id, 'ramdisk_id': ramdisk_id,
'instance_type_id': instance_type_id, 'instance_type_id': instance_type_id,
@@ -395,6 +394,29 @@ class XenAPIVMTestCase(test.TestCase):
os_type="linux") os_type="linux")
self.check_vm_params_for_linux() self.check_vm_params_for_linux()
def test_spawn_vhd_glance_swapdisk(self):
# Change the default host_call_plugin to one that'll return
# a swap disk
orig_func = stubs.FakeSessionForVMTests.host_call_plugin
stubs.FakeSessionForVMTests.host_call_plugin = \
stubs.FakeSessionForVMTests.host_call_plugin_swap
try:
# We'll steal the above glance linux test
self.test_spawn_vhd_glance_linux()
finally:
# Make sure to put this back
stubs.FakeSessionForVMTests.host_call_plugin = orig_func
# We should have 2 VBDs.
self.assertEqual(len(self.vm['VBDs']), 2)
# Now test that we have 1.
self.tearDown()
self.setUp()
self.test_spawn_vhd_glance_linux()
self.assertEqual(len(self.vm['VBDs']), 1)
def test_spawn_vhd_glance_windows(self): def test_spawn_vhd_glance_windows(self):
FLAGS.xenapi_image_service = 'glance' FLAGS.xenapi_image_service = 'glance'
self._test_spawn(glance_stubs.FakeGlance.IMAGE_VHD, None, None, self._test_spawn(glance_stubs.FakeGlance.IMAGE_VHD, None, None,
@@ -544,7 +566,7 @@ class XenAPIVMTestCase(test.TestCase):
'id': 1, 'id': 1,
'project_id': self.project.id, 'project_id': self.project.id,
'user_id': self.user.id, 'user_id': self.user.id,
'image_id': 1, 'image_ref': 1,
'kernel_id': 2, 'kernel_id': 2,
'ramdisk_id': 3, 'ramdisk_id': 3,
'instance_type_id': '3', # m1.large 'instance_type_id': '3', # m1.large
@@ -569,11 +591,29 @@ class XenAPIDiffieHellmanTestCase(test.TestCase):
bob_shared = self.bob.compute_shared(alice_pub) bob_shared = self.bob.compute_shared(alice_pub)
self.assertEquals(alice_shared, bob_shared) self.assertEquals(alice_shared, bob_shared)
def test_encryption(self): def _test_encryption(self, message):
msg = "This is a top-secret message" enc = self.alice.encrypt(message)
enc = self.alice.encrypt(msg) self.assertFalse(enc.endswith('\n'))
dec = self.bob.decrypt(enc) dec = self.bob.decrypt(enc)
self.assertEquals(dec, msg) self.assertEquals(dec, message)
def test_encrypt_simple_message(self):
self._test_encryption('This is a simple message.')
def test_encrypt_message_with_newlines_at_end(self):
self._test_encryption('This message has a newline at the end.\n')
def test_encrypt_many_newlines_at_end(self):
self._test_encryption('Message with lotsa newlines.\n\n\n')
def test_encrypt_newlines_inside_message(self):
self._test_encryption('Message\nwith\ninterior\nnewlines.')
def test_encrypt_with_leading_newlines(self):
self._test_encryption('\n\nMessage with leading newlines.')
def test_encrypt_really_long_message(self):
self._test_encryption(''.join(['abcd' for i in xrange(1024)]))
def tearDown(self): def tearDown(self):
super(XenAPIDiffieHellmanTestCase, self).tearDown() super(XenAPIDiffieHellmanTestCase, self).tearDown()
@@ -600,7 +640,7 @@ class XenAPIMigrateInstance(test.TestCase):
self.values = {'id': 1, self.values = {'id': 1,
'project_id': self.project.id, 'project_id': self.project.id,
'user_id': self.user.id, 'user_id': self.user.id,
'image_id': 1, 'image_ref': 1,
'kernel_id': None, 'kernel_id': None,
'ramdisk_id': None, 'ramdisk_id': None,
'local_gb': 5, 'local_gb': 5,
@@ -611,8 +651,7 @@ class XenAPIMigrateInstance(test.TestCase):
fake_utils.stub_out_utils_execute(self.stubs) fake_utils.stub_out_utils_execute(self.stubs)
stubs.stub_out_migration_methods(self.stubs) stubs.stub_out_migration_methods(self.stubs)
stubs.stubout_get_this_vm_uuid(self.stubs) stubs.stubout_get_this_vm_uuid(self.stubs)
glance_stubs.stubout_glance_client(self.stubs, glance_stubs.stubout_glance_client(self.stubs)
glance_stubs.FakeGlance)
def tearDown(self): def tearDown(self):
super(XenAPIMigrateInstance, self).tearDown() super(XenAPIMigrateInstance, self).tearDown()
@@ -638,8 +677,7 @@ class XenAPIDetermineDiskImageTestCase(test.TestCase):
"""Unit tests for code that detects the ImageType.""" """Unit tests for code that detects the ImageType."""
def setUp(self): def setUp(self):
super(XenAPIDetermineDiskImageTestCase, self).setUp() super(XenAPIDetermineDiskImageTestCase, self).setUp()
glance_stubs.stubout_glance_client(self.stubs, glance_stubs.stubout_glance_client(self.stubs)
glance_stubs.FakeGlance)
class FakeInstance(object): class FakeInstance(object):
pass pass
@@ -656,7 +694,7 @@ class XenAPIDetermineDiskImageTestCase(test.TestCase):
def test_instance_disk(self): def test_instance_disk(self):
"""If a kernel is specified, the image type is DISK (aka machine).""" """If a kernel is specified, the image type is DISK (aka machine)."""
FLAGS.xenapi_image_service = 'objectstore' FLAGS.xenapi_image_service = 'objectstore'
self.fake_instance.image_id = glance_stubs.FakeGlance.IMAGE_MACHINE self.fake_instance.image_ref = glance_stubs.FakeGlance.IMAGE_MACHINE
self.fake_instance.kernel_id = glance_stubs.FakeGlance.IMAGE_KERNEL self.fake_instance.kernel_id = glance_stubs.FakeGlance.IMAGE_KERNEL
self.assert_disk_type(vm_utils.ImageType.DISK) self.assert_disk_type(vm_utils.ImageType.DISK)
@@ -666,7 +704,7 @@ class XenAPIDetermineDiskImageTestCase(test.TestCase):
DISK_RAW is assumed. DISK_RAW is assumed.
""" """
FLAGS.xenapi_image_service = 'objectstore' FLAGS.xenapi_image_service = 'objectstore'
self.fake_instance.image_id = glance_stubs.FakeGlance.IMAGE_RAW self.fake_instance.image_ref = glance_stubs.FakeGlance.IMAGE_RAW
self.fake_instance.kernel_id = None self.fake_instance.kernel_id = None
self.assert_disk_type(vm_utils.ImageType.DISK_RAW) self.assert_disk_type(vm_utils.ImageType.DISK_RAW)
@@ -676,7 +714,7 @@ class XenAPIDetermineDiskImageTestCase(test.TestCase):
this case will be 'raw'. this case will be 'raw'.
""" """
FLAGS.xenapi_image_service = 'glance' FLAGS.xenapi_image_service = 'glance'
self.fake_instance.image_id = glance_stubs.FakeGlance.IMAGE_RAW self.fake_instance.image_ref = glance_stubs.FakeGlance.IMAGE_RAW
self.fake_instance.kernel_id = None self.fake_instance.kernel_id = None
self.assert_disk_type(vm_utils.ImageType.DISK_RAW) self.assert_disk_type(vm_utils.ImageType.DISK_RAW)
@@ -686,7 +724,7 @@ class XenAPIDetermineDiskImageTestCase(test.TestCase):
this case will be 'vhd'. this case will be 'vhd'.
""" """
FLAGS.xenapi_image_service = 'glance' FLAGS.xenapi_image_service = 'glance'
self.fake_instance.image_id = glance_stubs.FakeGlance.IMAGE_VHD self.fake_instance.image_ref = glance_stubs.FakeGlance.IMAGE_VHD
self.fake_instance.kernel_id = None self.fake_instance.kernel_id = None
self.assert_disk_type(vm_utils.ImageType.DISK_VHD) self.assert_disk_type(vm_utils.ImageType.DISK_VHD)

View File

@@ -61,7 +61,7 @@ def stub_out_db_instance_api(stubs):
'name': values['name'], 'name': values['name'],
'id': values['id'], 'id': values['id'],
'reservation_id': utils.generate_uid('r'), 'reservation_id': utils.generate_uid('r'),
'image_id': values['image_id'], 'image_ref': values['image_ref'],
'kernel_id': values['kernel_id'], 'kernel_id': values['kernel_id'],
'ramdisk_id': values['ramdisk_id'], 'ramdisk_id': values['ramdisk_id'],
'state_description': 'scheduling', 'state_description': 'scheduling',

View File

@@ -17,6 +17,7 @@
"""Stubouts, mocks and fixtures for the test suite""" """Stubouts, mocks and fixtures for the test suite"""
import eventlet import eventlet
import json
from nova.virt import xenapi_conn from nova.virt import xenapi_conn
from nova.virt.xenapi import fake from nova.virt.xenapi import fake
from nova.virt.xenapi import volume_utils from nova.virt.xenapi import volume_utils
@@ -37,21 +38,7 @@ def stubout_instance_snapshot(stubs):
sr_ref=sr_ref, sharable=False) sr_ref=sr_ref, sharable=False)
vdi_rec = session.get_xenapi().VDI.get_record(vdi_ref) vdi_rec = session.get_xenapi().VDI.get_record(vdi_ref)
vdi_uuid = vdi_rec['uuid'] vdi_uuid = vdi_rec['uuid']
return vdi_uuid return [dict(vdi_type='os', vdi_uuid=vdi_uuid)]
stubs.Set(vm_utils.VMHelper, 'fetch_image', fake_fetch_image)
def fake_wait_for_vhd_coalesce(session, instance_id, sr_ref, vdi_ref,
original_parent_uuid):
from nova.virt.xenapi.fake import create_vdi
name_label = "instance-%s" % instance_id
#TODO: create fake SR record
sr_ref = "fakesr"
vdi_ref = create_vdi(name_label=name_label, read_only=False,
sr_ref=sr_ref, sharable=False)
vdi_rec = session.get_xenapi().VDI.get_record(vdi_ref)
vdi_uuid = vdi_rec['uuid']
return vdi_uuid
stubs.Set(vm_utils.VMHelper, 'fetch_image', fake_fetch_image) stubs.Set(vm_utils.VMHelper, 'fetch_image', fake_fetch_image)
@@ -132,11 +119,30 @@ class FakeSessionForVMTests(fake.SessionBase):
def __init__(self, uri): def __init__(self, uri):
super(FakeSessionForVMTests, self).__init__(uri) super(FakeSessionForVMTests, self).__init__(uri)
def host_call_plugin(self, _1, _2, _3, _4, _5): def host_call_plugin(self, _1, _2, plugin, method, _5):
sr_ref = fake.get_all('SR')[0] sr_ref = fake.get_all('SR')[0]
vdi_ref = fake.create_vdi('', False, sr_ref, False) vdi_ref = fake.create_vdi('', False, sr_ref, False)
vdi_rec = fake.get_record('VDI', vdi_ref) vdi_rec = fake.get_record('VDI', vdi_ref)
return '<string>%s</string>' % vdi_rec['uuid'] if plugin == "glance" and method == "download_vhd":
ret_str = json.dumps([dict(vdi_type='os',
vdi_uuid=vdi_rec['uuid'])])
else:
ret_str = vdi_rec['uuid']
return '<string>%s</string>' % ret_str
def host_call_plugin_swap(self, _1, _2, plugin, method, _5):
sr_ref = fake.get_all('SR')[0]
vdi_ref = fake.create_vdi('', False, sr_ref, False)
vdi_rec = fake.get_record('VDI', vdi_ref)
if plugin == "glance" and method == "download_vhd":
swap_vdi_ref = fake.create_vdi('', False, sr_ref, False)
swap_vdi_rec = fake.get_record('VDI', swap_vdi_ref)
ret_str = json.dumps(
[dict(vdi_type='os', vdi_uuid=vdi_rec['uuid']),
dict(vdi_type='swap', vdi_uuid=swap_vdi_rec['uuid'])])
else:
ret_str = vdi_rec['uuid']
return '<string>%s</string>' % ret_str
def VM_start(self, _1, ref, _2, _3): def VM_start(self, _1, ref, _2, _3):
vm = fake.get_record('VM', ref) vm = fake.get_record('VM', ref)
@@ -231,10 +237,10 @@ class FakeSessionForMigrationTests(fake.SessionBase):
def __init__(self, uri): def __init__(self, uri):
super(FakeSessionForMigrationTests, self).__init__(uri) super(FakeSessionForMigrationTests, self).__init__(uri)
def VDI_get_by_uuid(*args): def VDI_get_by_uuid(self, *args):
return 'hurr' return 'hurr'
def VDI_resize_online(*args): def VDI_resize_online(self, *args):
pass pass
def VM_start(self, _1, ref, _2, _3): def VM_start(self, _1, ref, _2, _3):

View File

@@ -78,7 +78,7 @@ def WrapTwistedOptions(wrapped):
self._absorbParameters() self._absorbParameters()
self._absorbHandlers() self._absorbHandlers()
super(TwistedOptionsToFlags, self).__init__() wrapped.__init__(self)
def _absorbFlags(self): def _absorbFlags(self):
twistd_flags = [] twistd_flags = []
@@ -163,12 +163,12 @@ def WrapTwistedOptions(wrapped):
def parseArgs(self, *args): def parseArgs(self, *args):
# TODO(termie): figure out a decent way of dealing with args # TODO(termie): figure out a decent way of dealing with args
#return #return
super(TwistedOptionsToFlags, self).parseArgs(*args) wrapped.parseArgs(self, *args)
def postOptions(self): def postOptions(self):
self._doHandlers() self._doHandlers()
super(TwistedOptionsToFlags, self).postOptions() wrapped.postOptions(self)
def __getitem__(self, key): def __getitem__(self, key):
key = key.replace('-', '_') key = key.replace('-', '_')