Merge w/ trunk.

This commit is contained in:
Dan Prince
2011-05-26 09:05:14 -04:00
43 changed files with 2111 additions and 511 deletions

View File

@@ -4,6 +4,7 @@
<anotherjesse@gmail.com> <jesse@dancelamb> <anotherjesse@gmail.com> <jesse@dancelamb>
<anotherjesse@gmail.com> <jesse@gigantor.local> <anotherjesse@gmail.com> <jesse@gigantor.local>
<anotherjesse@gmail.com> <jesse@ubuntu> <anotherjesse@gmail.com> <jesse@ubuntu>
<anotherjesse@gmail.com> <jesse@aire.local>
<ant@openstack.org> <amesserl@rackspace.com> <ant@openstack.org> <amesserl@rackspace.com>
<Armando.Migliaccio@eu.citrix.com> <armando.migliaccio@citrix.com> <Armando.Migliaccio@eu.citrix.com> <armando.migliaccio@citrix.com>
<brian.lamar@rackspace.com> <brian.lamar@gmail.com> <brian.lamar@rackspace.com> <brian.lamar@gmail.com>
@@ -28,6 +29,7 @@
<matt.dietz@rackspace.com> <matthewdietz@Matthew-Dietzs-MacBook-Pro.local> <matt.dietz@rackspace.com> <matthewdietz@Matthew-Dietzs-MacBook-Pro.local>
<matt.dietz@rackspace.com> <mdietz@openstack> <matt.dietz@rackspace.com> <mdietz@openstack>
<mordred@inaugust.com> <mordred@hudson> <mordred@inaugust.com> <mordred@hudson>
<naveedm9@gmail.com> <naveed.massjouni@rackspace.com>
<nirmal.ranganathan@rackspace.com> <nirmal.ranganathan@rackspace.coom> <nirmal.ranganathan@rackspace.com> <nirmal.ranganathan@rackspace.coom>
<paul@openstack.org> <paul.voccio@rackspace.com> <paul@openstack.org> <paul.voccio@rackspace.com>
<paul@openstack.org> <pvoccio@castor.local> <paul@openstack.org> <pvoccio@castor.local>
@@ -35,6 +37,7 @@
<rlane@wikimedia.org> <laner@controller> <rlane@wikimedia.org> <laner@controller>
<sleepsonthefloor@gmail.com> <root@tonbuntu> <sleepsonthefloor@gmail.com> <root@tonbuntu>
<soren.hansen@rackspace.com> <soren@linux2go.dk> <soren.hansen@rackspace.com> <soren@linux2go.dk>
<throughnothing@gmail.com> <will.wolf@rackspace.com>
<todd@ansolabs.com> <todd@lapex> <todd@ansolabs.com> <todd@lapex>
<todd@ansolabs.com> <todd@rubidine.com> <todd@ansolabs.com> <todd@rubidine.com>
<tushar.vitthal.patil@gmail.com> <tpatil@vertex.co.in> <tushar.vitthal.patil@gmail.com> <tpatil@vertex.co.in>
@@ -43,5 +46,4 @@
<ueno.nachi@lab.ntt.co.jp> <openstack@lab.ntt.co.jp> <ueno.nachi@lab.ntt.co.jp> <openstack@lab.ntt.co.jp>
<vishvananda@gmail.com> <root@mirror.nasanebula.net> <vishvananda@gmail.com> <root@mirror.nasanebula.net>
<vishvananda@gmail.com> <root@ubuntu> <vishvananda@gmail.com> <root@ubuntu>
<naveedm9@gmail.com> <naveed.massjouni@rackspace.com>
<vishvananda@gmail.com> <vishvananda@yahoo.com> <vishvananda@gmail.com> <vishvananda@yahoo.com>

12
Authors
View File

@@ -1,3 +1,5 @@
Alex Meade <alex.meade@rackspace.com>
Andrey Brindeyev <abrindeyev@griddynamics.com>
Andy Smith <code@term.ie> Andy Smith <code@term.ie>
Andy Southgate <andy.southgate@citrix.com> Andy Southgate <andy.southgate@citrix.com>
Anne Gentle <anne@openstack.org> Anne Gentle <anne@openstack.org>
@@ -15,6 +17,7 @@ Christian Berendt <berendt@b1-systems.de>
Chuck Short <zulcss@ubuntu.com> Chuck Short <zulcss@ubuntu.com>
Cory Wright <corywright@gmail.com> Cory Wright <corywright@gmail.com>
Dan Prince <dan.prince@rackspace.com> Dan Prince <dan.prince@rackspace.com>
Dave Walker <DaveWalker@ubuntu.com>
David Pravec <David.Pravec@danix.org> David Pravec <David.Pravec@danix.org>
Dean Troyer <dtroyer@gmail.com> Dean Troyer <dtroyer@gmail.com>
Devin Carlen <devin.carlen@gmail.com> Devin Carlen <devin.carlen@gmail.com>
@@ -27,8 +30,10 @@ Gabe Westmaas <gabe.westmaas@rackspace.com>
Hisaharu Ishii <ishii.hisaharu@lab.ntt.co.jp> Hisaharu Ishii <ishii.hisaharu@lab.ntt.co.jp>
Hisaki Ohara <hisaki.ohara@intel.com> Hisaki Ohara <hisaki.ohara@intel.com>
Ilya Alekseyev <ialekseev@griddynamics.com> Ilya Alekseyev <ialekseev@griddynamics.com>
Jason Koelker <jason@koelker.net>
Jay Pipes <jaypipes@gmail.com> Jay Pipes <jaypipes@gmail.com>
Jesse Andrews <anotherjesse@gmail.com> Jesse Andrews <anotherjesse@gmail.com>
Jimmy Bergman <jimmy@sigint.se>
Joe Heck <heckj@mac.com> Joe Heck <heckj@mac.com>
Joel Moore <joelbm24@gmail.com> Joel Moore <joelbm24@gmail.com>
Johannes Erdfelt <johannes.erdfelt@rackspace.com> Johannes Erdfelt <johannes.erdfelt@rackspace.com>
@@ -41,11 +46,14 @@ Josh Kearney <josh@jk0.org>
Josh Kleinpeter <josh@kleinpeter.org> Josh Kleinpeter <josh@kleinpeter.org>
Joshua McKenty <jmckenty@gmail.com> Joshua McKenty <jmckenty@gmail.com>
Justin Santa Barbara <justin@fathomdb.com> Justin Santa Barbara <justin@fathomdb.com>
Justin Shepherd <jshepher@rackspace.com>
Kei Masumoto <masumotok@nttdata.co.jp> Kei Masumoto <masumotok@nttdata.co.jp>
Ken Pepple <ken.pepple@gmail.com> Ken Pepple <ken.pepple@gmail.com>
Kevin Bringard <kbringard@attinteractive.com>
Kevin L. Mitchell <kevin.mitchell@rackspace.com> Kevin L. Mitchell <kevin.mitchell@rackspace.com>
Koji Iida <iida.koji@lab.ntt.co.jp> Koji Iida <iida.koji@lab.ntt.co.jp>
Lorin Hochstein <lorin@isi.edu> Lorin Hochstein <lorin@isi.edu>
Lvov Maxim <usrleon@gmail.com>
Mark Washenberger <mark.washenberger@rackspace.com> Mark Washenberger <mark.washenberger@rackspace.com>
Masanori Itoh <itoumsn@nttdata.co.jp> Masanori Itoh <itoumsn@nttdata.co.jp>
Matt Dietz <matt.dietz@rackspace.com> Matt Dietz <matt.dietz@rackspace.com>
@@ -58,6 +66,7 @@ Nachi Ueno <ueno.nachi@lab.ntt.co.jp>
Naveed Massjouni <naveedm9@gmail.com> Naveed Massjouni <naveedm9@gmail.com>
Nirmal Ranganathan <nirmal.ranganathan@rackspace.com> Nirmal Ranganathan <nirmal.ranganathan@rackspace.com>
Paul Voccio <paul@openstack.org> Paul Voccio <paul@openstack.org>
Renuka Apte <renuka.apte@citrix.com>
Ricardo Carrillo Cruz <emaildericky@gmail.com> Ricardo Carrillo Cruz <emaildericky@gmail.com>
Rick Clark <rick@openstack.org> Rick Clark <rick@openstack.org>
Rick Harris <rconradharris@gmail.com> Rick Harris <rconradharris@gmail.com>
@@ -74,5 +83,8 @@ Trey Morris <trey.morris@rackspace.com>
Tushar Patil <tushar.vitthal.patil@gmail.com> Tushar Patil <tushar.vitthal.patil@gmail.com>
Vasiliy Shlykov <vash@vasiliyshlykov.org> Vasiliy Shlykov <vash@vasiliyshlykov.org>
Vishvananda Ishaya <vishvananda@gmail.com> Vishvananda Ishaya <vishvananda@gmail.com>
William Wolf <throughnothing@gmail.com>
Yoshiaki Tamura <yoshi@midokura.jp>
Youcef Laribi <Youcef.Laribi@eu.citrix.com> Youcef Laribi <Youcef.Laribi@eu.citrix.com>
Yuriy Taraday <yorik.sar@gmail.com>
Zhixue Wu <Zhixue.Wu@citrix.com> Zhixue Wu <Zhixue.Wu@citrix.com>

17
HACKING
View File

@@ -50,17 +50,24 @@ Human Alphabetical Order Examples
Docstrings Docstrings
---------- ----------
"""Summary of the function, class or method, less than 80 characters. """A one line docstring looks like this and ends in a period."""
New paragraph after newline that explains in more detail any general
information about the function, class or method. After this, if defining """A multiline docstring has a one-line summary, less than 80 characters.
parameters and return types use the Sphinx format. After that an extra
newline then close the quotations. Then a new paragraph after a newline that explains in more detail any
general information about the function, class or method. Example usages
are also great to have here if it is a complex class for function. After
you have finished your descriptions add an extra newline and close the
quotations.
When writing the docstring for a class, an extra line should be placed When writing the docstring for a class, an extra line should be placed
after the closing quotations. For more in-depth explanations for these after the closing quotations. For more in-depth explanations for these
decisions see http://www.python.org/dev/peps/pep-0257/ decisions see http://www.python.org/dev/peps/pep-0257/
If you are going to describe parameters and return values, use Sphinx, the
appropriate syntax is as follows.
:param foo: the foo parameter :param foo: the foo parameter
:param bar: the bar parameter :param bar: the bar parameter
:returns: description of the return value :returns: description of the return value

View File

@@ -35,6 +35,7 @@ include nova/tests/bundle/1mb.manifest.xml
include nova/tests/bundle/1mb.no_kernel_or_ramdisk.manifest.xml include nova/tests/bundle/1mb.no_kernel_or_ramdisk.manifest.xml
include nova/tests/bundle/1mb.part.0 include nova/tests/bundle/1mb.part.0
include nova/tests/bundle/1mb.part.1 include nova/tests/bundle/1mb.part.1
include nova/tests/public_key/*
include nova/tests/db/nova.austin.sqlite include nova/tests/db/nova.austin.sqlite
include plugins/xenapi/README include plugins/xenapi/README
include plugins/xenapi/etc/xapi.d/plugins/objectstore include plugins/xenapi/etc/xapi.d/plugins/objectstore

View File

@@ -28,11 +28,11 @@ import sys
# If ../nova/__init__.py exists, add ../ to Python search path, so that # If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python... # it will override what happens to be installed in /usr/(local/)lib/python...
possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]), POSSIBLE_TOPDIR = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
os.pardir, os.pardir,
os.pardir)) os.pardir))
if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')): if os.path.exists(os.path.join(POSSIBLE_TOPDIR, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir) sys.path.insert(0, POSSIBLE_TOPDIR)
gettext.install('nova', unicode=1) gettext.install('nova', unicode=1)

View File

@@ -108,6 +108,13 @@ def main():
interface = os.environ.get('DNSMASQ_INTERFACE', FLAGS.dnsmasq_interface) interface = os.environ.get('DNSMASQ_INTERFACE', FLAGS.dnsmasq_interface)
if int(os.environ.get('TESTING', '0')): if int(os.environ.get('TESTING', '0')):
from nova.tests import fake_flags from nova.tests import fake_flags
#if FLAGS.fake_rabbit:
# LOG.debug(_("leasing ip"))
# network_manager = utils.import_object(FLAGS.network_manager)
## reload(fake_flags)
# from nova.tests import fake_flags
action = argv[1] action = argv[1]
if action in ['add', 'del', 'old']: if action in ['add', 'del', 'old']:
mac = argv[2] mac = argv[2]

View File

@@ -58,7 +58,6 @@ import gettext
import glob import glob
import json import json
import os import os
import re
import sys import sys
import time import time
@@ -66,11 +65,11 @@ import IPy
# If ../nova/__init__.py exists, add ../ to Python search path, so that # If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python... # it will override what happens to be installed in /usr/(local/)lib/python...
possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]), POSSIBLE_TOPDIR = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
os.pardir, os.pardir,
os.pardir)) os.pardir))
if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')): if os.path.exists(os.path.join(POSSIBLE_TOPDIR, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir) sys.path.insert(0, POSSIBLE_TOPDIR)
gettext.install('nova', unicode=1) gettext.install('nova', unicode=1)
@@ -83,6 +82,7 @@ from nova import log as logging
from nova import quota from nova import quota
from nova import rpc from nova import rpc
from nova import utils from nova import utils
from nova import version
from nova.api.ec2 import ec2utils from nova.api.ec2 import ec2utils
from nova.auth import manager from nova.auth import manager
from nova.cloudpipe import pipelib from nova.cloudpipe import pipelib
@@ -97,7 +97,7 @@ flags.DECLARE('vlan_start', 'nova.network.manager')
flags.DECLARE('vpn_start', 'nova.network.manager') flags.DECLARE('vpn_start', 'nova.network.manager')
flags.DECLARE('fixed_range_v6', 'nova.network.manager') flags.DECLARE('fixed_range_v6', 'nova.network.manager')
flags.DECLARE('images_path', 'nova.image.local') flags.DECLARE('images_path', 'nova.image.local')
flags.DECLARE('libvirt_type', 'nova.virt.libvirt_conn') flags.DECLARE('libvirt_type', 'nova.virt.libvirt.connection')
flags.DEFINE_flag(flags.HelpFlag()) flags.DEFINE_flag(flags.HelpFlag())
flags.DEFINE_flag(flags.HelpshortFlag()) flags.DEFINE_flag(flags.HelpshortFlag())
flags.DEFINE_flag(flags.HelpXMLFlag()) flags.DEFINE_flag(flags.HelpXMLFlag())
@@ -151,7 +151,7 @@ class VpnCommands(object):
state = 'up' state = 'up'
print address, print address,
print vpn['host'], print vpn['host'],
print vpn['ec2_id'], print ec2utils.id_to_ec2_id(vpn['id']),
print vpn['state_description'], print vpn['state_description'],
print state print state
else: else:
@@ -362,34 +362,54 @@ class ProjectCommands(object):
def add(self, project_id, user_id): def add(self, project_id, user_id):
"""Adds user to project """Adds user to project
arguments: project_id user_id""" arguments: project_id user_id"""
self.manager.add_to_project(user_id, project_id) try:
self.manager.add_to_project(user_id, project_id)
except exception.UserNotFound as ex:
print ex
raise
def create(self, name, project_manager, description=None): def create(self, name, project_manager, description=None):
"""Creates a new project """Creates a new project
arguments: name project_manager [description]""" arguments: name project_manager [description]"""
self.manager.create_project(name, project_manager, description) try:
self.manager.create_project(name, project_manager, description)
except exception.UserNotFound as ex:
print ex
raise
def modify(self, name, project_manager, description=None): def modify(self, name, project_manager, description=None):
"""Modifies a project """Modifies a project
arguments: name project_manager [description]""" arguments: name project_manager [description]"""
self.manager.modify_project(name, project_manager, description) try:
self.manager.modify_project(name, project_manager, description)
except exception.UserNotFound as ex:
print ex
raise
def delete(self, name): def delete(self, name):
"""Deletes an existing project """Deletes an existing project
arguments: name""" arguments: name"""
self.manager.delete_project(name) try:
self.manager.delete_project(name)
except exception.ProjectNotFound as ex:
print ex
raise
def environment(self, project_id, user_id, filename='novarc'): def environment(self, project_id, user_id, filename='novarc'):
"""Exports environment variables to an sourcable file """Exports environment variables to an sourcable file
arguments: project_id user_id [filename='novarc]""" arguments: project_id user_id [filename='novarc]"""
rc = self.manager.get_environment_rc(user_id, project_id) try:
rc = self.manager.get_environment_rc(user_id, project_id)
except (exception.UserNotFound, exception.ProjectNotFound) as ex:
print ex
raise
with open(filename, 'w') as f: with open(filename, 'w') as f:
f.write(rc) f.write(rc)
def list(self): def list(self, username=None):
"""Lists all projects """Lists all projects
arguments: <none>""" arguments: [username]"""
for project in self.manager.get_projects(): for project in self.manager.get_projects(username):
print project.name print project.name
def quota(self, project_id, key=None, value=None): def quota(self, project_id, key=None, value=None):
@@ -397,19 +417,26 @@ class ProjectCommands(object):
arguments: project_id [key] [value]""" arguments: project_id [key] [value]"""
ctxt = context.get_admin_context() ctxt = context.get_admin_context()
if key: if key:
quo = {'project_id': project_id, key: value} if value.lower() == 'unlimited':
value = None
try: try:
db.quota_update(ctxt, project_id, quo) db.quota_update(ctxt, project_id, key, value)
except exception.NotFound: except exception.ProjectQuotaNotFound:
db.quota_create(ctxt, quo) db.quota_create(ctxt, project_id, key, value)
project_quota = quota.get_quota(ctxt, project_id) project_quota = quota.get_project_quotas(ctxt, project_id)
for key, value in project_quota.iteritems(): for key, value in project_quota.iteritems():
if value is None:
value = 'unlimited'
print '%s: %s' % (key, value) print '%s: %s' % (key, value)
def remove(self, project_id, user_id): def remove(self, project_id, user_id):
"""Removes user from project """Removes user from project
arguments: project_id user_id""" arguments: project_id user_id"""
self.manager.remove_from_project(user_id, project_id) try:
self.manager.remove_from_project(user_id, project_id)
except (exception.UserNotFound, exception.ProjectNotFound) as ex:
print ex
raise
def scrub(self, project_id): def scrub(self, project_id):
"""Deletes data associated with project """Deletes data associated with project
@@ -428,6 +455,9 @@ class ProjectCommands(object):
zip_file = self.manager.get_credentials(user_id, project_id) zip_file = self.manager.get_credentials(user_id, project_id)
with open(filename, 'w') as f: with open(filename, 'w') as f:
f.write(zip_file) f.write(zip_file)
except (exception.UserNotFound, exception.ProjectNotFound) as ex:
print ex
raise
except db.api.NoMoreNetworks: except db.api.NoMoreNetworks:
print _('No more networks available. If this is a new ' print _('No more networks available. If this is a new '
'installation, you need\nto call something like this:\n\n' 'installation, you need\nto call something like this:\n\n'
@@ -449,7 +479,7 @@ class FixedIpCommands(object):
ctxt = context.get_admin_context() ctxt = context.get_admin_context()
try: try:
if host == None: if host is None:
fixed_ips = db.fixed_ip_get_all(ctxt) fixed_ips = db.fixed_ip_get_all(ctxt)
else: else:
fixed_ips = db.fixed_ip_get_all_by_host(ctxt, host) fixed_ips = db.fixed_ip_get_all_by_host(ctxt, host)
@@ -499,7 +529,7 @@ class FloatingIpCommands(object):
"""Lists all floating ips (optionally by host) """Lists all floating ips (optionally by host)
arguments: [host]""" arguments: [host]"""
ctxt = context.get_admin_context() ctxt = context.get_admin_context()
if host == None: if host is None:
floating_ips = db.floating_ip_get_all(ctxt) floating_ips = db.floating_ip_get_all(ctxt)
else: else:
floating_ips = db.floating_ip_get_all_by_host(ctxt, host) floating_ips = db.floating_ip_get_all_by_host(ctxt, host)
@@ -523,8 +553,10 @@ class NetworkCommands(object):
[network_size=FLAG], [vlan_start=FLAG], [network_size=FLAG], [vlan_start=FLAG],
[vpn_start=FLAG], [fixed_range_v6=FLAG]""" [vpn_start=FLAG], [fixed_range_v6=FLAG]"""
if not fixed_range: if not fixed_range:
raise TypeError(_('Fixed range in the form of 10.0.0.0/8 is ' msg = _('Fixed range in the form of 10.0.0.0/8 is '
'required to create networks.')) 'required to create networks.')
print msg
raise TypeError(msg)
if not num_networks: if not num_networks:
num_networks = FLAGS.num_networks num_networks = FLAGS.num_networks
if not network_size: if not network_size:
@@ -536,14 +568,18 @@ class NetworkCommands(object):
if not fixed_range_v6: if not fixed_range_v6:
fixed_range_v6 = FLAGS.fixed_range_v6 fixed_range_v6 = FLAGS.fixed_range_v6
net_manager = utils.import_object(FLAGS.network_manager) net_manager = utils.import_object(FLAGS.network_manager)
net_manager.create_networks(context.get_admin_context(), try:
cidr=fixed_range, net_manager.create_networks(context.get_admin_context(),
num_networks=int(num_networks), cidr=fixed_range,
network_size=int(network_size), num_networks=int(num_networks),
vlan_start=int(vlan_start), network_size=int(network_size),
vpn_start=int(vpn_start), vlan_start=int(vlan_start),
cidr_v6=fixed_range_v6, vpn_start=int(vpn_start),
label=label) cidr_v6=fixed_range_v6,
label=label)
except ValueError, e:
print e
raise e
def list(self): def list(self):
"""List all created networks""" """List all created networks"""
@@ -591,7 +627,7 @@ class VmCommands(object):
_('zone'), _('zone'),
_('index')) _('index'))
if host == None: if host is None:
instances = db.instance_get_all(context.get_admin_context()) instances = db.instance_get_all(context.get_admin_context())
else: else:
instances = db.instance_get_all_by_host( instances = db.instance_get_all_by_host(
@@ -759,6 +795,17 @@ class DbCommands(object):
print migration.db_version() print migration.db_version()
class VersionCommands(object):
"""Class for exposing the codebase version."""
def __init__(self):
pass
def list(self):
print _("%s (%s)") %\
(version.version_string(), version.version_string_with_vcs())
class VolumeCommands(object): class VolumeCommands(object):
"""Methods for dealing with a cloud in an odd state""" """Methods for dealing with a cloud in an odd state"""
@@ -809,11 +856,11 @@ class VolumeCommands(object):
class InstanceTypeCommands(object): class InstanceTypeCommands(object):
"""Class for managing instance types / flavors.""" """Class for managing instance types / flavors."""
def _print_instance_types(self, n, val): def _print_instance_types(self, name, val):
deleted = ('', ', inactive')[val["deleted"] == 1] deleted = ('', ', inactive')[val["deleted"] == 1]
print ("%s: Memory: %sMB, VCPUS: %s, Storage: %sGB, FlavorID: %s, " print ("%s: Memory: %sMB, VCPUS: %s, Storage: %sGB, FlavorID: %s, "
"Swap: %sGB, RXTX Quota: %sGB, RXTX Cap: %sMB%s") % ( "Swap: %sGB, RXTX Quota: %sGB, RXTX Cap: %sMB%s") % (
n, val["memory_mb"], val["vcpus"], val["local_gb"], name, val["memory_mb"], val["vcpus"], val["local_gb"],
val["flavorid"], val["swap"], val["rxtx_quota"], val["flavorid"], val["swap"], val["rxtx_quota"],
val["rxtx_cap"], deleted) val["rxtx_cap"], deleted)
@@ -827,11 +874,17 @@ class InstanceTypeCommands(object):
instance_types.create(name, memory, vcpus, local_gb, instance_types.create(name, memory, vcpus, local_gb,
flavorid, swap, rxtx_quota, rxtx_cap) flavorid, swap, rxtx_quota, rxtx_cap)
except exception.InvalidInputException: except exception.InvalidInputException:
print "Must supply valid parameters to create instance type" print "Must supply valid parameters to create instance_type"
print e print e
sys.exit(1) sys.exit(1)
except exception.DBError, e: except exception.ApiError, e:
print "DB Error: %s" % e print "\n\n"
print "\n%s" % e
print "Please ensure instance_type name and flavorid are unique."
print "To complete remove a instance_type, use the --purge flag:"
print "\n # nova-manage instance_type delete <name> --purge\n"
print "Currently defined instance_type names and flavorids:"
self.list("--all")
sys.exit(2) sys.exit(2)
except: except:
print "Unknown error" print "Unknown error"
@@ -864,7 +917,7 @@ class InstanceTypeCommands(object):
"""Lists all active or specific instance types / flavors """Lists all active or specific instance types / flavors
arguments: [name]""" arguments: [name]"""
try: try:
if name == None: if name is None:
inst_types = instance_types.get_all_types() inst_types = instance_types.get_all_types()
elif name == "--all": elif name == "--all":
inst_types = instance_types.get_all_types(True) inst_types = instance_types.get_all_types(True)
@@ -955,7 +1008,7 @@ class ImageCommands(object):
try: try:
internal_id = ec2utils.ec2_id_to_id(old_image_id) internal_id = ec2utils.ec2_id_to_id(old_image_id)
image = self.image_service.show(context, internal_id) image = self.image_service.show(context, internal_id)
except exception.NotFound: except (exception.InvalidEc2Id, exception.ImageNotFound):
image = self.image_service.show_by_name(context, old_image_id) image = self.image_service.show_by_name(context, old_image_id)
return image['id'] return image['id']
@@ -1009,7 +1062,7 @@ class ImageCommands(object):
if (FLAGS.image_service == 'nova.image.local.LocalImageService' if (FLAGS.image_service == 'nova.image.local.LocalImageService'
and directory == os.path.abspath(FLAGS.images_path)): and directory == os.path.abspath(FLAGS.images_path)):
new_dir = "%s_bak" % directory new_dir = "%s_bak" % directory
os.move(directory, new_dir) os.rename(directory, new_dir)
os.mkdir(directory) os.mkdir(directory)
directory = new_dir directory = new_dir
for fn in glob.glob("%s/*/info.json" % directory): for fn in glob.glob("%s/*/info.json" % directory):
@@ -1021,7 +1074,7 @@ class ImageCommands(object):
machine_images[image_path] = image_metadata machine_images[image_path] = image_metadata
else: else:
other_images[image_path] = image_metadata other_images[image_path] = image_metadata
except Exception as exc: except Exception:
print _("Failed to load %(fn)s.") % locals() print _("Failed to load %(fn)s.") % locals()
# NOTE(vish): do kernels and ramdisks first so images # NOTE(vish): do kernels and ramdisks first so images
self._convert_images(other_images) self._convert_images(other_images)
@@ -1044,7 +1097,8 @@ CATEGORIES = [
('volume', VolumeCommands), ('volume', VolumeCommands),
('instance_type', InstanceTypeCommands), ('instance_type', InstanceTypeCommands),
('image', ImageCommands), ('image', ImageCommands),
('flavor', InstanceTypeCommands)] ('flavor', InstanceTypeCommands),
('version', VersionCommands)]
def lazy_match(name, key_value_tuples): def lazy_match(name, key_value_tuples):
@@ -1086,6 +1140,8 @@ def main():
script_name = argv.pop(0) script_name = argv.pop(0)
if len(argv) < 1: if len(argv) < 1:
print _("\nOpenStack Nova version: %s (%s)\n") %\
(version.version_string(), version.version_string_with_vcs())
print script_name + " category action [<args>]" print script_name + " category action [<args>]"
print _("Available categories:") print _("Available categories:")
for k, _v in CATEGORIES: for k, _v in CATEGORIES:

View File

@@ -65,7 +65,7 @@ def format_help(d):
indent = MAX_INDENT - 6 indent = MAX_INDENT - 6
out = [] out = []
for k, v in d.iteritems(): for k, v in sorted(d.iteritems()):
if (len(k) + 6) > MAX_INDENT: if (len(k) + 6) > MAX_INDENT:
out.extend([' %s' % k]) out.extend([' %s' % k])
initial_indent = ' ' * (indent + 6) initial_indent = ' ' * (indent + 6)

View File

@@ -81,7 +81,7 @@ class DbDriver(object):
user_ref = db.user_create(context.get_admin_context(), values) user_ref = db.user_create(context.get_admin_context(), values)
return self._db_user_to_auth_user(user_ref) return self._db_user_to_auth_user(user_ref)
except exception.Duplicate, e: except exception.Duplicate, e:
raise exception.Duplicate(_('User %s already exists') % name) raise exception.UserExists(user=name)
def _db_user_to_auth_user(self, user_ref): def _db_user_to_auth_user(self, user_ref):
return {'id': user_ref['id'], return {'id': user_ref['id'],
@@ -103,9 +103,7 @@ class DbDriver(object):
"""Create a project""" """Create a project"""
manager = db.user_get(context.get_admin_context(), manager_uid) manager = db.user_get(context.get_admin_context(), manager_uid)
if not manager: if not manager:
raise exception.NotFound(_("Project can't be created because " raise exception.UserNotFound(user_id=manager_uid)
"manager %s doesn't exist")
% manager_uid)
# description is a required attribute # description is a required attribute
if description is None: if description is None:
@@ -115,13 +113,11 @@ class DbDriver(object):
# on to create the project. This way we won't have to destroy # on to create the project. This way we won't have to destroy
# the project again because a user turns out to be invalid. # the project again because a user turns out to be invalid.
members = set([manager]) members = set([manager])
if member_uids != None: if member_uids is not None:
for member_uid in member_uids: for member_uid in member_uids:
member = db.user_get(context.get_admin_context(), member_uid) member = db.user_get(context.get_admin_context(), member_uid)
if not member: if not member:
raise exception.NotFound(_("Project can't be created " raise exception.UserNotFound(user_id=member_uid)
"because user %s doesn't exist")
% member_uid)
members.add(member) members.add(member)
values = {'id': name, values = {'id': name,
@@ -132,8 +128,7 @@ class DbDriver(object):
try: try:
project = db.project_create(context.get_admin_context(), values) project = db.project_create(context.get_admin_context(), values)
except exception.Duplicate: except exception.Duplicate:
raise exception.Duplicate(_("Project can't be created because " raise exception.ProjectExists(project=name)
"project %s already exists") % name)
for member in members: for member in members:
db.project_add_member(context.get_admin_context(), db.project_add_member(context.get_admin_context(),
@@ -154,9 +149,7 @@ class DbDriver(object):
if manager_uid: if manager_uid:
manager = db.user_get(context.get_admin_context(), manager_uid) manager = db.user_get(context.get_admin_context(), manager_uid)
if not manager: if not manager:
raise exception.NotFound(_("Project can't be modified because " raise exception.UserNotFound(user_id=manager_uid)
"manager %s doesn't exist") %
manager_uid)
values['project_manager'] = manager['id'] values['project_manager'] = manager['id']
if description: if description:
values['description'] = description values['description'] = description
@@ -244,8 +237,8 @@ class DbDriver(object):
def _validate_user_and_project(self, user_id, project_id): def _validate_user_and_project(self, user_id, project_id):
user = db.user_get(context.get_admin_context(), user_id) user = db.user_get(context.get_admin_context(), user_id)
if not user: if not user:
raise exception.NotFound(_('User "%s" not found') % user_id) raise exception.UserNotFound(user_id=user_id)
project = db.project_get(context.get_admin_context(), project_id) project = db.project_get(context.get_admin_context(), project_id)
if not project: if not project:
raise exception.NotFound(_('Project "%s" not found') % project_id) raise exception.ProjectNotFound(project_id=project_id)
return user, project return user, project

View File

@@ -171,7 +171,7 @@ class LdapDriver(object):
def create_user(self, name, access_key, secret_key, is_admin): def create_user(self, name, access_key, secret_key, is_admin):
"""Create a user""" """Create a user"""
if self.__user_exists(name): if self.__user_exists(name):
raise exception.Duplicate(_("LDAP user %s already exists") % name) raise exception.LDAPUserExists(user=name)
if FLAGS.ldap_user_modify_only: if FLAGS.ldap_user_modify_only:
if self.__ldap_user_exists(name): if self.__ldap_user_exists(name):
# Retrieve user by name # Retrieve user by name
@@ -202,8 +202,7 @@ class LdapDriver(object):
self.conn.modify_s(self.__uid_to_dn(name), attr) self.conn.modify_s(self.__uid_to_dn(name), attr)
return self.get_user(name) return self.get_user(name)
else: else:
raise exception.NotFound(_("LDAP object for %s doesn't exist") raise exception.LDAPUserNotFound(user_id=name)
% name)
else: else:
attr = [ attr = [
('objectclass', ['person', ('objectclass', ['person',
@@ -226,12 +225,9 @@ class LdapDriver(object):
description=None, member_uids=None): description=None, member_uids=None):
"""Create a project""" """Create a project"""
if self.__project_exists(name): if self.__project_exists(name):
raise exception.Duplicate(_("Project can't be created because " raise exception.ProjectExists(project=name)
"project %s already exists") % name)
if not self.__user_exists(manager_uid): if not self.__user_exists(manager_uid):
raise exception.NotFound(_("Project can't be created because " raise exception.LDAPUserNotFound(user_id=manager_uid)
"manager %s doesn't exist")
% manager_uid)
manager_dn = self.__uid_to_dn(manager_uid) manager_dn = self.__uid_to_dn(manager_uid)
# description is a required attribute # description is a required attribute
if description is None: if description is None:
@@ -240,9 +236,7 @@ class LdapDriver(object):
if member_uids is not None: if member_uids is not None:
for member_uid in member_uids: for member_uid in member_uids:
if not self.__user_exists(member_uid): if not self.__user_exists(member_uid):
raise exception.NotFound(_("Project can't be created " raise exception.LDAPUserNotFound(user_id=member_uid)
"because user %s doesn't exist")
% member_uid)
members.append(self.__uid_to_dn(member_uid)) members.append(self.__uid_to_dn(member_uid))
# always add the manager as a member because members is required # always add the manager as a member because members is required
if not manager_dn in members: if not manager_dn in members:
@@ -265,9 +259,7 @@ class LdapDriver(object):
attr = [] attr = []
if manager_uid: if manager_uid:
if not self.__user_exists(manager_uid): if not self.__user_exists(manager_uid):
raise exception.NotFound(_("Project can't be modified because " raise exception.LDAPUserNotFound(user_id=manager_uid)
"manager %s doesn't exist")
% manager_uid)
manager_dn = self.__uid_to_dn(manager_uid) manager_dn = self.__uid_to_dn(manager_uid)
attr.append((self.ldap.MOD_REPLACE, LdapDriver.project_attribute, attr.append((self.ldap.MOD_REPLACE, LdapDriver.project_attribute,
manager_dn)) manager_dn))
@@ -347,7 +339,7 @@ class LdapDriver(object):
def delete_user(self, uid): def delete_user(self, uid):
"""Delete a user""" """Delete a user"""
if not self.__user_exists(uid): if not self.__user_exists(uid):
raise exception.NotFound(_("User %s doesn't exist") % uid) raise exception.LDAPUserNotFound(user_id=uid)
self.__remove_from_all(uid) self.__remove_from_all(uid)
if FLAGS.ldap_user_modify_only: if FLAGS.ldap_user_modify_only:
# Delete attributes # Delete attributes
@@ -471,15 +463,12 @@ class LdapDriver(object):
description, member_uids=None): description, member_uids=None):
"""Create a group""" """Create a group"""
if self.__group_exists(group_dn): if self.__group_exists(group_dn):
raise exception.Duplicate(_("Group can't be created because " raise exception.LDAPGroupExists(group=name)
"group %s already exists") % name)
members = [] members = []
if member_uids is not None: if member_uids is not None:
for member_uid in member_uids: for member_uid in member_uids:
if not self.__user_exists(member_uid): if not self.__user_exists(member_uid):
raise exception.NotFound(_("Group can't be created " raise exception.LDAPUserNotFound(user_id=member_uid)
"because user %s doesn't exist")
% member_uid)
members.append(self.__uid_to_dn(member_uid)) members.append(self.__uid_to_dn(member_uid))
dn = self.__uid_to_dn(uid) dn = self.__uid_to_dn(uid)
if not dn in members: if not dn in members:
@@ -494,8 +483,7 @@ class LdapDriver(object):
def __is_in_group(self, uid, group_dn): def __is_in_group(self, uid, group_dn):
"""Check if user is in group""" """Check if user is in group"""
if not self.__user_exists(uid): if not self.__user_exists(uid):
raise exception.NotFound(_("User %s can't be searched in group " raise exception.LDAPUserNotFound(user_id=uid)
"because the user doesn't exist") % uid)
if not self.__group_exists(group_dn): if not self.__group_exists(group_dn):
return False return False
res = self.__find_object(group_dn, res = self.__find_object(group_dn,
@@ -506,29 +494,23 @@ class LdapDriver(object):
def __add_to_group(self, uid, group_dn): def __add_to_group(self, uid, group_dn):
"""Add user to group""" """Add user to group"""
if not self.__user_exists(uid): if not self.__user_exists(uid):
raise exception.NotFound(_("User %s can't be added to the group " raise exception.LDAPUserNotFound(user_id=uid)
"because the user doesn't exist") % uid)
if not self.__group_exists(group_dn): if not self.__group_exists(group_dn):
raise exception.NotFound(_("The group at dn %s doesn't exist") % raise exception.LDAPGroupNotFound(group_id=group_dn)
group_dn)
if self.__is_in_group(uid, group_dn): if self.__is_in_group(uid, group_dn):
raise exception.Duplicate(_("User %(uid)s is already a member of " raise exception.LDAPMembershipExists(uid=uid, group_dn=group_dn)
"the group %(group_dn)s") % locals())
attr = [(self.ldap.MOD_ADD, 'member', self.__uid_to_dn(uid))] attr = [(self.ldap.MOD_ADD, 'member', self.__uid_to_dn(uid))]
self.conn.modify_s(group_dn, attr) self.conn.modify_s(group_dn, attr)
def __remove_from_group(self, uid, group_dn): def __remove_from_group(self, uid, group_dn):
"""Remove user from group""" """Remove user from group"""
if not self.__group_exists(group_dn): if not self.__group_exists(group_dn):
raise exception.NotFound(_("The group at dn %s doesn't exist") raise exception.LDAPGroupNotFound(group_id=group_dn)
% group_dn)
if not self.__user_exists(uid): if not self.__user_exists(uid):
raise exception.NotFound(_("User %s can't be removed from the " raise exception.LDAPUserNotFound(user_id=uid)
"group because the user doesn't exist")
% uid)
if not self.__is_in_group(uid, group_dn): if not self.__is_in_group(uid, group_dn):
raise exception.NotFound(_("User %s is not a member of the group") raise exception.LDAPGroupMembershipNotFound(user_id=uid,
% uid) group_id=group_dn)
# NOTE(vish): remove user from group and any sub_groups # NOTE(vish): remove user from group and any sub_groups
sub_dns = self.__find_group_dns_with_member(group_dn, uid) sub_dns = self.__find_group_dns_with_member(group_dn, uid)
for sub_dn in sub_dns: for sub_dn in sub_dns:
@@ -548,9 +530,7 @@ class LdapDriver(object):
def __remove_from_all(self, uid): def __remove_from_all(self, uid):
"""Remove user from all roles and projects""" """Remove user from all roles and projects"""
if not self.__user_exists(uid): if not self.__user_exists(uid):
raise exception.NotFound(_("User %s can't be removed from all " raise exception.LDAPUserNotFound(user_id=uid)
"because the user doesn't exist")
% uid)
role_dns = self.__find_group_dns_with_member( role_dns = self.__find_group_dns_with_member(
FLAGS.role_project_subtree, uid) FLAGS.role_project_subtree, uid)
for role_dn in role_dns: for role_dn in role_dns:
@@ -563,8 +543,7 @@ class LdapDriver(object):
def __delete_group(self, group_dn): def __delete_group(self, group_dn):
"""Delete Group""" """Delete Group"""
if not self.__group_exists(group_dn): if not self.__group_exists(group_dn):
raise exception.NotFound(_("Group at dn %s doesn't exist") raise exception.LDAPGroupNotFound(group_id=group_dn)
% group_dn)
self.conn.delete_s(group_dn) self.conn.delete_s(group_dn)
def __delete_roles(self, project_dn): def __delete_roles(self, project_dn):

View File

@@ -223,6 +223,13 @@ class AuthManager(object):
if driver or not getattr(self, 'driver', None): if driver or not getattr(self, 'driver', None):
self.driver = utils.import_class(driver or FLAGS.auth_driver) self.driver = utils.import_class(driver or FLAGS.auth_driver)
if FLAGS.memcached_servers:
import memcache
else:
from nova import fakememcache as memcache
self.mc = memcache.Client(FLAGS.memcached_servers,
debug=0)
def authenticate(self, access, signature, params, verb='GET', def authenticate(self, access, signature, params, verb='GET',
server_string='127.0.0.1:8773', path='/', server_string='127.0.0.1:8773', path='/',
check_type='ec2', headers=None): check_type='ec2', headers=None):
@@ -268,10 +275,9 @@ class AuthManager(object):
LOG.debug(_('Looking up user: %r'), access_key) LOG.debug(_('Looking up user: %r'), access_key)
user = self.get_user_from_access_key(access_key) user = self.get_user_from_access_key(access_key)
LOG.debug('user: %r', user) LOG.debug('user: %r', user)
if user == None: if user is None:
LOG.audit(_("Failed authorization for access key %s"), access_key) LOG.audit(_("Failed authorization for access key %s"), access_key)
raise exception.NotFound(_('No user found for access key %s') raise exception.AccessKeyNotFound(access_key=access_key)
% access_key)
# NOTE(vish): if we stop using project name as id we need better # NOTE(vish): if we stop using project name as id we need better
# logic to find a default project for user # logic to find a default project for user
@@ -280,13 +286,12 @@ class AuthManager(object):
project_id = user.name project_id = user.name
project = self.get_project(project_id) project = self.get_project(project_id)
if project == None: if project is None:
pjid = project_id pjid = project_id
uname = user.name uname = user.name
LOG.audit(_("failed authorization: no project named %(pjid)s" LOG.audit(_("failed authorization: no project named %(pjid)s"
" (user=%(uname)s)") % locals()) " (user=%(uname)s)") % locals())
raise exception.NotFound(_('No project called %s could be found') raise exception.ProjectNotFound(project_id=project_id)
% project_id)
if not self.is_admin(user) and not self.is_project_member(user, if not self.is_admin(user) and not self.is_project_member(user,
project): project):
uname = user.name uname = user.name
@@ -295,28 +300,40 @@ class AuthManager(object):
pjid = project.id pjid = project.id
LOG.audit(_("Failed authorization: user %(uname)s not admin" LOG.audit(_("Failed authorization: user %(uname)s not admin"
" and not member of project %(pjname)s") % locals()) " and not member of project %(pjname)s") % locals())
raise exception.NotFound(_('User %(uid)s is not a member of' raise exception.ProjectMembershipNotFound(project_id=pjid,
' project %(pjid)s') % locals()) user_id=uid)
if check_type == 's3': if check_type == 's3':
sign = signer.Signer(user.secret.encode()) sign = signer.Signer(user.secret.encode())
expected_signature = sign.s3_authorization(headers, verb, path) expected_signature = sign.s3_authorization(headers, verb, path)
LOG.debug('user.secret: %s', user.secret) LOG.debug(_('user.secret: %s'), user.secret)
LOG.debug('expected_signature: %s', expected_signature) LOG.debug(_('expected_signature: %s'), expected_signature)
LOG.debug('signature: %s', signature) LOG.debug(_('signature: %s'), signature)
if signature != expected_signature: if signature != expected_signature:
LOG.audit(_("Invalid signature for user %s"), user.name) LOG.audit(_("Invalid signature for user %s"), user.name)
raise exception.NotAuthorized(_('Signature does not match')) raise exception.InvalidSignature(signature=signature,
user=user)
elif check_type == 'ec2': elif check_type == 'ec2':
# NOTE(vish): hmac can't handle unicode, so encode ensures that # NOTE(vish): hmac can't handle unicode, so encode ensures that
# secret isn't unicode # secret isn't unicode
expected_signature = signer.Signer(user.secret.encode()).generate( expected_signature = signer.Signer(user.secret.encode()).generate(
params, verb, server_string, path) params, verb, server_string, path)
LOG.debug('user.secret: %s', user.secret) LOG.debug(_('user.secret: %s'), user.secret)
LOG.debug('expected_signature: %s', expected_signature) LOG.debug(_('expected_signature: %s'), expected_signature)
LOG.debug('signature: %s', signature) LOG.debug(_('signature: %s'), signature)
if signature != expected_signature: if signature != expected_signature:
(addr_str, port_str) = utils.parse_server_string(server_string)
# If the given server_string contains port num, try without it.
if port_str != '':
host_only_signature = signer.Signer(
user.secret.encode()).generate(params, verb,
addr_str, path)
LOG.debug(_('host_only_signature: %s'),
host_only_signature)
if signature == host_only_signature:
return (user, project)
LOG.audit(_("Invalid signature for user %s"), user.name) LOG.audit(_("Invalid signature for user %s"), user.name)
raise exception.NotAuthorized(_('Signature does not match')) raise exception.InvalidSignature(signature=signature,
user=user)
return (user, project) return (user, project)
def get_access_key(self, user, project): def get_access_key(self, user, project):
@@ -360,6 +377,27 @@ class AuthManager(object):
if self.has_role(user, role): if self.has_role(user, role):
return True return True
def _build_mc_key(self, user, role, project=None):
key_parts = ['rolecache', User.safe_id(user), str(role)]
if project:
key_parts.append(Project.safe_id(project))
return '-'.join(key_parts)
def _clear_mc_key(self, user, role, project=None):
# NOTE(anthony): it would be better to delete the key
self.mc.set(self._build_mc_key(user, role, project), None)
def _has_role(self, user, role, project=None):
mc_key = self._build_mc_key(user, role, project)
rslt = self.mc.get(mc_key)
if rslt is None:
with self.driver() as drv:
rslt = drv.has_role(user, role, project)
self.mc.set(mc_key, rslt)
return rslt
else:
return rslt
def has_role(self, user, role, project=None): def has_role(self, user, role, project=None):
"""Checks existence of role for user """Checks existence of role for user
@@ -383,24 +421,24 @@ class AuthManager(object):
@rtype: bool @rtype: bool
@return: True if the user has the role. @return: True if the user has the role.
""" """
with self.driver() as drv: if role == 'projectmanager':
if role == 'projectmanager': if not project:
if not project: raise exception.Error(_("Must specify project"))
raise exception.Error(_("Must specify project")) return self.is_project_manager(user, project)
return self.is_project_manager(user, project)
global_role = drv.has_role(User.safe_id(user), global_role = self._has_role(User.safe_id(user),
role, role,
None) None)
if not global_role:
return global_role
if not project or role in FLAGS.global_roles: if not global_role:
return global_role return global_role
return drv.has_role(User.safe_id(user), if not project or role in FLAGS.global_roles:
role, return global_role
Project.safe_id(project))
return self._has_role(User.safe_id(user),
role,
Project.safe_id(project))
def add_role(self, user, role, project=None): def add_role(self, user, role, project=None):
"""Adds role for user """Adds role for user
@@ -420,9 +458,9 @@ class AuthManager(object):
@param project: Project in which to add local role. @param project: Project in which to add local role.
""" """
if role not in FLAGS.allowed_roles: if role not in FLAGS.allowed_roles:
raise exception.NotFound(_("The %s role can not be found") % role) raise exception.UserRoleNotFound(role_id=role)
if project is not None and role in FLAGS.global_roles: if project is not None and role in FLAGS.global_roles:
raise exception.NotFound(_("The %s role is global only") % role) raise exception.GlobalRoleNotAllowed(role_id=role)
uid = User.safe_id(user) uid = User.safe_id(user)
pid = Project.safe_id(project) pid = Project.safe_id(project)
if project: if project:
@@ -432,6 +470,7 @@ class AuthManager(object):
LOG.audit(_("Adding sitewide role %(role)s to user %(uid)s") LOG.audit(_("Adding sitewide role %(role)s to user %(uid)s")
% locals()) % locals())
with self.driver() as drv: with self.driver() as drv:
self._clear_mc_key(uid, role, pid)
drv.add_role(uid, role, pid) drv.add_role(uid, role, pid)
def remove_role(self, user, role, project=None): def remove_role(self, user, role, project=None):
@@ -460,6 +499,7 @@ class AuthManager(object):
LOG.audit(_("Removing sitewide role %(role)s" LOG.audit(_("Removing sitewide role %(role)s"
" from user %(uid)s") % locals()) " from user %(uid)s") % locals())
with self.driver() as drv: with self.driver() as drv:
self._clear_mc_key(uid, role, pid)
drv.remove_role(uid, role, pid) drv.remove_role(uid, role, pid)
@staticmethod @staticmethod
@@ -646,9 +686,9 @@ class AuthManager(object):
@rtype: User @rtype: User
@return: The new user. @return: The new user.
""" """
if access == None: if access is None:
access = str(uuid.uuid4()) access = str(uuid.uuid4())
if secret == None: if secret is None:
secret = str(uuid.uuid4()) secret = str(uuid.uuid4())
with self.driver() as drv: with self.driver() as drv:
user_dict = drv.create_user(name, access, secret, admin) user_dict = drv.create_user(name, access, secret, admin)

View File

@@ -18,14 +18,14 @@
"""Super simple fake memcache client.""" """Super simple fake memcache client."""
import utils from nova import utils
class Client(object): class Client(object):
"""Replicates a tiny subset of memcached client interface.""" """Replicates a tiny subset of memcached client interface."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Ignores the passed in args""" """Ignores the passed in args."""
self.cache = {} self.cache = {}
def get(self, key): def get(self, key):

View File

@@ -16,9 +16,13 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
""" """Command-line flag library.
Wraps gflags.
Package-level global flags are defined here, the rest are defined Package-level global flags are defined here, the rest are defined
where they're used. where they're used.
""" """
import getopt import getopt
@@ -106,7 +110,7 @@ class FlagValues(gflags.FlagValues):
return name in self.__dict__['__dirty'] return name in self.__dict__['__dirty']
def ClearDirty(self): def ClearDirty(self):
self.__dict__['__is_dirty'] = [] self.__dict__['__dirty'] = []
def WasAlreadyParsed(self): def WasAlreadyParsed(self):
return self.__dict__['__was_already_parsed'] return self.__dict__['__was_already_parsed']
@@ -115,11 +119,12 @@ class FlagValues(gflags.FlagValues):
if '__stored_argv' not in self.__dict__: if '__stored_argv' not in self.__dict__:
return return
new_flags = FlagValues(self) new_flags = FlagValues(self)
for k in self.__dict__['__dirty']: for k in self.FlagDict().iterkeys():
new_flags[k] = gflags.FlagValues.__getitem__(self, k) new_flags[k] = gflags.FlagValues.__getitem__(self, k)
new_flags.Reset()
new_flags(self.__dict__['__stored_argv']) new_flags(self.__dict__['__stored_argv'])
for k in self.__dict__['__dirty']: for k in new_flags.FlagDict().iterkeys():
setattr(self, k, getattr(new_flags, k)) setattr(self, k, getattr(new_flags, k))
self.ClearDirty() self.ClearDirty()
@@ -145,10 +150,12 @@ class FlagValues(gflags.FlagValues):
class StrWrapper(object): class StrWrapper(object):
"""Wrapper around FlagValues objects """Wrapper around FlagValues objects.
Wraps FlagValues objects for string.Template so that we're Wraps FlagValues objects for string.Template so that we're
sure to return strings.""" sure to return strings.
"""
def __init__(self, context_objs): def __init__(self, context_objs):
self.context_objs = context_objs self.context_objs = context_objs
@@ -169,6 +176,7 @@ def _GetCallingModule():
We generally use this function to get the name of the module calling a We generally use this function to get the name of the module calling a
DEFINE_foo... function. DEFINE_foo... function.
""" """
# Walk down the stack to find the first globals dict that's not ours. # Walk down the stack to find the first globals dict that's not ours.
for depth in range(1, sys.getrecursionlimit()): for depth in range(1, sys.getrecursionlimit()):
@@ -192,6 +200,7 @@ def __GetModuleName(globals_dict):
Returns: Returns:
A string (the name of the module) or None (if the module could not A string (the name of the module) or None (if the module could not
be identified. be identified.
""" """
for name, module in sys.modules.iteritems(): for name, module in sys.modules.iteritems():
if getattr(module, '__dict__', None) is globals_dict: if getattr(module, '__dict__', None) is globals_dict:
@@ -316,7 +325,7 @@ DEFINE_string('null_kernel', 'nokernel',
'kernel image that indicates not to use a kernel,' 'kernel image that indicates not to use a kernel,'
' but to use a raw disk image instead') ' but to use a raw disk image instead')
DEFINE_string('vpn_image_id', 'ami-cloudpipe', 'AMI for cloudpipe vpn server') DEFINE_integer('vpn_image_id', 0, 'integer id for cloudpipe vpn server')
DEFINE_string('vpn_key_suffix', DEFINE_string('vpn_key_suffix',
'-vpn', '-vpn',
'Suffix to add to project name for vpn key and secgroups') 'Suffix to add to project name for vpn key and secgroups')
@@ -326,7 +335,7 @@ DEFINE_integer('auth_token_ttl', 3600, 'Seconds for auth tokens to linger')
DEFINE_string('state_path', os.path.join(os.path.dirname(__file__), '../'), DEFINE_string('state_path', os.path.join(os.path.dirname(__file__), '../'),
"Top-level directory for maintaining nova's state") "Top-level directory for maintaining nova's state")
DEFINE_string('lock_path', os.path.join(os.path.dirname(__file__), '../'), DEFINE_string('lock_path', os.path.join(os.path.dirname(__file__), '../'),
"Directory for lock files") 'Directory for lock files')
DEFINE_string('logdir', None, 'output to a per-service log file in named ' DEFINE_string('logdir', None, 'output to a per-service log file in named '
'directory') 'directory')
@@ -361,6 +370,12 @@ DEFINE_string('host', socket.gethostname(),
DEFINE_string('node_availability_zone', 'nova', DEFINE_string('node_availability_zone', 'nova',
'availability zone of this node') 'availability zone of this node')
DEFINE_string('notification_driver',
'nova.notifier.no_op_notifier',
'Default driver for sending notifications')
DEFINE_list('memcached_servers', None,
'Memcached servers or None for in process cache.')
DEFINE_string('zone_name', 'nova', 'name of this zone') DEFINE_string('zone_name', 'nova', 'name of this zone')
DEFINE_list('zone_capabilities', DEFINE_list('zone_capabilities',
['hypervisor=xenserver;kvm', 'os=linux;windows'], ['hypervisor=xenserver;kvm', 'os=linux;windows'],

View File

@@ -16,16 +16,15 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
""" """Nova logging handler.
Nova logging handler.
This module adds to logging functionality by adding the option to specify This module adds to logging functionality by adding the option to specify
a context object when calling the various log methods. If the context object a context object when calling the various log methods. If the context object
is not specified, default formatting is used. is not specified, default formatting is used.
It also allows setting of formatting information through flags. It also allows setting of formatting information through flags.
"""
"""
import cStringIO import cStringIO
import inspect import inspect
@@ -41,34 +40,28 @@ from nova import version
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string('logging_context_format_string', flags.DEFINE_string('logging_context_format_string',
'%(asctime)s %(levelname)s %(name)s ' '%(asctime)s %(levelname)s %(name)s '
'[%(request_id)s %(user)s ' '[%(request_id)s %(user)s '
'%(project)s] %(message)s', '%(project)s] %(message)s',
'format string to use for log messages with context') 'format string to use for log messages with context')
flags.DEFINE_string('logging_default_format_string', flags.DEFINE_string('logging_default_format_string',
'%(asctime)s %(levelname)s %(name)s [-] ' '%(asctime)s %(levelname)s %(name)s [-] '
'%(message)s', '%(message)s',
'format string to use for log messages without context') 'format string to use for log messages without context')
flags.DEFINE_string('logging_debug_format_suffix', flags.DEFINE_string('logging_debug_format_suffix',
'from (pid=%(process)d) %(funcName)s' 'from (pid=%(process)d) %(funcName)s'
' %(pathname)s:%(lineno)d', ' %(pathname)s:%(lineno)d',
'data to append to log format when level is DEBUG') 'data to append to log format when level is DEBUG')
flags.DEFINE_string('logging_exception_prefix', flags.DEFINE_string('logging_exception_prefix',
'(%(name)s): TRACE: ', '(%(name)s): TRACE: ',
'prefix each line of exception output with this format') 'prefix each line of exception output with this format')
flags.DEFINE_list('default_log_levels', flags.DEFINE_list('default_log_levels',
['amqplib=WARN', ['amqplib=WARN',
'sqlalchemy=WARN', 'sqlalchemy=WARN',
'boto=WARN', 'boto=WARN',
'eventlet.wsgi.server=WARN'], 'eventlet.wsgi.server=WARN'],
'list of logger=LEVEL pairs') 'list of logger=LEVEL pairs')
flags.DEFINE_bool('use_syslog', False, 'output to syslog') flags.DEFINE_bool('use_syslog', False, 'output to syslog')
flags.DEFINE_string('logfile', None, 'output to named file') flags.DEFINE_string('logfile', None, 'output to named file')
@@ -83,6 +76,8 @@ WARN = logging.WARN
INFO = logging.INFO INFO = logging.INFO
DEBUG = logging.DEBUG DEBUG = logging.DEBUG
NOTSET = logging.NOTSET NOTSET = logging.NOTSET
# methods # methods
getLogger = logging.getLogger getLogger = logging.getLogger
debug = logging.debug debug = logging.debug
@@ -93,6 +88,8 @@ error = logging.error
exception = logging.exception exception = logging.exception
critical = logging.critical critical = logging.critical
log = logging.log log = logging.log
# handlers # handlers
StreamHandler = logging.StreamHandler StreamHandler = logging.StreamHandler
WatchedFileHandler = logging.handlers.WatchedFileHandler WatchedFileHandler = logging.handlers.WatchedFileHandler
@@ -106,7 +103,7 @@ logging.addLevelName(AUDIT, 'AUDIT')
def _dictify_context(context): def _dictify_context(context):
if context == None: if context is None:
return None return None
if not isinstance(context, dict) \ if not isinstance(context, dict) \
and getattr(context, 'to_dict', None): and getattr(context, 'to_dict', None):
@@ -127,17 +124,18 @@ def _get_log_file_path(binary=None):
class NovaLogger(logging.Logger): class NovaLogger(logging.Logger):
""" """NovaLogger manages request context and formatting.
NovaLogger manages request context and formatting.
This becomes the class that is instanciated by logging.getLogger. This becomes the class that is instanciated by logging.getLogger.
""" """
def __init__(self, name, level=NOTSET): def __init__(self, name, level=NOTSET):
logging.Logger.__init__(self, name, level) logging.Logger.__init__(self, name, level)
self.setup_from_flags() self.setup_from_flags()
def setup_from_flags(self): def setup_from_flags(self):
"""Setup logger from flags""" """Setup logger from flags."""
level = NOTSET level = NOTSET
for pair in FLAGS.default_log_levels: for pair in FLAGS.default_log_levels:
logger, _sep, level_name = pair.partition('=') logger, _sep, level_name = pair.partition('=')
@@ -148,7 +146,7 @@ class NovaLogger(logging.Logger):
self.setLevel(level) self.setLevel(level)
def _log(self, level, msg, args, exc_info=None, extra=None, context=None): def _log(self, level, msg, args, exc_info=None, extra=None, context=None):
"""Extract context from any log call""" """Extract context from any log call."""
if not extra: if not extra:
extra = {} extra = {}
if context: if context:
@@ -157,17 +155,17 @@ class NovaLogger(logging.Logger):
return logging.Logger._log(self, level, msg, args, exc_info, extra) return logging.Logger._log(self, level, msg, args, exc_info, extra)
def addHandler(self, handler): def addHandler(self, handler):
"""Each handler gets our custom formatter""" """Each handler gets our custom formatter."""
handler.setFormatter(_formatter) handler.setFormatter(_formatter)
return logging.Logger.addHandler(self, handler) return logging.Logger.addHandler(self, handler)
def audit(self, msg, *args, **kwargs): def audit(self, msg, *args, **kwargs):
"""Shortcut for our AUDIT level""" """Shortcut for our AUDIT level."""
if self.isEnabledFor(AUDIT): if self.isEnabledFor(AUDIT):
self._log(AUDIT, msg, args, **kwargs) self._log(AUDIT, msg, args, **kwargs)
def exception(self, msg, *args, **kwargs): def exception(self, msg, *args, **kwargs):
"""Logging.exception doesn't handle kwargs, so breaks context""" """Logging.exception doesn't handle kwargs, so breaks context."""
if not kwargs.get('exc_info'): if not kwargs.get('exc_info'):
kwargs['exc_info'] = 1 kwargs['exc_info'] = 1
self.error(msg, *args, **kwargs) self.error(msg, *args, **kwargs)
@@ -181,14 +179,13 @@ class NovaLogger(logging.Logger):
for k in env.keys(): for k in env.keys():
if not isinstance(env[k], str): if not isinstance(env[k], str):
env.pop(k) env.pop(k)
message = "Environment: %s" % json.dumps(env) message = 'Environment: %s' % json.dumps(env)
kwargs.pop('exc_info') kwargs.pop('exc_info')
self.error(message, **kwargs) self.error(message, **kwargs)
class NovaFormatter(logging.Formatter): class NovaFormatter(logging.Formatter):
""" """A nova.context.RequestContext aware formatter configured through flags.
A nova.context.RequestContext aware formatter configured through flags.
The flags used to set format strings are: logging_context_foramt_string The flags used to set format strings are: logging_context_foramt_string
and logging_default_format_string. You can also specify and logging_default_format_string. You can also specify
@@ -197,10 +194,11 @@ class NovaFormatter(logging.Formatter):
For information about what variables are available for the formatter see: For information about what variables are available for the formatter see:
http://docs.python.org/library/logging.html#formatter http://docs.python.org/library/logging.html#formatter
""" """
def format(self, record): def format(self, record):
"""Uses contextstring if request_id is set, otherwise default""" """Uses contextstring if request_id is set, otherwise default."""
if record.__dict__.get('request_id', None): if record.__dict__.get('request_id', None):
self._fmt = FLAGS.logging_context_format_string self._fmt = FLAGS.logging_context_format_string
else: else:
@@ -214,20 +212,21 @@ class NovaFormatter(logging.Formatter):
return logging.Formatter.format(self, record) return logging.Formatter.format(self, record)
def formatException(self, exc_info, record=None): def formatException(self, exc_info, record=None):
"""Format exception output with FLAGS.logging_exception_prefix""" """Format exception output with FLAGS.logging_exception_prefix."""
if not record: if not record:
return logging.Formatter.formatException(self, exc_info) return logging.Formatter.formatException(self, exc_info)
stringbuffer = cStringIO.StringIO() stringbuffer = cStringIO.StringIO()
traceback.print_exception(exc_info[0], exc_info[1], exc_info[2], traceback.print_exception(exc_info[0], exc_info[1], exc_info[2],
None, stringbuffer) None, stringbuffer)
lines = stringbuffer.getvalue().split("\n") lines = stringbuffer.getvalue().split('\n')
stringbuffer.close() stringbuffer.close()
formatted_lines = [] formatted_lines = []
for line in lines: for line in lines:
pl = FLAGS.logging_exception_prefix % record.__dict__ pl = FLAGS.logging_exception_prefix % record.__dict__
fl = "%s%s" % (pl, line) fl = '%s%s' % (pl, line)
formatted_lines.append(fl) formatted_lines.append(fl)
return "\n".join(formatted_lines) return '\n'.join(formatted_lines)
_formatter = NovaFormatter() _formatter = NovaFormatter()
@@ -241,7 +240,7 @@ class NovaRootLogger(NovaLogger):
NovaLogger.__init__(self, name, level) NovaLogger.__init__(self, name, level)
def setup_from_flags(self): def setup_from_flags(self):
"""Setup logger from flags""" """Setup logger from flags."""
global _filelog global _filelog
if FLAGS.use_syslog: if FLAGS.use_syslog:
self.syslog = SysLogHandler(address='/dev/log') self.syslog = SysLogHandler(address='/dev/log')

14
nova/notifier/__init__.py Normal file
View File

@@ -0,0 +1,14 @@
# Copyright 2011 OpenStack LLC.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

83
nova/notifier/api.py Normal file
View File

@@ -0,0 +1,83 @@
# Copyright 2011 OpenStack LLC.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.import datetime
import datetime
import uuid
from nova import flags
from nova import utils
FLAGS = flags.FLAGS
flags.DEFINE_string('default_notification_level', 'INFO',
'Default notification level for outgoing notifications')
WARN = 'WARN'
INFO = 'INFO'
ERROR = 'ERROR'
CRITICAL = 'CRITICAL'
DEBUG = 'DEBUG'
log_levels = (DEBUG, WARN, INFO, ERROR, CRITICAL)
class BadPriorityException(Exception):
pass
def notify(publisher_id, event_type, priority, payload):
"""
Sends a notification using the specified driver
Notify parameters:
publisher_id - the source worker_type.host of the message
event_type - the literal type of event (ex. Instance Creation)
priority - patterned after the enumeration of Python logging levels in
the set (DEBUG, WARN, INFO, ERROR, CRITICAL)
payload - A python dictionary of attributes
Outgoing message format includes the above parameters, and appends the
following:
message_id - a UUID representing the id for this notification
timestamp - the GMT timestamp the notification was sent at
The composite message will be constructed as a dictionary of the above
attributes, which will then be sent via the transport mechanism defined
by the driver.
Message example:
{'message_id': str(uuid.uuid4()),
'publisher_id': 'compute.host1',
'timestamp': datetime.datetime.utcnow(),
'priority': 'WARN',
'event_type': 'compute.create_instance',
'payload': {'instance_id': 12, ... }}
"""
if priority not in log_levels:
raise BadPriorityException(
_('%s not in valid priorities' % priority))
driver = utils.import_object(FLAGS.notification_driver)
msg = dict(message_id=str(uuid.uuid4()),
publisher_id=publisher_id,
event_type=event_type,
priority=priority,
payload=payload,
timestamp=str(datetime.datetime.utcnow()))
driver.notify(msg)

View File

@@ -0,0 +1,34 @@
# Copyright 2011 OpenStack LLC.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import json
from nova import flags
from nova import log as logging
FLAGS = flags.FLAGS
def notify(message):
"""Notifies the recipient of the desired event given the model.
Log notifications using nova's default logging system"""
priority = message.get('priority',
FLAGS.default_notification_level)
priority = priority.lower()
logger = logging.getLogger(
'nova.notification.%s' % message['event_type'])
getattr(logger, priority)(json.dumps(message))

View File

@@ -1,7 +1,4 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4 # Copyright 2011 OpenStack LLC.
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved. # All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); you may # Licensed under the Apache License, Version 2.0 (the "License"); you may
@@ -16,11 +13,7 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
from nova import flags
FLAGS = flags.FLAGS def notify(message):
"""Notifies the recipient of the desired event given the model"""
FLAGS.connection_type = 'libvirt' pass
FLAGS.fake_rabbit = False
FLAGS.fake_network = False
FLAGS.verbose = False

View File

@@ -0,0 +1,36 @@
# Copyright 2011 OpenStack LLC.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import nova.context
from nova import flags
from nova import rpc
FLAGS = flags.FLAGS
flags.DEFINE_string('notification_topic', 'notifications',
'RabbitMQ topic used for Nova notifications')
def notify(message):
"""Sends a notification to the RabbitMQ"""
context = nova.context.get_admin_context()
priority = message.get('priority',
FLAGS.default_notification_level)
priority = priority.lower()
topic = '%s.%s' % (FLAGS.notification_topic, priority)
rpc.cast(context, topic, message)

View File

@@ -16,9 +16,12 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
""" """AMQP-based RPC.
AMQP-based RPC. Queues have consumers and publishers.
Queues have consumers and publishers.
No fan-out support yet. No fan-out support yet.
""" """
import json import json
@@ -40,17 +43,19 @@ from nova import log as logging
from nova import utils from nova import utils
FLAGS = flags.FLAGS
LOG = logging.getLogger('nova.rpc') LOG = logging.getLogger('nova.rpc')
FLAGS = flags.FLAGS
flags.DEFINE_integer('rpc_thread_pool_size', 1024, 'Size of RPC thread pool') flags.DEFINE_integer('rpc_thread_pool_size', 1024, 'Size of RPC thread pool')
class Connection(carrot_connection.BrokerConnection): class Connection(carrot_connection.BrokerConnection):
"""Connection instance object""" """Connection instance object."""
@classmethod @classmethod
def instance(cls, new=True): def instance(cls, new=True):
"""Returns the instance""" """Returns the instance."""
if new or not hasattr(cls, '_instance'): if new or not hasattr(cls, '_instance'):
params = dict(hostname=FLAGS.rabbit_host, params = dict(hostname=FLAGS.rabbit_host,
port=FLAGS.rabbit_port, port=FLAGS.rabbit_port,
@@ -71,9 +76,11 @@ class Connection(carrot_connection.BrokerConnection):
@classmethod @classmethod
def recreate(cls): def recreate(cls):
"""Recreates the connection instance """Recreates the connection instance.
This is necessary to recover from some network errors/disconnects""" This is necessary to recover from some network errors/disconnects.
"""
try: try:
del cls._instance del cls._instance
except AttributeError, e: except AttributeError, e:
@@ -84,10 +91,12 @@ class Connection(carrot_connection.BrokerConnection):
class Consumer(messaging.Consumer): class Consumer(messaging.Consumer):
"""Consumer base class """Consumer base class.
Contains methods for connecting the fetch method to async loops.
Contains methods for connecting the fetch method to async loops
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
for i in xrange(FLAGS.rabbit_max_retries): for i in xrange(FLAGS.rabbit_max_retries):
if i > 0: if i > 0:
@@ -100,19 +109,18 @@ class Consumer(messaging.Consumer):
fl_host = FLAGS.rabbit_host fl_host = FLAGS.rabbit_host
fl_port = FLAGS.rabbit_port fl_port = FLAGS.rabbit_port
fl_intv = FLAGS.rabbit_retry_interval fl_intv = FLAGS.rabbit_retry_interval
LOG.error(_("AMQP server on %(fl_host)s:%(fl_port)d is" LOG.error(_('AMQP server on %(fl_host)s:%(fl_port)d is'
" unreachable: %(e)s. Trying again in %(fl_intv)d" ' unreachable: %(e)s. Trying again in %(fl_intv)d'
" seconds.") ' seconds.') % locals())
% locals())
self.failed_connection = True self.failed_connection = True
if self.failed_connection: if self.failed_connection:
LOG.error(_("Unable to connect to AMQP server " LOG.error(_('Unable to connect to AMQP server '
"after %d tries. Shutting down."), 'after %d tries. Shutting down.'),
FLAGS.rabbit_max_retries) FLAGS.rabbit_max_retries)
sys.exit(1) sys.exit(1)
def fetch(self, no_ack=None, auto_ack=None, enable_callbacks=False): def fetch(self, no_ack=None, auto_ack=None, enable_callbacks=False):
"""Wraps the parent fetch with some logic for failed connections""" """Wraps the parent fetch with some logic for failed connection."""
# TODO(vish): the logic for failed connections and logging should be # TODO(vish): the logic for failed connections and logging should be
# refactored into some sort of connection manager object # refactored into some sort of connection manager object
try: try:
@@ -125,14 +133,14 @@ class Consumer(messaging.Consumer):
self.declare() self.declare()
super(Consumer, self).fetch(no_ack, auto_ack, enable_callbacks) super(Consumer, self).fetch(no_ack, auto_ack, enable_callbacks)
if self.failed_connection: if self.failed_connection:
LOG.error(_("Reconnected to queue")) LOG.error(_('Reconnected to queue'))
self.failed_connection = False self.failed_connection = False
# NOTE(vish): This is catching all errors because we really don't # NOTE(vish): This is catching all errors because we really don't
# want exceptions to be logged 10 times a second if some # want exceptions to be logged 10 times a second if some
# persistent failure occurs. # persistent failure occurs.
except Exception, e: # pylint: disable=W0703 except Exception, e: # pylint: disable=W0703
if not self.failed_connection: if not self.failed_connection:
LOG.exception(_("Failed to fetch message from queue: %s" % e)) LOG.exception(_('Failed to fetch message from queue: %s' % e))
self.failed_connection = True self.failed_connection = True
def attach_to_eventlet(self): def attach_to_eventlet(self):
@@ -143,8 +151,9 @@ class Consumer(messaging.Consumer):
class AdapterConsumer(Consumer): class AdapterConsumer(Consumer):
"""Calls methods on a proxy object based on method and args""" """Calls methods on a proxy object based on method and args."""
def __init__(self, connection=None, topic="broadcast", proxy=None):
def __init__(self, connection=None, topic='broadcast', proxy=None):
LOG.debug(_('Initing the Adapter Consumer for %s') % topic) LOG.debug(_('Initing the Adapter Consumer for %s') % topic)
self.proxy = proxy self.proxy = proxy
self.pool = greenpool.GreenPool(FLAGS.rpc_thread_pool_size) self.pool = greenpool.GreenPool(FLAGS.rpc_thread_pool_size)
@@ -156,13 +165,14 @@ class AdapterConsumer(Consumer):
@exception.wrap_exception @exception.wrap_exception
def _receive(self, message_data, message): def _receive(self, message_data, message):
"""Magically looks for a method on the proxy object and calls it """Magically looks for a method on the proxy object and calls it.
Message data should be a dictionary with two keys: Message data should be a dictionary with two keys:
method: string representing the method to call method: string representing the method to call
args: dictionary of arg: value args: dictionary of arg: value
Example: {'method': 'echo', 'args': {'value': 42}} Example: {'method': 'echo', 'args': {'value': 42}}
""" """
LOG.debug(_('received %s') % message_data) LOG.debug(_('received %s') % message_data)
msg_id = message_data.pop('_msg_id', None) msg_id = message_data.pop('_msg_id', None)
@@ -189,22 +199,23 @@ class AdapterConsumer(Consumer):
if msg_id: if msg_id:
msg_reply(msg_id, rval, None) msg_reply(msg_id, rval, None)
except Exception as e: except Exception as e:
logging.exception("Exception during message handling") logging.exception('Exception during message handling')
if msg_id: if msg_id:
msg_reply(msg_id, None, sys.exc_info()) msg_reply(msg_id, None, sys.exc_info())
return return
class Publisher(messaging.Publisher): class Publisher(messaging.Publisher):
"""Publisher base class""" """Publisher base class."""
pass pass
class TopicAdapterConsumer(AdapterConsumer): class TopicAdapterConsumer(AdapterConsumer):
"""Consumes messages on a specific topic""" """Consumes messages on a specific topic."""
exchange_type = "topic"
def __init__(self, connection=None, topic="broadcast", proxy=None): exchange_type = 'topic'
def __init__(self, connection=None, topic='broadcast', proxy=None):
self.queue = topic self.queue = topic
self.routing_key = topic self.routing_key = topic
self.exchange = FLAGS.control_exchange self.exchange = FLAGS.control_exchange
@@ -214,27 +225,29 @@ class TopicAdapterConsumer(AdapterConsumer):
class FanoutAdapterConsumer(AdapterConsumer): class FanoutAdapterConsumer(AdapterConsumer):
"""Consumes messages from a fanout exchange""" """Consumes messages from a fanout exchange."""
exchange_type = "fanout"
def __init__(self, connection=None, topic="broadcast", proxy=None): exchange_type = 'fanout'
self.exchange = "%s_fanout" % topic
def __init__(self, connection=None, topic='broadcast', proxy=None):
self.exchange = '%s_fanout' % topic
self.routing_key = topic self.routing_key = topic
unique = uuid.uuid4().hex unique = uuid.uuid4().hex
self.queue = "%s_fanout_%s" % (topic, unique) self.queue = '%s_fanout_%s' % (topic, unique)
self.durable = False self.durable = False
LOG.info(_("Created '%(exchange)s' fanout exchange " LOG.info(_('Created "%(exchange)s" fanout exchange '
"with '%(key)s' routing key"), 'with "%(key)s" routing key'),
dict(exchange=self.exchange, key=self.routing_key)) dict(exchange=self.exchange, key=self.routing_key))
super(FanoutAdapterConsumer, self).__init__(connection=connection, super(FanoutAdapterConsumer, self).__init__(connection=connection,
topic=topic, proxy=proxy) topic=topic, proxy=proxy)
class TopicPublisher(Publisher): class TopicPublisher(Publisher):
"""Publishes messages on a specific topic""" """Publishes messages on a specific topic."""
exchange_type = "topic"
def __init__(self, connection=None, topic="broadcast"): exchange_type = 'topic'
def __init__(self, connection=None, topic='broadcast'):
self.routing_key = topic self.routing_key = topic
self.exchange = FLAGS.control_exchange self.exchange = FLAGS.control_exchange
self.durable = False self.durable = False
@@ -243,20 +256,22 @@ class TopicPublisher(Publisher):
class FanoutPublisher(Publisher): class FanoutPublisher(Publisher):
"""Publishes messages to a fanout exchange.""" """Publishes messages to a fanout exchange."""
exchange_type = "fanout"
exchange_type = 'fanout'
def __init__(self, topic, connection=None): def __init__(self, topic, connection=None):
self.exchange = "%s_fanout" % topic self.exchange = '%s_fanout' % topic
self.queue = "%s_fanout" % topic self.queue = '%s_fanout' % topic
self.durable = False self.durable = False
LOG.info(_("Creating '%(exchange)s' fanout exchange"), LOG.info(_('Creating "%(exchange)s" fanout exchange'),
dict(exchange=self.exchange)) dict(exchange=self.exchange))
super(FanoutPublisher, self).__init__(connection=connection) super(FanoutPublisher, self).__init__(connection=connection)
class DirectConsumer(Consumer): class DirectConsumer(Consumer):
"""Consumes messages directly on a channel specified by msg_id""" """Consumes messages directly on a channel specified by msg_id."""
exchange_type = "direct"
exchange_type = 'direct'
def __init__(self, connection=None, msg_id=None): def __init__(self, connection=None, msg_id=None):
self.queue = msg_id self.queue = msg_id
@@ -268,8 +283,9 @@ class DirectConsumer(Consumer):
class DirectPublisher(Publisher): class DirectPublisher(Publisher):
"""Publishes messages directly on a channel specified by msg_id""" """Publishes messages directly on a channel specified by msg_id."""
exchange_type = "direct"
exchange_type = 'direct'
def __init__(self, connection=None, msg_id=None): def __init__(self, connection=None, msg_id=None):
self.routing_key = msg_id self.routing_key = msg_id
@@ -279,9 +295,9 @@ class DirectPublisher(Publisher):
def msg_reply(msg_id, reply=None, failure=None): def msg_reply(msg_id, reply=None, failure=None):
"""Sends a reply or an error on the channel signified by msg_id """Sends a reply or an error on the channel signified by msg_id.
failure should be a sys.exc_info() tuple. Failure should be a sys.exc_info() tuple.
""" """
if failure: if failure:
@@ -303,17 +319,20 @@ def msg_reply(msg_id, reply=None, failure=None):
class RemoteError(exception.Error): class RemoteError(exception.Error):
"""Signifies that a remote class has raised an exception """Signifies that a remote class has raised an exception.
Containes a string representation of the type of the original exception, Containes a string representation of the type of the original exception,
the value of the original exception, and the traceback. These are the value of the original exception, and the traceback. These are
sent to the parent as a joined string so printing the exception sent to the parent as a joined string so printing the exception
contains all of the relevent info.""" contains all of the relevent info.
"""
def __init__(self, exc_type, value, traceback): def __init__(self, exc_type, value, traceback):
self.exc_type = exc_type self.exc_type = exc_type
self.value = value self.value = value
self.traceback = traceback self.traceback = traceback
super(RemoteError, self).__init__("%s %s\n%s" % (exc_type, super(RemoteError, self).__init__('%s %s\n%s' % (exc_type,
value, value,
traceback)) traceback))
@@ -339,6 +358,7 @@ def _pack_context(msg, context):
context out into a bunch of separate keys. If we want to support context out into a bunch of separate keys. If we want to support
more arguments in rabbit messages, we may want to do the same more arguments in rabbit messages, we may want to do the same
for args at some point. for args at some point.
""" """
context = dict([('_context_%s' % key, value) context = dict([('_context_%s' % key, value)
for (key, value) in context.to_dict().iteritems()]) for (key, value) in context.to_dict().iteritems()])
@@ -346,11 +366,11 @@ def _pack_context(msg, context):
def call(context, topic, msg): def call(context, topic, msg):
"""Sends a message on a topic and wait for a response""" """Sends a message on a topic and wait for a response."""
LOG.debug(_("Making asynchronous call on %s ..."), topic) LOG.debug(_('Making asynchronous call on %s ...'), topic)
msg_id = uuid.uuid4().hex msg_id = uuid.uuid4().hex
msg.update({'_msg_id': msg_id}) msg.update({'_msg_id': msg_id})
LOG.debug(_("MSG_ID is %s") % (msg_id)) LOG.debug(_('MSG_ID is %s') % (msg_id))
_pack_context(msg, context) _pack_context(msg, context)
class WaitMessage(object): class WaitMessage(object):
@@ -387,8 +407,8 @@ def call(context, topic, msg):
def cast(context, topic, msg): def cast(context, topic, msg):
"""Sends a message on a topic without waiting for a response""" """Sends a message on a topic without waiting for a response."""
LOG.debug(_("Making asynchronous cast on %s..."), topic) LOG.debug(_('Making asynchronous cast on %s...'), topic)
_pack_context(msg, context) _pack_context(msg, context)
conn = Connection.instance() conn = Connection.instance()
publisher = TopicPublisher(connection=conn, topic=topic) publisher = TopicPublisher(connection=conn, topic=topic)
@@ -397,8 +417,8 @@ def cast(context, topic, msg):
def fanout_cast(context, topic, msg): def fanout_cast(context, topic, msg):
"""Sends a message on a fanout exchange without waiting for a response""" """Sends a message on a fanout exchange without waiting for a response."""
LOG.debug(_("Making asynchronous fanout cast...")) LOG.debug(_('Making asynchronous fanout cast...'))
_pack_context(msg, context) _pack_context(msg, context)
conn = Connection.instance() conn = Connection.instance()
publisher = FanoutPublisher(topic, connection=conn) publisher = FanoutPublisher(topic, connection=conn)
@@ -407,14 +427,14 @@ def fanout_cast(context, topic, msg):
def generic_response(message_data, message): def generic_response(message_data, message):
"""Logs a result and exits""" """Logs a result and exits."""
LOG.debug(_('response %s'), message_data) LOG.debug(_('response %s'), message_data)
message.ack() message.ack()
sys.exit(0) sys.exit(0)
def send_message(topic, message, wait=True): def send_message(topic, message, wait=True):
"""Sends a message for testing""" """Sends a message for testing."""
msg_id = uuid.uuid4().hex msg_id = uuid.uuid4().hex
message.update({'_msg_id': msg_id}) message.update({'_msg_id': msg_id})
LOG.debug(_('topic is %s'), topic) LOG.debug(_('topic is %s'), topic)
@@ -425,14 +445,14 @@ def send_message(topic, message, wait=True):
queue=msg_id, queue=msg_id,
exchange=msg_id, exchange=msg_id,
auto_delete=True, auto_delete=True,
exchange_type="direct", exchange_type='direct',
routing_key=msg_id) routing_key=msg_id)
consumer.register_callback(generic_response) consumer.register_callback(generic_response)
publisher = messaging.Publisher(connection=Connection.instance(), publisher = messaging.Publisher(connection=Connection.instance(),
exchange=FLAGS.control_exchange, exchange=FLAGS.control_exchange,
durable=False, durable=False,
exchange_type="topic", exchange_type='topic',
routing_key=topic) routing_key=topic)
publisher.send(message) publisher.send(message)
publisher.close() publisher.close()
@@ -441,8 +461,8 @@ def send_message(topic, message, wait=True):
consumer.wait() consumer.wait()
if __name__ == "__main__": if __name__ == '__main__':
# NOTE(vish): you can send messages from the command line using # You can send messages from the command line using
# topic and a json sting representing a dictionary # topic and a json string representing a dictionary
# for the method # for the method
send_message(sys.argv[1], json.loads(sys.argv[2])) send_message(sys.argv[1], json.loads(sys.argv[2]))

View File

@@ -76,11 +76,15 @@ def zone_update(context, zone_id, data):
return db.zone_update(context, zone_id, data) return db.zone_update(context, zone_id, data)
def get_zone_capabilities(context, service=None): def get_zone_capabilities(context):
"""Returns a dict of key, value capabilities for this zone, """Returns a dict of key, value capabilities for this zone."""
or for a particular class of services running in this zone.""" return _call_scheduler('get_zone_capabilities', context=context)
return _call_scheduler('get_zone_capabilities', context=context,
params=dict(service=service))
def select(context, specs=None):
"""Returns a list of hosts."""
return _call_scheduler('select', context=context,
params={"specs": specs})
def update_service_capabilities(context, service_name, host, capabilities): def update_service_capabilities(context, service_name, host, capabilities):
@@ -107,6 +111,45 @@ def _process(func, zone):
return func(nova, zone) return func(nova, zone)
def call_zone_method(context, method, errors_to_ignore=None, *args, **kwargs):
"""Returns a list of (zone, call_result) objects."""
if not isinstance(errors_to_ignore, (list, tuple)):
# This will also handle the default None
errors_to_ignore = [errors_to_ignore]
pool = greenpool.GreenPool()
results = []
for zone in db.zone_get_all(context):
try:
nova = novaclient.OpenStack(zone.username, zone.password,
zone.api_url)
nova.authenticate()
except novaclient.exceptions.BadRequest, e:
url = zone.api_url
LOG.warn(_("Failed request to zone; URL=%(url)s: %(e)s")
% locals())
#TODO (dabo) - add logic for failure counts per zone,
# with escalation after a given number of failures.
continue
zone_method = getattr(nova.zones, method)
def _error_trap(*args, **kwargs):
try:
return zone_method(*args, **kwargs)
except Exception as e:
if type(e) in errors_to_ignore:
return None
# TODO (dabo) - want to be able to re-raise here.
# Returning a string now; raising was causing issues.
# raise e
return "ERROR", "%s" % e
res = pool.spawn(_error_trap, *args, **kwargs)
results.append((zone, res))
pool.waitall()
return [(zone.id, res.wait()) for zone, res in results]
def child_zone_helper(zone_list, func): def child_zone_helper(zone_list, func):
"""Fire off a command to each zone in the list. """Fire off a command to each zone in the list.
The return is [novaclient return objects] from each child zone. The return is [novaclient return objects] from each child zone.

View File

@@ -0,0 +1,288 @@
# Copyright (c) 2011 Openstack, LLC.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Host Filter is a driver mechanism for requesting instance resources.
Three drivers are included: AllHosts, Flavor & JSON. AllHosts just
returns the full, unfiltered list of hosts. Flavor is a hard coded
matching mechanism based on flavor criteria and JSON is an ad-hoc
filter grammar.
Why JSON? The requests for instances may come in through the
REST interface from a user or a parent Zone.
Currently Flavors and/or InstanceTypes are used for
specifing the type of instance desired. Specific Nova users have
noted a need for a more expressive way of specifying instances.
Since we don't want to get into building full DSL this is a simple
form as an example of how this could be done. In reality, most
consumers will use the more rigid filters such as FlavorFilter.
Note: These are "required" capability filters. These capabilities
used must be present or the host will be excluded. The hosts
returned are then weighed by the Weighted Scheduler. Weights
can take the more esoteric factors into consideration (such as
server affinity and customer separation).
"""
import json
from nova import exception
from nova import flags
from nova import log as logging
from nova import utils
LOG = logging.getLogger('nova.scheduler.host_filter')
FLAGS = flags.FLAGS
flags.DEFINE_string('default_host_filter_driver',
'nova.scheduler.host_filter.AllHostsFilter',
'Which driver to use for filtering hosts.')
class HostFilter(object):
"""Base class for host filter drivers."""
def instance_type_to_filter(self, instance_type):
"""Convert instance_type into a filter for most common use-case."""
raise NotImplementedError()
def filter_hosts(self, zone_manager, query):
"""Return a list of hosts that fulfill the filter."""
raise NotImplementedError()
def _full_name(self):
"""module.classname of the filter driver"""
return "%s.%s" % (self.__module__, self.__class__.__name__)
class AllHostsFilter(HostFilter):
"""NOP host filter driver. Returns all hosts in ZoneManager.
This essentially does what the old Scheduler+Chance used
to give us."""
def instance_type_to_filter(self, instance_type):
"""Return anything to prevent base-class from raising
exception."""
return (self._full_name(), instance_type)
def filter_hosts(self, zone_manager, query):
"""Return a list of hosts from ZoneManager list."""
return [(host, services)
for host, services in zone_manager.service_states.iteritems()]
class FlavorFilter(HostFilter):
"""HostFilter driver hard-coded to work with flavors."""
def instance_type_to_filter(self, instance_type):
"""Use instance_type to filter hosts."""
return (self._full_name(), instance_type)
def filter_hosts(self, zone_manager, query):
"""Return a list of hosts that can create instance_type."""
instance_type = query
selected_hosts = []
for host, services in zone_manager.service_states.iteritems():
capabilities = services.get('compute', {})
host_ram_mb = capabilities['host_memory_free']
disk_bytes = capabilities['disk_available']
if host_ram_mb >= instance_type['memory_mb'] and \
disk_bytes >= instance_type['local_gb']:
selected_hosts.append((host, capabilities))
return selected_hosts
#host entries (currently) are like:
# {'host_name-description': 'Default install of XenServer',
# 'host_hostname': 'xs-mini',
# 'host_memory_total': 8244539392,
# 'host_memory_overhead': 184225792,
# 'host_memory_free': 3868327936,
# 'host_memory_free_computed': 3840843776},
# 'host_other-config': {},
# 'host_ip_address': '192.168.1.109',
# 'host_cpu_info': {},
# 'disk_available': 32954957824,
# 'disk_total': 50394562560,
# 'disk_used': 17439604736},
# 'host_uuid': 'cedb9b39-9388-41df-8891-c5c9a0c0fe5f',
# 'host_name-label': 'xs-mini'}
# instance_type table has:
#name = Column(String(255), unique=True)
#memory_mb = Column(Integer)
#vcpus = Column(Integer)
#local_gb = Column(Integer)
#flavorid = Column(Integer, unique=True)
#swap = Column(Integer, nullable=False, default=0)
#rxtx_quota = Column(Integer, nullable=False, default=0)
#rxtx_cap = Column(Integer, nullable=False, default=0)
class JsonFilter(HostFilter):
"""Host Filter driver to allow simple JSON-based grammar for
selecting hosts."""
def _equals(self, args):
"""First term is == all the other terms."""
if len(args) < 2:
return False
lhs = args[0]
for rhs in args[1:]:
if lhs != rhs:
return False
return True
def _less_than(self, args):
"""First term is < all the other terms."""
if len(args) < 2:
return False
lhs = args[0]
for rhs in args[1:]:
if lhs >= rhs:
return False
return True
def _greater_than(self, args):
"""First term is > all the other terms."""
if len(args) < 2:
return False
lhs = args[0]
for rhs in args[1:]:
if lhs <= rhs:
return False
return True
def _in(self, args):
"""First term is in set of remaining terms"""
if len(args) < 2:
return False
return args[0] in args[1:]
def _less_than_equal(self, args):
"""First term is <= all the other terms."""
if len(args) < 2:
return False
lhs = args[0]
for rhs in args[1:]:
if lhs > rhs:
return False
return True
def _greater_than_equal(self, args):
"""First term is >= all the other terms."""
if len(args) < 2:
return False
lhs = args[0]
for rhs in args[1:]:
if lhs < rhs:
return False
return True
def _not(self, args):
"""Flip each of the arguments."""
if len(args) == 0:
return False
return [not arg for arg in args]
def _or(self, args):
"""True if any arg is True."""
return True in args
def _and(self, args):
"""True if all args are True."""
return False not in args
commands = {
'=': _equals,
'<': _less_than,
'>': _greater_than,
'in': _in,
'<=': _less_than_equal,
'>=': _greater_than_equal,
'not': _not,
'or': _or,
'and': _and,
}
def instance_type_to_filter(self, instance_type):
"""Convert instance_type into JSON filter object."""
required_ram = instance_type['memory_mb']
required_disk = instance_type['local_gb']
query = ['and',
['>=', '$compute.host_memory_free', required_ram],
['>=', '$compute.disk_available', required_disk]
]
return (self._full_name(), json.dumps(query))
def _parse_string(self, string, host, services):
"""Strings prefixed with $ are capability lookups in the
form '$service.capability[.subcap*]'"""
if not string:
return None
if string[0] != '$':
return string
path = string[1:].split('.')
for item in path:
services = services.get(item, None)
if not services:
return None
return services
def _process_filter(self, zone_manager, query, host, services):
"""Recursively parse the query structure."""
if len(query) == 0:
return True
cmd = query[0]
method = self.commands[cmd] # Let exception fly.
cooked_args = []
for arg in query[1:]:
if isinstance(arg, list):
arg = self._process_filter(zone_manager, arg, host, services)
elif isinstance(arg, basestring):
arg = self._parse_string(arg, host, services)
if arg != None:
cooked_args.append(arg)
result = method(self, cooked_args)
return result
def filter_hosts(self, zone_manager, query):
"""Return a list of hosts that can fulfill filter."""
expanded = json.loads(query)
hosts = []
for host, services in zone_manager.service_states.iteritems():
r = self._process_filter(zone_manager, expanded, host, services)
if isinstance(r, list):
r = True in r
if r:
hosts.append((host, services))
return hosts
DRIVERS = [AllHostsFilter, FlavorFilter, JsonFilter]
def choose_driver(driver_name=None):
"""Since the caller may specify which driver to use we need
to have an authoritative list of what is permissible. This
function checks the driver name against a predefined set
of acceptable drivers."""
if not driver_name:
driver_name = FLAGS.default_host_filter_driver
for driver in DRIVERS:
if "%s.%s" % (driver.__module__, driver.__name__) == driver_name:
return driver()
raise exception.SchedulerHostFilterDriverNotFound(driver_name=driver_name)

View File

@@ -0,0 +1,119 @@
# Copyright (c) 2011 Openstack, LLC.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
The Zone Aware Scheduler is a base class Scheduler for creating instances
across zones. There are two expansion points to this class for:
1. Assigning Weights to hosts for requested instances
2. Filtering Hosts based on required instance capabilities
"""
import operator
from nova import log as logging
from nova.scheduler import api
from nova.scheduler import driver
LOG = logging.getLogger('nova.scheduler.zone_aware_scheduler')
class ZoneAwareScheduler(driver.Scheduler):
"""Base class for creating Zone Aware Schedulers."""
def _call_zone_method(self, context, method, specs):
"""Call novaclient zone method. Broken out for testing."""
return api.call_zone_method(context, method, specs=specs)
def schedule_run_instance(self, context, topic='compute', specs={},
*args, **kwargs):
"""This method is called from nova.compute.api to provision
an instance. However we need to look at the parameters being
passed in to see if this is a request to:
1. Create a Build Plan and then provision, or
2. Use the Build Plan information in the request parameters
to simply create the instance (either in this zone or
a child zone)."""
if 'blob' in specs:
return self.provision_instance(context, topic, specs)
# Create build plan and provision ...
build_plan = self.select(context, specs)
for item in build_plan:
self.provision_instance(context, topic, item)
def provision_instance(context, topic, item):
"""Create the requested instance in this Zone or a child zone."""
pass
def select(self, context, *args, **kwargs):
"""Select returns a list of weights and zone/host information
corresponding to the best hosts to service the request. Any
child zone information has been encrypted so as not to reveal
anything about the children."""
return self._schedule(context, "compute", *args, **kwargs)
def schedule(self, context, topic, *args, **kwargs):
"""The schedule() contract requires we return the one
best-suited host for this request.
"""
res = self._schedule(context, topic, *args, **kwargs)
# TODO(sirp): should this be a host object rather than a weight-dict?
if not res:
raise driver.NoValidHost(_('No hosts were available'))
return res[0]
def _schedule(self, context, topic, *args, **kwargs):
"""Returns a list of hosts that meet the required specs,
ordered by their fitness.
"""
#TODO(sandy): extract these from args.
num_instances = 1
specs = {}
# Filter local hosts based on requirements ...
host_list = self.filter_hosts(num_instances, specs)
# then weigh the selected hosts.
# weighted = [{weight=weight, name=hostname}, ...]
weighted = self.weigh_hosts(num_instances, specs, host_list)
# Next, tack on the best weights from the child zones ...
child_results = self._call_zone_method(context, "select",
specs=specs)
for child_zone, result in child_results:
for weighting in result:
# Remember the child_zone so we can get back to
# it later if needed. This implicitly builds a zone
# path structure.
host_dict = {
"weight": weighting["weight"],
"child_zone": child_zone,
"child_blob": weighting["blob"]}
weighted.append(host_dict)
weighted.sort(key=operator.itemgetter('weight'))
return weighted
def filter_hosts(self, num, specs):
"""Derived classes must override this method and return
a list of hosts in [(hostname, capability_dict)] format."""
raise NotImplemented()
def weigh_hosts(self, num, specs, hosts):
"""Derived classes must override this method and return
a lists of hosts in [{weight, hostname}] format."""
raise NotImplemented()

View File

@@ -106,28 +106,26 @@ class ZoneManager(object):
def __init__(self): def __init__(self):
self.last_zone_db_check = datetime.min self.last_zone_db_check = datetime.min
self.zone_states = {} # { <zone_id> : ZoneState } self.zone_states = {} # { <zone_id> : ZoneState }
self.service_states = {} # { <service> : { <host> : { cap k : v }}} self.service_states = {} # { <host> : { <service> : { cap k : v }}}
self.green_pool = greenpool.GreenPool() self.green_pool = greenpool.GreenPool()
def get_zone_list(self): def get_zone_list(self):
"""Return the list of zones we know about.""" """Return the list of zones we know about."""
return [zone.to_dict() for zone in self.zone_states.values()] return [zone.to_dict() for zone in self.zone_states.values()]
def get_zone_capabilities(self, context, service=None): def get_zone_capabilities(self, context):
"""Roll up all the individual host info to generic 'service' """Roll up all the individual host info to generic 'service'
capabilities. Each capability is aggregated into capabilities. Each capability is aggregated into
<cap>_min and <cap>_max values.""" <cap>_min and <cap>_max values."""
service_dict = self.service_states hosts_dict = self.service_states
if service:
service_dict = {service: self.service_states.get(service, {})}
# TODO(sandy) - be smarter about fabricating this structure. # TODO(sandy) - be smarter about fabricating this structure.
# But it's likely to change once we understand what the Best-Match # But it's likely to change once we understand what the Best-Match
# code will need better. # code will need better.
combined = {} # { <service>_<cap> : (min, max), ... } combined = {} # { <service>_<cap> : (min, max), ... }
for service_name, host_dict in service_dict.iteritems(): for host, host_dict in hosts_dict.iteritems():
for host, caps_dict in host_dict.iteritems(): for service_name, service_dict in host_dict.iteritems():
for cap, value in caps_dict.iteritems(): for cap, value in service_dict.iteritems():
key = "%s_%s" % (service_name, cap) key = "%s_%s" % (service_name, cap)
min_value, max_value = combined.get(key, (value, value)) min_value, max_value = combined.get(key, (value, value))
min_value = min(min_value, value) min_value = min(min_value, value)
@@ -171,6 +169,6 @@ class ZoneManager(object):
"""Update the per-service capabilities based on this notification.""" """Update the per-service capabilities based on this notification."""
logging.debug(_("Received %(service_name)s service update from " logging.debug(_("Received %(service_name)s service update from "
"%(host)s: %(capabilities)s") % locals()) "%(host)s: %(capabilities)s") % locals())
service_caps = self.service_states.get(service_name, {}) service_caps = self.service_states.get(host, {})
service_caps[host] = capabilities service_caps[service_name] = capabilities
self.service_states[service_name] = service_caps self.service_states[host] = service_caps

View File

@@ -21,24 +21,24 @@ from nova import flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DECLARE('volume_driver', 'nova.volume.manager') flags.DECLARE('volume_driver', 'nova.volume.manager')
FLAGS.volume_driver = 'nova.volume.driver.FakeISCSIDriver' FLAGS['volume_driver'].SetDefault('nova.volume.driver.FakeISCSIDriver')
FLAGS.connection_type = 'fake' FLAGS['connection_type'].SetDefault('fake')
FLAGS.fake_rabbit = True FLAGS['fake_rabbit'].SetDefault(True)
flags.DECLARE('auth_driver', 'nova.auth.manager') flags.DECLARE('auth_driver', 'nova.auth.manager')
FLAGS.auth_driver = 'nova.auth.dbdriver.DbDriver' FLAGS['auth_driver'].SetDefault('nova.auth.dbdriver.DbDriver')
flags.DECLARE('network_size', 'nova.network.manager') flags.DECLARE('network_size', 'nova.network.manager')
flags.DECLARE('num_networks', 'nova.network.manager') flags.DECLARE('num_networks', 'nova.network.manager')
flags.DECLARE('fake_network', 'nova.network.manager') flags.DECLARE('fake_network', 'nova.network.manager')
FLAGS.network_size = 8 FLAGS['network_size'].SetDefault(8)
FLAGS.num_networks = 2 FLAGS['num_networks'].SetDefault(2)
FLAGS.fake_network = True FLAGS['fake_network'].SetDefault(True)
FLAGS.image_service = 'nova.image.local.LocalImageService' FLAGS['image_service'].SetDefault('nova.image.local.LocalImageService')
flags.DECLARE('num_shelves', 'nova.volume.driver') flags.DECLARE('num_shelves', 'nova.volume.driver')
flags.DECLARE('blades_per_shelf', 'nova.volume.driver') flags.DECLARE('blades_per_shelf', 'nova.volume.driver')
flags.DECLARE('iscsi_num_targets', 'nova.volume.driver') flags.DECLARE('iscsi_num_targets', 'nova.volume.driver')
FLAGS.num_shelves = 2 FLAGS['num_shelves'].SetDefault(2)
FLAGS.blades_per_shelf = 4 FLAGS['blades_per_shelf'].SetDefault(4)
FLAGS.iscsi_num_targets = 8 FLAGS['iscsi_num_targets'].SetDefault(8)
FLAGS.verbose = True FLAGS['verbose'].SetDefault(True)
FLAGS.sqlite_db = "tests.sqlite" FLAGS['sqlite_db'].SetDefault("tests.sqlite")
FLAGS.use_ipv6 = True FLAGS['use_ipv6'].SetDefault(True)

View File

@@ -0,0 +1 @@
1c:87:d1:d9:32:fd:62:3c:78:2b:c0:ad:c0:15:88:df

View File

@@ -0,0 +1 @@
ssh-dss AAAAB3NzaC1kc3MAAACBAMGJlY9XEIm2X234pdO5yFWMp2JuOQx8U0E815IVXhmKxYCBK9ZakgZOIQmPbXoGYyV+mziDPp6HJ0wKYLQxkwLEFr51fAZjWQvRss0SinURRuLkockDfGFtD4pYJthekr/rlqMKlBSDUSpGq8jUWW60UJ18FGooFpxR7ESqQRx/AAAAFQC96LRglaUeeP+E8U/yblEJocuiWwAAAIA3XiMR8Skiz/0aBm5K50SeQznQuMJTyzt9S9uaz5QZWiFu69hOyGSFGw8fqgxEkXFJIuHobQQpGYQubLW0NdaYRqyE/Vud3JUJUb8Texld6dz8vGemyB5d1YvtSeHIo8/BGv2msOqR3u5AZTaGCBD9DhpSGOKHEdNjTtvpPd8S8gAAAIBociGZ5jf09iHLVENhyXujJbxfGRPsyNTyARJfCOGl0oFV6hEzcQyw8U/ePwjgvjc2UizMWLl8tsb2FXKHRdc2v+ND3Us+XqKQ33X3ADP4FZ/+Oj213gMyhCmvFTP0u5FmHog9My4CB7YcIWRuUR42WlhQ2IfPvKwUoTk3R+T6Og== www-data@mk

View File

@@ -28,10 +28,12 @@ import StringIO
import webob import webob
from nova import context from nova import context
from nova import exception
from nova import test from nova import test
from nova.api import ec2 from nova.api import ec2
from nova.api.ec2 import cloud
from nova.api.ec2 import apirequest from nova.api.ec2 import apirequest
from nova.api.ec2 import cloud
from nova.api.ec2 import ec2utils
from nova.auth import manager from nova.auth import manager
@@ -101,6 +103,21 @@ class XmlConversionTestCase(test.TestCase):
self.assertEqual(conv('-0'), 0) self.assertEqual(conv('-0'), 0)
class Ec2utilsTestCase(test.TestCase):
def test_ec2_id_to_id(self):
self.assertEqual(ec2utils.ec2_id_to_id('i-0000001e'), 30)
self.assertEqual(ec2utils.ec2_id_to_id('ami-1d'), 29)
def test_bad_ec2_id(self):
self.assertRaises(exception.InvalidEc2Id,
ec2utils.ec2_id_to_id,
'badone')
def test_id_to_ec2_id(self):
self.assertEqual(ec2utils.id_to_ec2_id(30), 'i-0000001e')
self.assertEqual(ec2utils.id_to_ec2_id(29, 'ami-%08x'), 'ami-0000001d')
class ApiEc2TestCase(test.TestCase): class ApiEc2TestCase(test.TestCase):
"""Unit test for the cloud controller on an EC2 API""" """Unit test for the cloud controller on an EC2 API"""
def setUp(self): def setUp(self):
@@ -207,6 +224,29 @@ class ApiEc2TestCase(test.TestCase):
self.manager.delete_project(project) self.manager.delete_project(project)
self.manager.delete_user(user) self.manager.delete_user(user)
def test_create_duplicate_key_pair(self):
"""Test that, after successfully generating a keypair,
requesting a second keypair with the same name fails sanely"""
self.expect_http()
self.mox.ReplayAll()
keyname = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \
for x in range(random.randint(4, 8)))
user = self.manager.create_user('fake', 'fake', 'fake')
project = self.manager.create_project('fake', 'fake', 'fake')
# NOTE(vish): create depends on pool, so call helper directly
self.ec2.create_key_pair('test')
try:
self.ec2.create_key_pair('test')
except EC2ResponseError, e:
if e.code == 'KeyPairExists':
pass
else:
self.fail("Unexpected EC2ResponseError: %s "
"(expected KeyPairExists)" % e.code)
else:
self.fail('Exception not raised.')
def test_get_all_security_groups(self): def test_get_all_security_groups(self):
"""Test that we can retrieve security groups""" """Test that we can retrieve security groups"""
self.expect_http() self.expect_http()

View File

@@ -101,9 +101,43 @@ class _AuthManagerBaseTestCase(test.TestCase):
self.assertEqual('private-party', u.access) self.assertEqual('private-party', u.access)
def test_004_signature_is_valid(self): def test_004_signature_is_valid(self):
#self.assertTrue(self.manager.authenticate(**boto.generate_url ...? )) with user_generator(self.manager, name='admin', secret='admin',
pass access='admin'):
#raise NotImplementedError with project_generator(self.manager, name="admin",
manager_user='admin'):
accesskey = 'admin:admin'
expected_result = (self.manager.get_user('admin'),
self.manager.get_project('admin'))
# captured sig and query string using boto 1.9b/euca2ools 1.2
sig = 'd67Wzd9Bwz8xid9QU+lzWXcF2Y3tRicYABPJgrqfrwM='
auth_params = {'AWSAccessKeyId': 'admin:admin',
'Action': 'DescribeAvailabilityZones',
'SignatureMethod': 'HmacSHA256',
'SignatureVersion': '2',
'Timestamp': '2011-04-22T11:29:29',
'Version': '2009-11-30'}
self.assertTrue(expected_result, self.manager.authenticate(
accesskey,
sig,
auth_params,
'GET',
'127.0.0.1:8773',
'/services/Cloud/'))
# captured sig and query string using RightAWS 1.10.0
sig = 'ECYLU6xdFG0ZqRVhQybPJQNJ5W4B9n8fGs6+/fuGD2c='
auth_params = {'AWSAccessKeyId': 'admin:admin',
'Action': 'DescribeAvailabilityZones',
'SignatureMethod': 'HmacSHA256',
'SignatureVersion': '2',
'Timestamp': '2011-04-22T11:29:49.000Z',
'Version': '2008-12-01'}
self.assertTrue(expected_result, self.manager.authenticate(
accesskey,
sig,
auth_params,
'GET',
'127.0.0.1',
'/services/Cloud'))
def test_005_can_get_credentials(self): def test_005_can_get_credentials(self):
return return

View File

@@ -36,6 +36,7 @@ from nova import rpc
from nova import service from nova import service
from nova import test from nova import test
from nova import utils from nova import utils
from nova import exception
from nova.auth import manager from nova.auth import manager
from nova.compute import power_state from nova.compute import power_state
from nova.api.ec2 import cloud from nova.api.ec2 import cloud
@@ -278,6 +279,26 @@ class CloudTestCase(test.TestCase):
user_group=['all']) user_group=['all'])
self.assertEqual(True, result['is_public']) self.assertEqual(True, result['is_public'])
def test_deregister_image(self):
deregister_image = self.cloud.deregister_image
def fake_delete(self, context, id):
return None
self.stubs.Set(local.LocalImageService, 'delete', fake_delete)
# valid image
result = deregister_image(self.context, 'ami-00000001')
self.assertEqual(result['imageId'], 'ami-00000001')
# invalid image
self.stubs.UnsetAll()
def fake_detail_empty(self, context):
return []
self.stubs.Set(local.LocalImageService, 'detail', fake_detail_empty)
self.assertRaises(exception.ImageNotFound, deregister_image,
self.context, 'ami-bad001')
def test_console_output(self): def test_console_output(self):
instance_type = FLAGS.default_instance_type instance_type = FLAGS.default_instance_type
max_count = 1 max_count = 1
@@ -289,7 +310,7 @@ class CloudTestCase(test.TestCase):
instance_id = rv['instancesSet'][0]['instanceId'] instance_id = rv['instancesSet'][0]['instanceId']
output = self.cloud.get_console_output(context=self.context, output = self.cloud.get_console_output(context=self.context,
instance_id=[instance_id]) instance_id=[instance_id])
self.assertEquals(b64decode(output['output']), 'FAKE CONSOLE OUTPUT') self.assertEquals(b64decode(output['output']), 'FAKE CONSOLE?OUTPUT')
# TODO(soren): We need this until we can stop polling in the rpc code # TODO(soren): We need this until we can stop polling in the rpc code
# for unit tests. # for unit tests.
greenthread.sleep(0.3) greenthread.sleep(0.3)
@@ -333,44 +354,52 @@ class CloudTestCase(test.TestCase):
self.assertTrue(filter(lambda k: k['keyName'] == 'test1', keys)) self.assertTrue(filter(lambda k: k['keyName'] == 'test1', keys))
self.assertTrue(filter(lambda k: k['keyName'] == 'test2', keys)) self.assertTrue(filter(lambda k: k['keyName'] == 'test2', keys))
def test_import_public_key(self):
# test when user provides all values
result1 = self.cloud.import_public_key(self.context,
'testimportkey1',
'mytestpubkey',
'mytestfprint')
self.assertTrue(result1)
keydata = db.key_pair_get(self.context,
self.context.user.id,
'testimportkey1')
self.assertEqual('mytestpubkey', keydata['public_key'])
self.assertEqual('mytestfprint', keydata['fingerprint'])
# test when user omits fingerprint
pubkey_path = os.path.join(os.path.dirname(__file__), 'public_key')
f = open(pubkey_path + '/dummy.pub', 'r')
dummypub = f.readline().rstrip()
f.close
f = open(pubkey_path + '/dummy.fingerprint', 'r')
dummyfprint = f.readline().rstrip()
f.close
result2 = self.cloud.import_public_key(self.context,
'testimportkey2',
dummypub)
self.assertTrue(result2)
keydata = db.key_pair_get(self.context,
self.context.user.id,
'testimportkey2')
self.assertEqual(dummypub, keydata['public_key'])
self.assertEqual(dummyfprint, keydata['fingerprint'])
def test_delete_key_pair(self): def test_delete_key_pair(self):
self._create_key('test') self._create_key('test')
self.cloud.delete_key_pair(self.context, 'test') self.cloud.delete_key_pair(self.context, 'test')
def test_run_instances(self): def test_terminate_instances(self):
if FLAGS.connection_type == 'fake': inst1 = db.instance_create(self.context, {'reservation_id': 'a',
LOG.debug(_("Can't test instances without a real virtual env.")) 'image_id': 1,
return 'host': 'host1'})
image_id = FLAGS.default_image terminate_instances = self.cloud.terminate_instances
instance_type = FLAGS.default_instance_type # valid instance_id
max_count = 1 result = terminate_instances(self.context, ['i-00000001'])
kwargs = {'image_id': image_id, self.assertTrue(result)
'instance_type': instance_type, # non-existing instance_id
'max_count': max_count} self.assertRaises(exception.InstanceNotFound, terminate_instances,
rv = self.cloud.run_instances(self.context, **kwargs) self.context, ['i-2'])
# TODO: check for proper response db.instance_destroy(self.context, inst1['id'])
instance_id = rv['reservationSet'][0].keys()[0]
instance = rv['reservationSet'][0][instance_id][0]
LOG.debug(_("Need to watch instance %s until it's running..."),
instance['instance_id'])
while True:
greenthread.sleep(1)
info = self.cloud._get_instance(instance['instance_id'])
LOG.debug(info['state'])
if info['state'] == power_state.RUNNING:
break
self.assert_(rv)
if FLAGS.connection_type != 'fake':
time.sleep(45) # Should use boto for polling here
for reservations in rv['reservationSet']:
# for res_id in reservations.keys():
# LOG.debug(reservations[res_id])
# for instance in reservations[res_id]:
for instance in reservations[reservations.keys()[0]]:
instance_id = instance['instance_id']
LOG.debug(_("Terminating instance %s"), instance_id)
rv = self.compute.terminate_instance(instance_id)
def test_update_of_instance_display_fields(self): def test_update_of_instance_display_fields(self):
inst = db.instance_create(self.context, {}) inst = db.instance_create(self.context, {})

View File

@@ -21,6 +21,7 @@ Tests For Compute
import datetime import datetime
import mox import mox
import stubout
from nova import compute from nova import compute
from nova import context from nova import context
@@ -52,6 +53,10 @@ class FakeTime(object):
self.counter += t self.counter += t
def nop_report_driver_status(self):
pass
class ComputeTestCase(test.TestCase): class ComputeTestCase(test.TestCase):
"""Test case for compute""" """Test case for compute"""
def setUp(self): def setUp(self):
@@ -329,6 +334,28 @@ class ComputeTestCase(test.TestCase):
self.compute.terminate_instance(self.context, instance_id) self.compute.terminate_instance(self.context, instance_id)
def test_finish_resize(self):
"""Contrived test to ensure finish_resize doesn't raise anything"""
def fake(*args, **kwargs):
pass
self.stubs.Set(self.compute.driver, 'finish_resize', fake)
context = self.context.elevated()
instance_id = self._create_instance()
self.compute.prep_resize(context, instance_id, 1)
migration_ref = db.migration_get_by_instance_and_status(context,
instance_id, 'pre-migrating')
try:
self.compute.finish_resize(context, instance_id,
int(migration_ref['id']), {})
except KeyError, e:
# Only catch key errors. We want other reasons for the test to
# fail to actually error out so we don't obscure anything
self.fail()
self.compute.terminate_instance(self.context, instance_id)
def test_resize_instance(self): def test_resize_instance(self):
"""Ensure instance can be migrated/resized""" """Ensure instance can be migrated/resized"""
instance_id = self._create_instance() instance_id = self._create_instance()
@@ -649,6 +676,10 @@ class ComputeTestCase(test.TestCase):
def test_run_kill_vm(self): def test_run_kill_vm(self):
"""Detect when a vm is terminated behind the scenes""" """Detect when a vm is terminated behind the scenes"""
self.stubs = stubout.StubOutForTesting()
self.stubs.Set(compute_manager.ComputeManager,
'_report_driver_status', nop_report_driver_status)
instance_id = self._create_instance() instance_id = self._create_instance()
self.compute.run_instance(self.context, instance_id) self.compute.run_instance(self.context, instance_id)

View File

@@ -91,6 +91,20 @@ class FlagsTestCase(test.TestCase):
self.assert_('runtime_answer' in self.global_FLAGS) self.assert_('runtime_answer' in self.global_FLAGS)
self.assertEqual(self.global_FLAGS.runtime_answer, 60) self.assertEqual(self.global_FLAGS.runtime_answer, 60)
def test_long_vs_short_flags(self):
flags.DEFINE_string('duplicate_answer_long', 'val', 'desc',
flag_values=self.global_FLAGS)
argv = ['flags_test', '--duplicate_answer=60', 'extra_arg']
args = self.global_FLAGS(argv)
self.assert_('duplicate_answer' not in self.global_FLAGS)
self.assert_(self.global_FLAGS.duplicate_answer_long, 60)
flags.DEFINE_integer('duplicate_answer', 60, 'desc',
flag_values=self.global_FLAGS)
self.assertEqual(self.global_FLAGS.duplicate_answer, 60)
self.assertEqual(self.global_FLAGS.duplicate_answer_long, 'val')
def test_flag_leak_left(self): def test_flag_leak_left(self):
self.assertEqual(FLAGS.flags_unittest, 'foo') self.assertEqual(FLAGS.flags_unittest, 'foo')
FLAGS.flags_unittest = 'bar' FLAGS.flags_unittest = 'bar'

View File

@@ -0,0 +1,208 @@
# Copyright 2011 OpenStack LLC.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Tests For Scheduler Host Filter Drivers.
"""
import json
from nova import exception
from nova import flags
from nova import test
from nova.scheduler import host_filter
FLAGS = flags.FLAGS
class FakeZoneManager:
pass
class HostFilterTestCase(test.TestCase):
"""Test case for host filter drivers."""
def _host_caps(self, multiplier):
# Returns host capabilities in the following way:
# host1 = memory:free 10 (100max)
# disk:available 100 (1000max)
# hostN = memory:free 10 + 10N
# disk:available 100 + 100N
# in other words: hostN has more resources than host0
# which means ... don't go above 10 hosts.
return {'host_name-description': 'XenServer %s' % multiplier,
'host_hostname': 'xs-%s' % multiplier,
'host_memory_total': 100,
'host_memory_overhead': 10,
'host_memory_free': 10 + multiplier * 10,
'host_memory_free-computed': 10 + multiplier * 10,
'host_other-config': {},
'host_ip_address': '192.168.1.%d' % (100 + multiplier),
'host_cpu_info': {},
'disk_available': 100 + multiplier * 100,
'disk_total': 1000,
'disk_used': 0,
'host_uuid': 'xxx-%d' % multiplier,
'host_name-label': 'xs-%s' % multiplier}
def setUp(self):
self.old_flag = FLAGS.default_host_filter_driver
FLAGS.default_host_filter_driver = \
'nova.scheduler.host_filter.AllHostsFilter'
self.instance_type = dict(name='tiny',
memory_mb=50,
vcpus=10,
local_gb=500,
flavorid=1,
swap=500,
rxtx_quota=30000,
rxtx_cap=200)
self.zone_manager = FakeZoneManager()
states = {}
for x in xrange(10):
states['host%02d' % (x + 1)] = {'compute': self._host_caps(x)}
self.zone_manager.service_states = states
def tearDown(self):
FLAGS.default_host_filter_driver = self.old_flag
def test_choose_driver(self):
# Test default driver ...
driver = host_filter.choose_driver()
self.assertEquals(driver._full_name(),
'nova.scheduler.host_filter.AllHostsFilter')
# Test valid driver ...
driver = host_filter.choose_driver(
'nova.scheduler.host_filter.FlavorFilter')
self.assertEquals(driver._full_name(),
'nova.scheduler.host_filter.FlavorFilter')
# Test invalid driver ...
try:
host_filter.choose_driver('does not exist')
self.fail("Should not find driver")
except exception.SchedulerHostFilterDriverNotFound:
pass
def test_all_host_driver(self):
driver = host_filter.AllHostsFilter()
cooked = driver.instance_type_to_filter(self.instance_type)
hosts = driver.filter_hosts(self.zone_manager, cooked)
self.assertEquals(10, len(hosts))
for host, capabilities in hosts:
self.assertTrue(host.startswith('host'))
def test_flavor_driver(self):
driver = host_filter.FlavorFilter()
# filter all hosts that can support 50 ram and 500 disk
name, cooked = driver.instance_type_to_filter(self.instance_type)
self.assertEquals('nova.scheduler.host_filter.FlavorFilter', name)
hosts = driver.filter_hosts(self.zone_manager, cooked)
self.assertEquals(6, len(hosts))
just_hosts = [host for host, caps in hosts]
just_hosts.sort()
self.assertEquals('host05', just_hosts[0])
self.assertEquals('host10', just_hosts[5])
def test_json_driver(self):
driver = host_filter.JsonFilter()
# filter all hosts that can support 50 ram and 500 disk
name, cooked = driver.instance_type_to_filter(self.instance_type)
self.assertEquals('nova.scheduler.host_filter.JsonFilter', name)
hosts = driver.filter_hosts(self.zone_manager, cooked)
self.assertEquals(6, len(hosts))
just_hosts = [host for host, caps in hosts]
just_hosts.sort()
self.assertEquals('host05', just_hosts[0])
self.assertEquals('host10', just_hosts[5])
# Try some custom queries
raw = ['or',
['and',
['<', '$compute.host_memory_free', 30],
['<', '$compute.disk_available', 300]
],
['and',
['>', '$compute.host_memory_free', 70],
['>', '$compute.disk_available', 700]
]
]
cooked = json.dumps(raw)
hosts = driver.filter_hosts(self.zone_manager, cooked)
self.assertEquals(5, len(hosts))
just_hosts = [host for host, caps in hosts]
just_hosts.sort()
for index, host in zip([1, 2, 8, 9, 10], just_hosts):
self.assertEquals('host%02d' % index, host)
raw = ['not',
['=', '$compute.host_memory_free', 30],
]
cooked = json.dumps(raw)
hosts = driver.filter_hosts(self.zone_manager, cooked)
self.assertEquals(9, len(hosts))
just_hosts = [host for host, caps in hosts]
just_hosts.sort()
for index, host in zip([1, 2, 4, 5, 6, 7, 8, 9, 10], just_hosts):
self.assertEquals('host%02d' % index, host)
raw = ['in', '$compute.host_memory_free', 20, 40, 60, 80, 100]
cooked = json.dumps(raw)
hosts = driver.filter_hosts(self.zone_manager, cooked)
self.assertEquals(5, len(hosts))
just_hosts = [host for host, caps in hosts]
just_hosts.sort()
for index, host in zip([2, 4, 6, 8, 10], just_hosts):
self.assertEquals('host%02d' % index, host)
# Try some bogus input ...
raw = ['unknown command', ]
cooked = json.dumps(raw)
try:
driver.filter_hosts(self.zone_manager, cooked)
self.fail("Should give KeyError")
except KeyError, e:
pass
self.assertTrue(driver.filter_hosts(self.zone_manager, json.dumps([])))
self.assertTrue(driver.filter_hosts(self.zone_manager, json.dumps({})))
self.assertTrue(driver.filter_hosts(self.zone_manager, json.dumps(
['not', True, False, True, False]
)))
try:
driver.filter_hosts(self.zone_manager, json.dumps(
'not', True, False, True, False
))
self.fail("Should give KeyError")
except KeyError, e:
pass
self.assertFalse(driver.filter_hosts(self.zone_manager, json.dumps(
['=', '$foo', 100]
)))
self.assertFalse(driver.filter_hosts(self.zone_manager, json.dumps(
['=', '$.....', 100]
)))
self.assertFalse(driver.filter_hosts(self.zone_manager, json.dumps(
['>', ['and', ['or', ['not', ['<', ['>=', ['<=', ['in', ]]]]]]]]
)))
self.assertFalse(driver.filter_hosts(self.zone_manager, json.dumps(
['=', {}, ['>', '$missing....foo']]
)))

View File

@@ -75,16 +75,25 @@ class InstanceTypeTestCase(test.TestCase):
def test_invalid_create_args_should_fail(self): def test_invalid_create_args_should_fail(self):
"""Ensures that instance type creation fails with invalid args""" """Ensures that instance type creation fails with invalid args"""
self.assertRaises( self.assertRaises(
exception.InvalidInputException, exception.InvalidInput,
instance_types.create, self.name, 0, 1, 120, self.flavorid) instance_types.create, self.name, 0, 1, 120, self.flavorid)
self.assertRaises( self.assertRaises(
exception.InvalidInputException, exception.InvalidInput,
instance_types.create, self.name, 256, -1, 120, self.flavorid) instance_types.create, self.name, 256, -1, 120, self.flavorid)
self.assertRaises( self.assertRaises(
exception.InvalidInputException, exception.InvalidInput,
instance_types.create, self.name, 256, 1, "aa", self.flavorid) instance_types.create, self.name, 256, 1, "aa", self.flavorid)
def test_non_existant_inst_type_shouldnt_delete(self): def test_non_existant_inst_type_shouldnt_delete(self):
"""Ensures that instance type creation fails with invalid args""" """Ensures that instance type creation fails with invalid args"""
self.assertRaises(exception.ApiError, self.assertRaises(exception.ApiError,
instance_types.destroy, "sfsfsdfdfs") instance_types.destroy, "sfsfsdfdfs")
def test_repeated_inst_types_should_raise_api_error(self):
"""Ensures that instance duplicates raises ApiError"""
new_name = self.name + "dup"
instance_types.create(new_name, 256, 1, 120, self.flavorid + 1)
instance_types.destroy(new_name)
self.assertRaises(
exception.ApiError,
instance_types.create, new_name, 256, 1, 120, self.flavorid)

View File

@@ -31,10 +31,9 @@ from nova import test
from nova import utils from nova import utils
from nova.api.ec2 import cloud from nova.api.ec2 import cloud
from nova.auth import manager from nova.auth import manager
from nova.compute import manager as compute_manager
from nova.compute import power_state from nova.compute import power_state
from nova.db.sqlalchemy import models from nova.virt.libvirt import connection
from nova.virt import libvirt_conn from nova.virt.libvirt import firewall
libvirt = None libvirt = None
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
@@ -46,6 +45,27 @@ def _concurrency(wait, done, target):
done.send() done.send()
def _create_network_info(count=1, ipv6=None):
if ipv6 is None:
ipv6 = FLAGS.use_ipv6
fake = 'fake'
fake_ip = '0.0.0.0/0'
fake_ip_2 = '0.0.0.1/0'
fake_ip_3 = '0.0.0.1/0'
network = {'gateway': fake,
'gateway_v6': fake,
'bridge': fake,
'cidr': fake_ip,
'cidr_v6': fake_ip}
mapping = {'mac': fake,
'ips': [{'ip': fake_ip}, {'ip': fake_ip}]}
if ipv6:
mapping['ip6s'] = [{'ip': fake_ip},
{'ip': fake_ip_2},
{'ip': fake_ip_3}]
return [(network, mapping) for x in xrange(0, count)]
class CacheConcurrencyTestCase(test.TestCase): class CacheConcurrencyTestCase(test.TestCase):
def setUp(self): def setUp(self):
super(CacheConcurrencyTestCase, self).setUp() super(CacheConcurrencyTestCase, self).setUp()
@@ -64,7 +84,7 @@ class CacheConcurrencyTestCase(test.TestCase):
def test_same_fname_concurrency(self): def test_same_fname_concurrency(self):
"""Ensures that the same fname cache runs at a sequentially""" """Ensures that the same fname cache runs at a sequentially"""
conn = libvirt_conn.LibvirtConnection conn = connection.LibvirtConnection
wait1 = eventlet.event.Event() wait1 = eventlet.event.Event()
done1 = eventlet.event.Event() done1 = eventlet.event.Event()
eventlet.spawn(conn._cache_image, _concurrency, eventlet.spawn(conn._cache_image, _concurrency,
@@ -85,7 +105,7 @@ class CacheConcurrencyTestCase(test.TestCase):
def test_different_fname_concurrency(self): def test_different_fname_concurrency(self):
"""Ensures that two different fname caches are concurrent""" """Ensures that two different fname caches are concurrent"""
conn = libvirt_conn.LibvirtConnection conn = connection.LibvirtConnection
wait1 = eventlet.event.Event() wait1 = eventlet.event.Event()
done1 = eventlet.event.Event() done1 = eventlet.event.Event()
eventlet.spawn(conn._cache_image, _concurrency, eventlet.spawn(conn._cache_image, _concurrency,
@@ -106,7 +126,7 @@ class CacheConcurrencyTestCase(test.TestCase):
class LibvirtConnTestCase(test.TestCase): class LibvirtConnTestCase(test.TestCase):
def setUp(self): def setUp(self):
super(LibvirtConnTestCase, self).setUp() super(LibvirtConnTestCase, self).setUp()
libvirt_conn._late_load_cheetah() connection._late_load_cheetah()
self.flags(fake_call=True) self.flags(fake_call=True)
self.manager = manager.AuthManager() self.manager = manager.AuthManager()
@@ -152,8 +172,8 @@ class LibvirtConnTestCase(test.TestCase):
return False return False
global libvirt global libvirt
libvirt = __import__('libvirt') libvirt = __import__('libvirt')
libvirt_conn.libvirt = __import__('libvirt') connection.libvirt = __import__('libvirt')
libvirt_conn.libxml2 = __import__('libxml2') connection.libxml2 = __import__('libxml2')
return True return True
def create_fake_libvirt_mock(self, **kwargs): def create_fake_libvirt_mock(self, **kwargs):
@@ -163,7 +183,7 @@ class LibvirtConnTestCase(test.TestCase):
class FakeLibvirtConnection(object): class FakeLibvirtConnection(object):
pass pass
# A fake libvirt_conn.IptablesFirewallDriver # A fake connection.IptablesFirewallDriver
class FakeIptablesFirewallDriver(object): class FakeIptablesFirewallDriver(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
@@ -179,11 +199,11 @@ class LibvirtConnTestCase(test.TestCase):
for key, val in kwargs.items(): for key, val in kwargs.items():
fake.__setattr__(key, val) fake.__setattr__(key, val)
# Inevitable mocks for libvirt_conn.LibvirtConnection # Inevitable mocks for connection.LibvirtConnection
self.mox.StubOutWithMock(libvirt_conn.utils, 'import_class') self.mox.StubOutWithMock(connection.utils, 'import_class')
libvirt_conn.utils.import_class(mox.IgnoreArg()).AndReturn(fakeip) connection.utils.import_class(mox.IgnoreArg()).AndReturn(fakeip)
self.mox.StubOutWithMock(libvirt_conn.LibvirtConnection, '_conn') self.mox.StubOutWithMock(connection.LibvirtConnection, '_conn')
libvirt_conn.LibvirtConnection._conn = fake connection.LibvirtConnection._conn = fake
def create_service(self, **kwargs): def create_service(self, **kwargs):
service_ref = {'host': kwargs.get('host', 'dummy'), service_ref = {'host': kwargs.get('host', 'dummy'),
@@ -194,6 +214,37 @@ class LibvirtConnTestCase(test.TestCase):
return db.service_create(context.get_admin_context(), service_ref) return db.service_create(context.get_admin_context(), service_ref)
def test_preparing_xml_info(self):
conn = connection.LibvirtConnection(True)
instance_ref = db.instance_create(self.context, self.test_instance)
result = conn._prepare_xml_info(instance_ref, False)
self.assertFalse(result['nics'])
result = conn._prepare_xml_info(instance_ref, False,
_create_network_info())
self.assertTrue(len(result['nics']) == 1)
result = conn._prepare_xml_info(instance_ref, False,
_create_network_info(2))
self.assertTrue(len(result['nics']) == 2)
def test_get_nic_for_xml_v4(self):
conn = connection.LibvirtConnection(True)
network, mapping = _create_network_info()[0]
self.flags(use_ipv6=False)
params = conn._get_nic_for_xml(network, mapping)['extra_params']
self.assertTrue(params.find('PROJNETV6') == -1)
self.assertTrue(params.find('PROJMASKV6') == -1)
def test_get_nic_for_xml_v6(self):
conn = connection.LibvirtConnection(True)
network, mapping = _create_network_info()[0]
self.flags(use_ipv6=True)
params = conn._get_nic_for_xml(network, mapping)['extra_params']
self.assertTrue(params.find('PROJNETV6') > -1)
self.assertTrue(params.find('PROJMASKV6') > -1)
def test_xml_and_uri_no_ramdisk_no_kernel(self): def test_xml_and_uri_no_ramdisk_no_kernel(self):
instance_data = dict(self.test_instance) instance_data = dict(self.test_instance)
self._check_xml_and_uri(instance_data, self._check_xml_and_uri(instance_data,
@@ -229,6 +280,22 @@ class LibvirtConnTestCase(test.TestCase):
instance_data = dict(self.test_instance) instance_data = dict(self.test_instance)
self._check_xml_and_container(instance_data) self._check_xml_and_container(instance_data)
def test_multi_nic(self):
instance_data = dict(self.test_instance)
network_info = _create_network_info(2)
conn = connection.LibvirtConnection(True)
instance_ref = db.instance_create(self.context, instance_data)
xml = conn.to_xml(instance_ref, False, network_info)
tree = xml_to_tree(xml)
interfaces = tree.findall("./devices/interface")
self.assertEquals(len(interfaces), 2)
parameters = interfaces[0].findall('./filterref/parameter')
self.assertEquals(interfaces[0].get('type'), 'bridge')
self.assertEquals(parameters[0].get('name'), 'IP')
self.assertEquals(parameters[0].get('value'), '0.0.0.0/0')
self.assertEquals(parameters[1].get('name'), 'DHCPSERVER')
self.assertEquals(parameters[1].get('value'), 'fake')
def _check_xml_and_container(self, instance): def _check_xml_and_container(self, instance):
user_context = context.RequestContext(project=self.project, user_context = context.RequestContext(project=self.project,
user=self.user) user=self.user)
@@ -247,7 +314,7 @@ class LibvirtConnTestCase(test.TestCase):
'instance_id': instance_ref['id']}) 'instance_id': instance_ref['id']})
self.flags(libvirt_type='lxc') self.flags(libvirt_type='lxc')
conn = libvirt_conn.LibvirtConnection(True) conn = connection.LibvirtConnection(True)
uri = conn.get_uri() uri = conn.get_uri()
self.assertEquals(uri, 'lxc:///') self.assertEquals(uri, 'lxc:///')
@@ -327,19 +394,13 @@ class LibvirtConnTestCase(test.TestCase):
check = (lambda t: t.find('./os/initrd'), None) check = (lambda t: t.find('./os/initrd'), None)
check_list.append(check) check_list.append(check)
parameter = './devices/interface/filterref/parameter'
common_checks = [ common_checks = [
(lambda t: t.find('.').tag, 'domain'), (lambda t: t.find('.').tag, 'domain'),
(lambda t: t.find( (lambda t: t.find(parameter).get('name'), 'IP'),
'./devices/interface/filterref/parameter').get('name'), 'IP'), (lambda t: t.find(parameter).get('value'), '10.11.12.13'),
(lambda t: t.find( (lambda t: t.findall(parameter)[1].get('name'), 'DHCPSERVER'),
'./devices/interface/filterref/parameter').get( (lambda t: t.findall(parameter)[1].get('value'), '10.0.0.1'),
'value'), '10.11.12.13'),
(lambda t: t.findall(
'./devices/interface/filterref/parameter')[1].get(
'name'), 'DHCPSERVER'),
(lambda t: t.findall(
'./devices/interface/filterref/parameter')[1].get(
'value'), '10.0.0.1'),
(lambda t: t.find('./devices/serial/source').get( (lambda t: t.find('./devices/serial/source').get(
'path').split('/')[1], 'console.log'), 'path').split('/')[1], 'console.log'),
(lambda t: t.find('./memory').text, '2097152')] (lambda t: t.find('./memory').text, '2097152')]
@@ -359,7 +420,7 @@ class LibvirtConnTestCase(test.TestCase):
for (libvirt_type, (expected_uri, checks)) in type_uri_map.iteritems(): for (libvirt_type, (expected_uri, checks)) in type_uri_map.iteritems():
FLAGS.libvirt_type = libvirt_type FLAGS.libvirt_type = libvirt_type
conn = libvirt_conn.LibvirtConnection(True) conn = connection.LibvirtConnection(True)
uri = conn.get_uri() uri = conn.get_uri()
self.assertEquals(uri, expected_uri) self.assertEquals(uri, expected_uri)
@@ -386,7 +447,7 @@ class LibvirtConnTestCase(test.TestCase):
FLAGS.libvirt_uri = testuri FLAGS.libvirt_uri = testuri
for (libvirt_type, (expected_uri, checks)) in type_uri_map.iteritems(): for (libvirt_type, (expected_uri, checks)) in type_uri_map.iteritems():
FLAGS.libvirt_type = libvirt_type FLAGS.libvirt_type = libvirt_type
conn = libvirt_conn.LibvirtConnection(True) conn = connection.LibvirtConnection(True)
uri = conn.get_uri() uri = conn.get_uri()
self.assertEquals(uri, testuri) self.assertEquals(uri, testuri)
db.instance_destroy(user_context, instance_ref['id']) db.instance_destroy(user_context, instance_ref['id'])
@@ -410,13 +471,13 @@ class LibvirtConnTestCase(test.TestCase):
self.create_fake_libvirt_mock(getVersion=getVersion, self.create_fake_libvirt_mock(getVersion=getVersion,
getType=getType, getType=getType,
listDomainsID=listDomainsID) listDomainsID=listDomainsID)
self.mox.StubOutWithMock(libvirt_conn.LibvirtConnection, self.mox.StubOutWithMock(connection.LibvirtConnection,
'get_cpu_info') 'get_cpu_info')
libvirt_conn.LibvirtConnection.get_cpu_info().AndReturn('cpuinfo') connection.LibvirtConnection.get_cpu_info().AndReturn('cpuinfo')
# Start test # Start test
self.mox.ReplayAll() self.mox.ReplayAll()
conn = libvirt_conn.LibvirtConnection(False) conn = connection.LibvirtConnection(False)
conn.update_available_resource(self.context, 'dummy') conn.update_available_resource(self.context, 'dummy')
service_ref = db.service_get(self.context, service_ref['id']) service_ref = db.service_get(self.context, service_ref['id'])
compute_node = service_ref['compute_node'][0] compute_node = service_ref['compute_node'][0]
@@ -450,8 +511,8 @@ class LibvirtConnTestCase(test.TestCase):
self.create_fake_libvirt_mock() self.create_fake_libvirt_mock()
self.mox.ReplayAll() self.mox.ReplayAll()
conn = libvirt_conn.LibvirtConnection(False) conn = connection.LibvirtConnection(False)
self.assertRaises(exception.Invalid, self.assertRaises(exception.ComputeServiceUnavailable,
conn.update_available_resource, conn.update_available_resource,
self.context, 'dummy') self.context, 'dummy')
@@ -485,7 +546,7 @@ class LibvirtConnTestCase(test.TestCase):
# Start test # Start test
self.mox.ReplayAll() self.mox.ReplayAll()
try: try:
conn = libvirt_conn.LibvirtConnection(False) conn = connection.LibvirtConnection(False)
conn.firewall_driver.setattr('setup_basic_filtering', fake_none) conn.firewall_driver.setattr('setup_basic_filtering', fake_none)
conn.firewall_driver.setattr('prepare_instance_filter', fake_none) conn.firewall_driver.setattr('prepare_instance_filter', fake_none)
conn.firewall_driver.setattr('instance_filter_exists', fake_none) conn.firewall_driver.setattr('instance_filter_exists', fake_none)
@@ -534,7 +595,7 @@ class LibvirtConnTestCase(test.TestCase):
# Start test # Start test
self.mox.ReplayAll() self.mox.ReplayAll()
conn = libvirt_conn.LibvirtConnection(False) conn = connection.LibvirtConnection(False)
self.assertRaises(libvirt.libvirtError, self.assertRaises(libvirt.libvirtError,
conn._live_migration, conn._live_migration,
self.context, instance_ref, 'dest', '', self.context, instance_ref, 'dest', '',
@@ -549,6 +610,48 @@ class LibvirtConnTestCase(test.TestCase):
db.volume_destroy(self.context, volume_ref['id']) db.volume_destroy(self.context, volume_ref['id'])
db.instance_destroy(self.context, instance_ref['id']) db.instance_destroy(self.context, instance_ref['id'])
def test_spawn_with_network_info(self):
# Skip if non-libvirt environment
if not self.lazy_load_library_exists():
return
# Preparing mocks
def fake_none(self, instance):
return
self.create_fake_libvirt_mock()
instance = db.instance_create(self.context, self.test_instance)
# Start test
self.mox.ReplayAll()
conn = connection.LibvirtConnection(False)
conn.firewall_driver.setattr('setup_basic_filtering', fake_none)
conn.firewall_driver.setattr('prepare_instance_filter', fake_none)
network = db.project_get_network(context.get_admin_context(),
self.project.id)
ip_dict = {'ip': self.test_ip,
'netmask': network['netmask'],
'enabled': '1'}
mapping = {'label': network['label'],
'gateway': network['gateway'],
'mac': instance['mac_address'],
'dns': [network['dns']],
'ips': [ip_dict]}
network_info = [(network, mapping)]
try:
conn.spawn(instance, network_info)
except Exception, e:
count = (0 <= str(e.message).find('Unexpected method call'))
self.assertTrue(count)
def test_get_host_ip_addr(self):
conn = connection.LibvirtConnection(False)
ip = conn.get_host_ip_addr()
self.assertEquals(ip, FLAGS.my_ip)
def tearDown(self): def tearDown(self):
self.manager.delete_project(self.project) self.manager.delete_project(self.project)
self.manager.delete_user(self.user) self.manager.delete_user(self.user)
@@ -569,7 +672,7 @@ class IptablesFirewallTestCase(test.TestCase):
class FakeLibvirtConnection(object): class FakeLibvirtConnection(object):
pass pass
self.fake_libvirt_connection = FakeLibvirtConnection() self.fake_libvirt_connection = FakeLibvirtConnection()
self.fw = libvirt_conn.IptablesFirewallDriver( self.fw = firewall.IptablesFirewallDriver(
get_connection=lambda: self.fake_libvirt_connection) get_connection=lambda: self.fake_libvirt_connection)
def tearDown(self): def tearDown(self):
@@ -614,11 +717,15 @@ class IptablesFirewallTestCase(test.TestCase):
'# Completed on Tue Jan 18 23:47:56 2011', '# Completed on Tue Jan 18 23:47:56 2011',
] ]
def _create_instance_ref(self):
return db.instance_create(self.context,
{'user_id': 'fake',
'project_id': 'fake',
'mac_address': '56:12:12:12:12:12',
'instance_type_id': 1})
def test_static_filters(self): def test_static_filters(self):
instance_ref = db.instance_create(self.context, instance_ref = self._create_instance_ref()
{'user_id': 'fake',
'project_id': 'fake',
'mac_address': '56:12:12:12:12:12'})
ip = '10.11.12.13' ip = '10.11.12.13'
network_ref = db.project_get_network(self.context, network_ref = db.project_get_network(self.context,
@@ -729,6 +836,50 @@ class IptablesFirewallTestCase(test.TestCase):
"TCP port 80/81 acceptance rule wasn't added") "TCP port 80/81 acceptance rule wasn't added")
db.instance_destroy(admin_ctxt, instance_ref['id']) db.instance_destroy(admin_ctxt, instance_ref['id'])
def test_filters_for_instance_with_ip_v6(self):
self.flags(use_ipv6=True)
network_info = _create_network_info()
rulesv4, rulesv6 = self.fw._filters_for_instance("fake", network_info)
self.assertEquals(len(rulesv4), 2)
self.assertEquals(len(rulesv6), 3)
def test_filters_for_instance_without_ip_v6(self):
self.flags(use_ipv6=False)
network_info = _create_network_info()
rulesv4, rulesv6 = self.fw._filters_for_instance("fake", network_info)
self.assertEquals(len(rulesv4), 2)
self.assertEquals(len(rulesv6), 0)
def test_multinic_iptables(self):
ipv4_rules_per_network = 2
ipv6_rules_per_network = 3
networks_count = 5
instance_ref = self._create_instance_ref()
network_info = _create_network_info(networks_count)
ipv4_len = len(self.fw.iptables.ipv4['filter'].rules)
ipv6_len = len(self.fw.iptables.ipv6['filter'].rules)
inst_ipv4, inst_ipv6 = self.fw.instance_rules(instance_ref,
network_info)
self.fw.add_filters_for_instance(instance_ref, network_info)
ipv4 = self.fw.iptables.ipv4['filter'].rules
ipv6 = self.fw.iptables.ipv6['filter'].rules
ipv4_network_rules = len(ipv4) - len(inst_ipv4) - ipv4_len
ipv6_network_rules = len(ipv6) - len(inst_ipv6) - ipv6_len
self.assertEquals(ipv4_network_rules,
ipv4_rules_per_network * networks_count)
self.assertEquals(ipv6_network_rules,
ipv6_rules_per_network * networks_count)
def test_do_refresh_security_group_rules(self):
instance_ref = self._create_instance_ref()
self.mox.StubOutWithMock(self.fw,
'add_filters_for_instance',
use_mock_anything=True)
self.fw.add_filters_for_instance(instance_ref, mox.IgnoreArg())
self.fw.instances[instance_ref['id']] = instance_ref
self.mox.ReplayAll()
self.fw.do_refresh_security_group_rules("fake")
class NWFilterTestCase(test.TestCase): class NWFilterTestCase(test.TestCase):
def setUp(self): def setUp(self):
@@ -745,7 +896,7 @@ class NWFilterTestCase(test.TestCase):
self.fake_libvirt_connection = Mock() self.fake_libvirt_connection = Mock()
self.fw = libvirt_conn.NWFilterFirewall( self.fw = firewall.NWFilterFirewall(
lambda: self.fake_libvirt_connection) lambda: self.fake_libvirt_connection)
def tearDown(self): def tearDown(self):
@@ -810,6 +961,28 @@ class NWFilterTestCase(test.TestCase):
return db.security_group_get_by_name(self.context, 'fake', 'testgroup') return db.security_group_get_by_name(self.context, 'fake', 'testgroup')
def _create_instance(self):
return db.instance_create(self.context,
{'user_id': 'fake',
'project_id': 'fake',
'mac_address': '00:A0:C9:14:C8:29',
'instance_type_id': 1})
def _create_instance_type(self, params={}):
"""Create a test instance"""
context = self.context.elevated()
inst = {}
inst['name'] = 'm1.small'
inst['memory_mb'] = '1024'
inst['vcpus'] = '1'
inst['local_gb'] = '20'
inst['flavorid'] = '1'
inst['swap'] = '2048'
inst['rxtx_quota'] = 100
inst['rxtx_cap'] = 200
inst.update(params)
return db.instance_type_create(context, inst)['id']
def test_creates_base_rule_first(self): def test_creates_base_rule_first(self):
# These come pre-defined by libvirt # These come pre-defined by libvirt
self.defined_filters = ['no-mac-spoofing', self.defined_filters = ['no-mac-spoofing',
@@ -838,24 +1011,18 @@ class NWFilterTestCase(test.TestCase):
self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock
instance_ref = db.instance_create(self.context, instance_ref = self._create_instance()
{'user_id': 'fake',
'project_id': 'fake',
'mac_address': '00:A0:C9:14:C8:29'})
inst_id = instance_ref['id'] inst_id = instance_ref['id']
ip = '10.11.12.13' ip = '10.11.12.13'
network_ref = db.project_get_network(self.context, network_ref = db.project_get_network(self.context, 'fake')
'fake') fixed_ip = {'address': ip, 'network_id': network_ref['id']}
fixed_ip = {'address': ip,
'network_id': network_ref['id']}
admin_ctxt = context.get_admin_context() admin_ctxt = context.get_admin_context()
db.fixed_ip_create(admin_ctxt, fixed_ip) db.fixed_ip_create(admin_ctxt, fixed_ip)
db.fixed_ip_update(admin_ctxt, ip, {'allocated': True, db.fixed_ip_update(admin_ctxt, ip, {'allocated': True,
'instance_id': instance_ref['id']}) 'instance_id': inst_id})
def _ensure_all_called(): def _ensure_all_called():
instance_filter = 'nova-instance-%s-%s' % (instance_ref['name'], instance_filter = 'nova-instance-%s-%s' % (instance_ref['name'],
@@ -881,3 +1048,11 @@ class NWFilterTestCase(test.TestCase):
_ensure_all_called() _ensure_all_called()
self.teardown_security_group() self.teardown_security_group()
db.instance_destroy(admin_ctxt, instance_ref['id']) db.instance_destroy(admin_ctxt, instance_ref['id'])
def test_create_network_filters(self):
instance_ref = self._create_instance()
network_info = _create_network_info(3)
result = self.fw._create_network_filters(instance_ref,
network_info,
"fake")
self.assertEquals(len(result), 3)

View File

@@ -29,11 +29,12 @@ from nova.utils import parse_mailmap, str_dict_replace
class ProjectTestCase(test.TestCase): class ProjectTestCase(test.TestCase):
def test_authors_up_to_date(self): def test_authors_up_to_date(self):
topdir = os.path.normpath(os.path.dirname(__file__) + '/../../') topdir = os.path.normpath(os.path.dirname(__file__) + '/../../')
missing = set()
contributors = set()
mailmap = parse_mailmap(os.path.join(topdir, '.mailmap'))
authors_file = open(os.path.join(topdir, 'Authors'), 'r').read()
if os.path.exists(os.path.join(topdir, '.bzr')): if os.path.exists(os.path.join(topdir, '.bzr')):
contributors = set()
mailmap = parse_mailmap(os.path.join(topdir, '.mailmap'))
import bzrlib.workingtree import bzrlib.workingtree
tree = bzrlib.workingtree.WorkingTree.open(topdir) tree = bzrlib.workingtree.WorkingTree.open(topdir)
tree.lock_read() tree.lock_read()
@@ -47,23 +48,37 @@ class ProjectTestCase(test.TestCase):
for r in revs: for r in revs:
for author in r.get_apparent_authors(): for author in r.get_apparent_authors():
email = author.split(' ')[-1] email = author.split(' ')[-1]
contributors.add(str_dict_replace(email, mailmap)) contributors.add(str_dict_replace(email,
mailmap))
authors_file = open(os.path.join(topdir, 'Authors'),
'r').read()
missing = set()
for contributor in contributors:
if contributor == 'nova-core':
continue
if not contributor in authors_file:
missing.add(contributor)
self.assertTrue(len(missing) == 0,
'%r not listed in Authors' % missing)
finally: finally:
tree.unlock() tree.unlock()
elif os.path.exists(os.path.join(topdir, '.git')):
import git
repo = git.Repo(topdir)
for commit in repo.head.commit.iter_parents():
email = commit.author.email
if email is None:
email = commit.author.name
if 'nova-core' in email:
continue
if email.split(' ')[-1] == '<>':
email = email.split(' ')[-2]
email = '<' + email + '>'
contributors.add(str_dict_replace(email, mailmap))
else:
return
for contributor in contributors:
if contributor == 'nova-core':
continue
if not contributor in authors_file:
missing.add(contributor)
self.assertTrue(len(missing) == 0,
'%r not listed in Authors' % missing)
class LockTestCase(test.TestCase): class LockTestCase(test.TestCase):
def test_synchronized_wrapped_function_metadata(self): def test_synchronized_wrapped_function_metadata(self):

117
nova/tests/test_notifier.py Normal file
View File

@@ -0,0 +1,117 @@
# Copyright 2011 OpenStack LLC.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import nova
from nova import context
from nova import flags
from nova import rpc
import nova.notifier.api
from nova.notifier.api import notify
from nova.notifier import no_op_notifier
from nova.notifier import rabbit_notifier
from nova import test
import stubout
class NotifierTestCase(test.TestCase):
"""Test case for notifications"""
def setUp(self):
super(NotifierTestCase, self).setUp()
self.stubs = stubout.StubOutForTesting()
def tearDown(self):
self.stubs.UnsetAll()
super(NotifierTestCase, self).tearDown()
def test_send_notification(self):
self.notify_called = False
def mock_notify(cls, *args):
self.notify_called = True
self.stubs.Set(nova.notifier.no_op_notifier, 'notify',
mock_notify)
class Mock(object):
pass
notify('publisher_id', 'event_type',
nova.notifier.api.WARN, dict(a=3))
self.assertEqual(self.notify_called, True)
def test_verify_message_format(self):
"""A test to ensure changing the message format is prohibitively
annoying"""
def message_assert(message):
fields = [('publisher_id', 'publisher_id'),
('event_type', 'event_type'),
('priority', 'WARN'),
('payload', dict(a=3))]
for k, v in fields:
self.assertEqual(message[k], v)
self.assertTrue(len(message['message_id']) > 0)
self.assertTrue(len(message['timestamp']) > 0)
self.stubs.Set(nova.notifier.no_op_notifier, 'notify',
message_assert)
notify('publisher_id', 'event_type',
nova.notifier.api.WARN, dict(a=3))
def test_send_rabbit_notification(self):
self.stubs.Set(nova.flags.FLAGS, 'notification_driver',
'nova.notifier.rabbit_notifier')
self.mock_cast = False
def mock_cast(cls, *args):
self.mock_cast = True
class Mock(object):
pass
self.stubs.Set(nova.rpc, 'cast', mock_cast)
notify('publisher_id', 'event_type',
nova.notifier.api.WARN, dict(a=3))
self.assertEqual(self.mock_cast, True)
def test_invalid_priority(self):
def mock_cast(cls, *args):
pass
class Mock(object):
pass
self.stubs.Set(nova.rpc, 'cast', mock_cast)
self.assertRaises(nova.notifier.api.BadPriorityException,
notify, 'publisher_id',
'event_type', 'not a priority', dict(a=3))
def test_rabbit_priority_queue(self):
self.stubs.Set(nova.flags.FLAGS, 'notification_driver',
'nova.notifier.rabbit_notifier')
self.stubs.Set(nova.flags.FLAGS, 'notification_topic',
'testnotify')
self.test_topic = None
def mock_cast(context, topic, msg):
self.test_topic = topic
self.stubs.Set(nova.rpc, 'cast', mock_cast)
notify('publisher_id',
'event_type', 'DEBUG', dict(a=3))
self.assertEqual(self.test_topic, 'testnotify.debug')

View File

@@ -120,12 +120,11 @@ class SchedulerTestCase(test.TestCase):
dest = 'dummydest' dest = 'dummydest'
ctxt = context.get_admin_context() ctxt = context.get_admin_context()
try: self.assertRaises(exception.NotFound, scheduler.show_host_resources,
scheduler.show_host_resources(ctxt, dest) ctxt, dest)
except exception.NotFound, e: #TODO(bcwaldon): reimplement this functionality
c1 = (e.message.find(_("does not exist or is not a " #c1 = (e.message.find(_("does not exist or is not a "
"compute node.")) >= 0) # "compute node.")) >= 0)
self.assertTrue(c1)
def _dic_is_equal(self, dic1, dic2, keys=None): def _dic_is_equal(self, dic1, dic2, keys=None):
"""Compares 2 dictionary contents(Helper method)""" """Compares 2 dictionary contents(Helper method)"""
@@ -698,14 +697,10 @@ class SimpleDriverTestCase(test.TestCase):
'topic': 'volume', 'report_count': 0} 'topic': 'volume', 'report_count': 0}
s_ref = db.service_create(self.context, dic) s_ref = db.service_create(self.context, dic)
try: self.assertRaises(exception.VolumeServiceUnavailable,
self.scheduler.driver.schedule_live_migration(self.context, self.scheduler.driver.schedule_live_migration,
instance_id, self.context, instance_id, i_ref['host'])
i_ref['host'])
except exception.Invalid, e:
c = (e.message.find('volume node is not alive') >= 0)
self.assertTrue(c)
db.instance_destroy(self.context, instance_id) db.instance_destroy(self.context, instance_id)
db.service_destroy(self.context, s_ref['id']) db.service_destroy(self.context, s_ref['id'])
db.volume_destroy(self.context, v_ref['id']) db.volume_destroy(self.context, v_ref['id'])
@@ -718,13 +713,10 @@ class SimpleDriverTestCase(test.TestCase):
s_ref = self._create_compute_service(created_at=t, updated_at=t, s_ref = self._create_compute_service(created_at=t, updated_at=t,
host=i_ref['host']) host=i_ref['host'])
try: self.assertRaises(exception.ComputeServiceUnavailable,
self.scheduler.driver._live_migration_src_check(self.context, self.scheduler.driver._live_migration_src_check,
i_ref) self.context, i_ref)
except exception.Invalid, e:
c = (e.message.find('is not alive') >= 0)
self.assertTrue(c)
db.instance_destroy(self.context, instance_id) db.instance_destroy(self.context, instance_id)
db.service_destroy(self.context, s_ref['id']) db.service_destroy(self.context, s_ref['id'])
@@ -737,7 +729,7 @@ class SimpleDriverTestCase(test.TestCase):
ret = self.scheduler.driver._live_migration_src_check(self.context, ret = self.scheduler.driver._live_migration_src_check(self.context,
i_ref) i_ref)
self.assertTrue(ret == None) self.assertTrue(ret is None)
db.instance_destroy(self.context, instance_id) db.instance_destroy(self.context, instance_id)
db.service_destroy(self.context, s_ref['id']) db.service_destroy(self.context, s_ref['id'])
@@ -749,14 +741,10 @@ class SimpleDriverTestCase(test.TestCase):
s_ref = self._create_compute_service(created_at=t, updated_at=t, s_ref = self._create_compute_service(created_at=t, updated_at=t,
host=i_ref['host']) host=i_ref['host'])
try: self.assertRaises(exception.ComputeServiceUnavailable,
self.scheduler.driver._live_migration_dest_check(self.context, self.scheduler.driver._live_migration_dest_check,
i_ref, self.context, i_ref, i_ref['host'])
i_ref['host'])
except exception.Invalid, e:
c = (e.message.find('is not alive') >= 0)
self.assertTrue(c)
db.instance_destroy(self.context, instance_id) db.instance_destroy(self.context, instance_id)
db.service_destroy(self.context, s_ref['id']) db.service_destroy(self.context, s_ref['id'])
@@ -766,14 +754,10 @@ class SimpleDriverTestCase(test.TestCase):
i_ref = db.instance_get(self.context, instance_id) i_ref = db.instance_get(self.context, instance_id)
s_ref = self._create_compute_service(host=i_ref['host']) s_ref = self._create_compute_service(host=i_ref['host'])
try: self.assertRaises(exception.UnableToMigrateToSelf,
self.scheduler.driver._live_migration_dest_check(self.context, self.scheduler.driver._live_migration_dest_check,
i_ref, self.context, i_ref, i_ref['host'])
i_ref['host'])
except exception.Invalid, e:
c = (e.message.find('choose other host') >= 0)
self.assertTrue(c)
db.instance_destroy(self.context, instance_id) db.instance_destroy(self.context, instance_id)
db.service_destroy(self.context, s_ref['id']) db.service_destroy(self.context, s_ref['id'])
@@ -784,14 +768,10 @@ class SimpleDriverTestCase(test.TestCase):
s_ref = self._create_compute_service(host='somewhere', s_ref = self._create_compute_service(host='somewhere',
memory_mb_used=12) memory_mb_used=12)
try: self.assertRaises(exception.MigrationError,
self.scheduler.driver._live_migration_dest_check(self.context, self.scheduler.driver._live_migration_dest_check,
i_ref, self.context, i_ref, 'somewhere')
'somewhere')
except exception.NotEmpty, e:
c = (e.message.find('Unable to migrate') >= 0)
self.assertTrue(c)
db.instance_destroy(self.context, instance_id) db.instance_destroy(self.context, instance_id)
db.service_destroy(self.context, s_ref['id']) db.service_destroy(self.context, s_ref['id'])
@@ -805,7 +785,7 @@ class SimpleDriverTestCase(test.TestCase):
ret = self.scheduler.driver._live_migration_dest_check(self.context, ret = self.scheduler.driver._live_migration_dest_check(self.context,
i_ref, i_ref,
'somewhere') 'somewhere')
self.assertTrue(ret == None) self.assertTrue(ret is None)
db.instance_destroy(self.context, instance_id) db.instance_destroy(self.context, instance_id)
db.service_destroy(self.context, s_ref['id']) db.service_destroy(self.context, s_ref['id'])
@@ -837,14 +817,10 @@ class SimpleDriverTestCase(test.TestCase):
"args": {'filename': fpath}}) "args": {'filename': fpath}})
self.mox.ReplayAll() self.mox.ReplayAll()
try: self.assertRaises(exception.SourceHostUnavailable,
self.scheduler.driver._live_migration_common_check(self.context, self.scheduler.driver._live_migration_common_check,
i_ref, self.context, i_ref, dest)
dest)
except exception.Invalid, e:
c = (e.message.find('does not exist') >= 0)
self.assertTrue(c)
db.instance_destroy(self.context, instance_id) db.instance_destroy(self.context, instance_id)
db.service_destroy(self.context, s_ref['id']) db.service_destroy(self.context, s_ref['id'])
@@ -865,14 +841,10 @@ class SimpleDriverTestCase(test.TestCase):
driver.mounted_on_same_shared_storage(mox.IgnoreArg(), i_ref, dest) driver.mounted_on_same_shared_storage(mox.IgnoreArg(), i_ref, dest)
self.mox.ReplayAll() self.mox.ReplayAll()
try: self.assertRaises(exception.InvalidHypervisorType,
self.scheduler.driver._live_migration_common_check(self.context, self.scheduler.driver._live_migration_common_check,
i_ref, self.context, i_ref, dest)
dest)
except exception.Invalid, e:
c = (e.message.find(_('Different hypervisor type')) >= 0)
self.assertTrue(c)
db.instance_destroy(self.context, instance_id) db.instance_destroy(self.context, instance_id)
db.service_destroy(self.context, s_ref['id']) db.service_destroy(self.context, s_ref['id'])
db.service_destroy(self.context, s_ref2['id']) db.service_destroy(self.context, s_ref2['id'])
@@ -895,14 +867,10 @@ class SimpleDriverTestCase(test.TestCase):
driver.mounted_on_same_shared_storage(mox.IgnoreArg(), i_ref, dest) driver.mounted_on_same_shared_storage(mox.IgnoreArg(), i_ref, dest)
self.mox.ReplayAll() self.mox.ReplayAll()
try: self.assertRaises(exception.DestinationHypervisorTooOld,
self.scheduler.driver._live_migration_common_check(self.context, self.scheduler.driver._live_migration_common_check,
i_ref, self.context, i_ref, dest)
dest)
except exception.Invalid, e:
c = (e.message.find(_('Older hypervisor version')) >= 0)
self.assertTrue(c)
db.instance_destroy(self.context, instance_id) db.instance_destroy(self.context, instance_id)
db.service_destroy(self.context, s_ref['id']) db.service_destroy(self.context, s_ref['id'])
db.service_destroy(self.context, s_ref2['id']) db.service_destroy(self.context, s_ref2['id'])
@@ -944,7 +912,8 @@ class SimpleDriverTestCase(test.TestCase):
class FakeZone(object): class FakeZone(object):
def __init__(self, api_url, username, password): def __init__(self, id, api_url, username, password):
self.id = id
self.api_url = api_url self.api_url = api_url
self.username = username self.username = username
self.password = password self.password = password
@@ -952,7 +921,7 @@ class FakeZone(object):
def zone_get_all(context): def zone_get_all(context):
return [ return [
FakeZone('http://example.com', 'bob', 'xxx'), FakeZone(1, 'http://example.com', 'bob', 'xxx'),
] ]
@@ -968,7 +937,7 @@ class FakeRerouteCompute(api.reroute_compute):
def go_boom(self, context, instance): def go_boom(self, context, instance):
raise exception.InstanceNotFound("boom message", instance) raise exception.InstanceNotFound(instance_id=instance)
def found_instance(self, context, instance): def found_instance(self, context, instance):
@@ -1017,11 +986,8 @@ class ZoneRedirectTest(test.TestCase):
def test_routing_flags(self): def test_routing_flags(self):
FLAGS.enable_zone_routing = False FLAGS.enable_zone_routing = False
decorator = FakeRerouteCompute("foo") decorator = FakeRerouteCompute("foo")
try: self.assertRaises(exception.InstanceNotFound, decorator(go_boom),
result = decorator(go_boom)(None, None, 1) None, None, 1)
self.assertFail(_("Should have thrown exception."))
except exception.InstanceNotFound, e:
self.assertEquals(e.message, 'boom message')
def test_get_collection_context_and_id(self): def test_get_collection_context_and_id(self):
decorator = api.reroute_compute("foo") decorator = api.reroute_compute("foo")
@@ -1072,7 +1038,7 @@ class FakeNovaClient(object):
class DynamicNovaClientTest(test.TestCase): class DynamicNovaClientTest(test.TestCase):
def test_issue_novaclient_command_found(self): def test_issue_novaclient_command_found(self):
zone = FakeZone('http://example.com', 'bob', 'xxx') zone = FakeZone(1, 'http://example.com', 'bob', 'xxx')
self.assertEquals(api._issue_novaclient_command( self.assertEquals(api._issue_novaclient_command(
FakeNovaClient(FakeServerCollection()), FakeNovaClient(FakeServerCollection()),
zone, "servers", "get", 100).a, 10) zone, "servers", "get", 100).a, 10)
@@ -1086,7 +1052,7 @@ class DynamicNovaClientTest(test.TestCase):
zone, "servers", "pause", 100), None) zone, "servers", "pause", 100), None)
def test_issue_novaclient_command_not_found(self): def test_issue_novaclient_command_not_found(self):
zone = FakeZone('http://example.com', 'bob', 'xxx') zone = FakeZone(1, 'http://example.com', 'bob', 'xxx')
self.assertEquals(api._issue_novaclient_command( self.assertEquals(api._issue_novaclient_command(
FakeNovaClient(FakeEmptyServerCollection()), FakeNovaClient(FakeEmptyServerCollection()),
zone, "servers", "get", 100), None) zone, "servers", "get", 100), None)
@@ -1098,3 +1064,55 @@ class DynamicNovaClientTest(test.TestCase):
self.assertEquals(api._issue_novaclient_command( self.assertEquals(api._issue_novaclient_command(
FakeNovaClient(FakeEmptyServerCollection()), FakeNovaClient(FakeEmptyServerCollection()),
zone, "servers", "any", "name"), None) zone, "servers", "any", "name"), None)
class FakeZonesProxy(object):
def do_something(*args, **kwargs):
return 42
def raises_exception(*args, **kwargs):
raise Exception('testing')
class FakeNovaClientOpenStack(object):
def __init__(self, *args, **kwargs):
self.zones = FakeZonesProxy()
def authenticate(self):
pass
class CallZoneMethodTest(test.TestCase):
def setUp(self):
super(CallZoneMethodTest, self).setUp()
self.stubs = stubout.StubOutForTesting()
self.stubs.Set(db, 'zone_get_all', zone_get_all)
self.stubs.Set(novaclient, 'OpenStack', FakeNovaClientOpenStack)
def tearDown(self):
self.stubs.UnsetAll()
super(CallZoneMethodTest, self).tearDown()
def test_call_zone_method(self):
context = {}
method = 'do_something'
results = api.call_zone_method(context, method)
expected = [(1, 42)]
self.assertEqual(expected, results)
def test_call_zone_method_not_present(self):
context = {}
method = 'not_present'
self.assertRaises(AttributeError, api.call_zone_method,
context, method)
def test_call_zone_method_generates_exception(self):
context = {}
method = 'raises_exception'
results = api.call_zone_method(context, method)
# FIXME(sirp): for now the _error_trap code is catching errors and
# converting them to a ("ERROR", "string") tuples. The code (and this
# test) should eventually handle real exceptions.
expected = [(1, ('ERROR', 'testing'))]
self.assertEqual(expected, results)

View File

@@ -142,7 +142,7 @@ class VolumeTestCase(test.TestCase):
self.assertEqual(vol['status'], "available") self.assertEqual(vol['status'], "available")
self.volume.delete_volume(self.context, volume_id) self.volume.delete_volume(self.context, volume_id)
self.assertRaises(exception.Error, self.assertRaises(exception.VolumeNotFound,
db.volume_get, db.volume_get,
self.context, self.context,
volume_id) volume_id)

View File

@@ -16,7 +16,9 @@
"""Test suite for XenAPI.""" """Test suite for XenAPI."""
import eventlet
import functools import functools
import json
import os import os
import re import re
import stubout import stubout
@@ -197,6 +199,28 @@ class XenAPIVMTestCase(test.TestCase):
self.context = context.RequestContext('fake', 'fake', False) self.context = context.RequestContext('fake', 'fake', False)
self.conn = xenapi_conn.get_connection(False) self.conn = xenapi_conn.get_connection(False)
def test_parallel_builds(self):
stubs.stubout_loopingcall_delay(self.stubs)
def _do_build(id, proj, user, *args):
values = {
'id': id,
'project_id': proj,
'user_id': user,
'image_id': 1,
'kernel_id': 2,
'ramdisk_id': 3,
'instance_type_id': '3', # m1.large
'mac_address': 'aa:bb:cc:dd:ee:ff',
'os_type': 'linux'}
instance = db.instance_create(self.context, values)
self.conn.spawn(instance)
gt1 = eventlet.spawn(_do_build, 1, self.project.id, self.user.id)
gt2 = eventlet.spawn(_do_build, 2, self.project.id, self.user.id)
gt1.wait()
gt2.wait()
def test_list_instances_0(self): def test_list_instances_0(self):
instances = self.conn.list_instances() instances = self.conn.list_instances()
self.assertEquals(instances, []) self.assertEquals(instances, [])
@@ -665,3 +689,52 @@ class XenAPIDetermineDiskImageTestCase(test.TestCase):
self.fake_instance.image_id = glance_stubs.FakeGlance.IMAGE_VHD self.fake_instance.image_id = glance_stubs.FakeGlance.IMAGE_VHD
self.fake_instance.kernel_id = None self.fake_instance.kernel_id = None
self.assert_disk_type(vm_utils.ImageType.DISK_VHD) self.assert_disk_type(vm_utils.ImageType.DISK_VHD)
class FakeXenApi(object):
"""Fake XenApi for testing HostState."""
class FakeSR(object):
def get_record(self, ref):
return {'virtual_allocation': 10000,
'physical_utilisation': 20000}
SR = FakeSR()
class FakeSession(object):
"""Fake Session class for HostState testing."""
def async_call_plugin(self, *args):
return None
def wait_for_task(self, *args):
vm = {'total': 10,
'overhead': 20,
'free': 30,
'free-computed': 40}
return json.dumps({'host_memory': vm})
def get_xenapi(self):
return FakeXenApi()
class HostStateTestCase(test.TestCase):
"""Tests HostState, which holds metrics from XenServer that get
reported back to the Schedulers."""
def _fake_safe_find_sr(self, session):
"""None SR ref since we're ignoring it in FakeSR."""
return None
def test_host_state(self):
self.stubs = stubout.StubOutForTesting()
self.stubs.Set(vm_utils, 'safe_find_sr', self._fake_safe_find_sr)
host_state = xenapi_conn.HostState(FakeSession())
stats = host_state._stats
self.assertEquals(stats['disk_total'], 10000)
self.assertEquals(stats['disk_used'], 20000)
self.assertEquals(stats['host_memory_total'], 10)
self.assertEquals(stats['host_memory_overhead'], 20)
self.assertEquals(stats['host_memory_free'], 30)
self.assertEquals(stats['host_memory_free_computed'], 40)

View File

@@ -0,0 +1,119 @@
# Copyright 2011 OpenStack LLC.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Tests For Zone Aware Scheduler.
"""
from nova import test
from nova.scheduler import driver
from nova.scheduler import zone_aware_scheduler
from nova.scheduler import zone_manager
class FakeZoneAwareScheduler(zone_aware_scheduler.ZoneAwareScheduler):
def filter_hosts(self, num, specs):
# NOTE(sirp): this is returning [(hostname, services)]
return self.zone_manager.service_states.items()
def weigh_hosts(self, num, specs, hosts):
fake_weight = 99
weighted = []
for hostname, caps in hosts:
weighted.append(dict(weight=fake_weight, name=hostname))
return weighted
class FakeZoneManager(zone_manager.ZoneManager):
def __init__(self):
self.service_states = {
'host1': {
'compute': {'ram': 1000}
},
'host2': {
'compute': {'ram': 2000}
},
'host3': {
'compute': {'ram': 3000}
}
}
class FakeEmptyZoneManager(zone_manager.ZoneManager):
def __init__(self):
self.service_states = {}
def fake_empty_call_zone_method(context, method, specs):
return []
def fake_call_zone_method(context, method, specs):
return [
('zone1', [
dict(weight=1, blob='AAAAAAA'),
dict(weight=111, blob='BBBBBBB'),
dict(weight=112, blob='CCCCCCC'),
dict(weight=113, blob='DDDDDDD'),
]),
('zone2', [
dict(weight=120, blob='EEEEEEE'),
dict(weight=2, blob='FFFFFFF'),
dict(weight=122, blob='GGGGGGG'),
dict(weight=123, blob='HHHHHHH'),
]),
('zone3', [
dict(weight=130, blob='IIIIIII'),
dict(weight=131, blob='JJJJJJJ'),
dict(weight=132, blob='KKKKKKK'),
dict(weight=3, blob='LLLLLLL'),
]),
]
class ZoneAwareSchedulerTestCase(test.TestCase):
"""Test case for Zone Aware Scheduler."""
def test_zone_aware_scheduler(self):
"""
Create a nested set of FakeZones, ensure that a select call returns the
appropriate build plan.
"""
sched = FakeZoneAwareScheduler()
self.stubs.Set(sched, '_call_zone_method', fake_call_zone_method)
zm = FakeZoneManager()
sched.set_zone_manager(zm)
fake_context = {}
build_plan = sched.select(fake_context, {})
self.assertEqual(15, len(build_plan))
hostnames = [plan_item['name']
for plan_item in build_plan if 'name' in plan_item]
self.assertEqual(3, len(hostnames))
def test_empty_zone_aware_scheduler(self):
"""
Ensure empty hosts & child_zones result in NoValidHosts exception.
"""
sched = FakeZoneAwareScheduler()
self.stubs.Set(sched, '_call_zone_method', fake_empty_call_zone_method)
zm = FakeEmptyZoneManager()
sched.set_zone_manager(zm)
fake_context = {}
self.assertRaises(driver.NoValidHost, sched.schedule, fake_context, {})

View File

@@ -78,38 +78,32 @@ class ZoneManagerTestCase(test.TestCase):
def test_service_capabilities(self): def test_service_capabilities(self):
zm = zone_manager.ZoneManager() zm = zone_manager.ZoneManager()
caps = zm.get_zone_capabilities(self, None) caps = zm.get_zone_capabilities(None)
self.assertEquals(caps, {}) self.assertEquals(caps, {})
zm.update_service_capabilities("svc1", "host1", dict(a=1, b=2)) zm.update_service_capabilities("svc1", "host1", dict(a=1, b=2))
caps = zm.get_zone_capabilities(self, None) caps = zm.get_zone_capabilities(None)
self.assertEquals(caps, dict(svc1_a=(1, 1), svc1_b=(2, 2))) self.assertEquals(caps, dict(svc1_a=(1, 1), svc1_b=(2, 2)))
zm.update_service_capabilities("svc1", "host1", dict(a=2, b=3)) zm.update_service_capabilities("svc1", "host1", dict(a=2, b=3))
caps = zm.get_zone_capabilities(self, None) caps = zm.get_zone_capabilities(None)
self.assertEquals(caps, dict(svc1_a=(2, 2), svc1_b=(3, 3))) self.assertEquals(caps, dict(svc1_a=(2, 2), svc1_b=(3, 3)))
zm.update_service_capabilities("svc1", "host2", dict(a=20, b=30)) zm.update_service_capabilities("svc1", "host2", dict(a=20, b=30))
caps = zm.get_zone_capabilities(self, None) caps = zm.get_zone_capabilities(None)
self.assertEquals(caps, dict(svc1_a=(2, 20), svc1_b=(3, 30))) self.assertEquals(caps, dict(svc1_a=(2, 20), svc1_b=(3, 30)))
zm.update_service_capabilities("svc10", "host1", dict(a=99, b=99)) zm.update_service_capabilities("svc10", "host1", dict(a=99, b=99))
caps = zm.get_zone_capabilities(self, None) caps = zm.get_zone_capabilities(None)
self.assertEquals(caps, dict(svc1_a=(2, 20), svc1_b=(3, 30), self.assertEquals(caps, dict(svc1_a=(2, 20), svc1_b=(3, 30),
svc10_a=(99, 99), svc10_b=(99, 99))) svc10_a=(99, 99), svc10_b=(99, 99)))
zm.update_service_capabilities("svc1", "host3", dict(c=5)) zm.update_service_capabilities("svc1", "host3", dict(c=5))
caps = zm.get_zone_capabilities(self, None) caps = zm.get_zone_capabilities(None)
self.assertEquals(caps, dict(svc1_a=(2, 20), svc1_b=(3, 30), self.assertEquals(caps, dict(svc1_a=(2, 20), svc1_b=(3, 30),
svc1_c=(5, 5), svc10_a=(99, 99), svc1_c=(5, 5), svc10_a=(99, 99),
svc10_b=(99, 99))) svc10_b=(99, 99)))
caps = zm.get_zone_capabilities(self, 'svc1')
self.assertEquals(caps, dict(svc1_a=(2, 20), svc1_b=(3, 30),
svc1_c=(5, 5)))
caps = zm.get_zone_capabilities(self, 'svc10')
self.assertEquals(caps, dict(svc10_a=(99, 99), svc10_b=(99, 99)))
def test_refresh_from_db_replace_existing(self): def test_refresh_from_db_replace_existing(self):
zm = zone_manager.ZoneManager() zm = zone_manager.ZoneManager()
zone_state = zone_manager.ZoneState() zone_state = zone_manager.ZoneState()

View File

@@ -16,6 +16,7 @@
"""Stubouts, mocks and fixtures for the test suite""" """Stubouts, mocks and fixtures for the test suite"""
import eventlet
from nova.virt import xenapi_conn from nova.virt import xenapi_conn
from nova.virt.xenapi import fake from nova.virt.xenapi import fake
from nova.virt.xenapi import volume_utils from nova.virt.xenapi import volume_utils
@@ -28,29 +29,6 @@ def stubout_instance_snapshot(stubs):
@classmethod @classmethod
def fake_fetch_image(cls, session, instance_id, image, user, project, def fake_fetch_image(cls, session, instance_id, image, user, project,
type): type):
# Stubout wait_for_task
def fake_wait_for_task(self, task, id):
class FakeEvent:
def send(self, value):
self.rv = value
def wait(self):
return self.rv
done = FakeEvent()
self._poll_task(id, task, done)
rv = done.wait()
return rv
def fake_loop(self):
pass
stubs.Set(xenapi_conn.XenAPISession, 'wait_for_task',
fake_wait_for_task)
stubs.Set(xenapi_conn.XenAPISession, '_stop_loop', fake_loop)
from nova.virt.xenapi.fake import create_vdi from nova.virt.xenapi.fake import create_vdi
name_label = "instance-%s" % instance_id name_label = "instance-%s" % instance_id
#TODO: create fake SR record #TODO: create fake SR record
@@ -63,11 +41,6 @@ def stubout_instance_snapshot(stubs):
stubs.Set(vm_utils.VMHelper, 'fetch_image', fake_fetch_image) stubs.Set(vm_utils.VMHelper, 'fetch_image', fake_fetch_image)
def fake_parse_xmlrpc_value(val):
return val
stubs.Set(xenapi_conn, '_parse_xmlrpc_value', fake_parse_xmlrpc_value)
def fake_wait_for_vhd_coalesce(session, instance_id, sr_ref, vdi_ref, def fake_wait_for_vhd_coalesce(session, instance_id, sr_ref, vdi_ref,
original_parent_uuid): original_parent_uuid):
from nova.virt.xenapi.fake import create_vdi from nova.virt.xenapi.fake import create_vdi
@@ -144,6 +117,16 @@ def stubout_loopingcall_start(stubs):
stubs.Set(utils.LoopingCall, 'start', fake_start) stubs.Set(utils.LoopingCall, 'start', fake_start)
def stubout_loopingcall_delay(stubs):
def fake_start(self, interval, now=True):
self._running = True
eventlet.sleep(1)
self.f(*self.args, **self.kw)
# This would fail before parallel xenapi calls were fixed
assert self._running == False
stubs.Set(utils.LoopingCall, 'start', fake_start)
class FakeSessionForVMTests(fake.SessionBase): class FakeSessionForVMTests(fake.SessionBase):
""" Stubs out a XenAPISession for VM tests """ """ Stubs out a XenAPISession for VM tests """
def __init__(self, uri): def __init__(self, uri):