merged from trunk

This commit is contained in:
John Tran
2011-05-12 12:23:31 -07:00
21 changed files with 958 additions and 194 deletions

View File

@@ -1,3 +1,4 @@
Alex Meade <alex.meade@rackspace.com>
Andy Smith <code@term.ie> Andy Smith <code@term.ie>
Andy Southgate <andy.southgate@citrix.com> Andy Southgate <andy.southgate@citrix.com>
Anne Gentle <anne@openstack.org> Anne Gentle <anne@openstack.org>
@@ -43,11 +44,14 @@ Josh Kearney <josh@jk0.org>
Josh Kleinpeter <josh@kleinpeter.org> Josh Kleinpeter <josh@kleinpeter.org>
Joshua McKenty <jmckenty@gmail.com> Joshua McKenty <jmckenty@gmail.com>
Justin Santa Barbara <justin@fathomdb.com> Justin Santa Barbara <justin@fathomdb.com>
Justin Shepherd <jshepher@rackspace.com>
Kei Masumoto <masumotok@nttdata.co.jp> Kei Masumoto <masumotok@nttdata.co.jp>
Ken Pepple <ken.pepple@gmail.com> Ken Pepple <ken.pepple@gmail.com>
Kevin Bringard <kbringard@attinteractive.com>
Kevin L. Mitchell <kevin.mitchell@rackspace.com> Kevin L. Mitchell <kevin.mitchell@rackspace.com>
Koji Iida <iida.koji@lab.ntt.co.jp> Koji Iida <iida.koji@lab.ntt.co.jp>
Lorin Hochstein <lorin@isi.edu> Lorin Hochstein <lorin@isi.edu>
Lvov Maxim <usrleon@gmail.com>
Mark Washenberger <mark.washenberger@rackspace.com> Mark Washenberger <mark.washenberger@rackspace.com>
Masanori Itoh <itoumsn@nttdata.co.jp> Masanori Itoh <itoumsn@nttdata.co.jp>
Matt Dietz <matt.dietz@rackspace.com> Matt Dietz <matt.dietz@rackspace.com>
@@ -76,6 +80,8 @@ Trey Morris <trey.morris@rackspace.com>
Tushar Patil <tushar.vitthal.patil@gmail.com> Tushar Patil <tushar.vitthal.patil@gmail.com>
Vasiliy Shlykov <vash@vasiliyshlykov.org> Vasiliy Shlykov <vash@vasiliyshlykov.org>
Vishvananda Ishaya <vishvananda@gmail.com> Vishvananda Ishaya <vishvananda@gmail.com>
William Wolf <will.wolf@rackspace.com>
Yoshiaki Tamura <yoshi@midokura.jp> Yoshiaki Tamura <yoshi@midokura.jp>
Youcef Laribi <Youcef.Laribi@eu.citrix.com> Youcef Laribi <Youcef.Laribi@eu.citrix.com>
Yuriy Taraday <yorik.sar@gmail.com>
Zhixue Wu <Zhixue.Wu@citrix.com> Zhixue Wu <Zhixue.Wu@citrix.com>

View File

@@ -82,6 +82,7 @@ from nova import log as logging
from nova import quota from nova import quota
from nova import rpc from nova import rpc
from nova import utils from nova import utils
from nova import version
from nova.api.ec2 import ec2utils from nova.api.ec2 import ec2utils
from nova.auth import manager from nova.auth import manager
from nova.cloudpipe import pipelib from nova.cloudpipe import pipelib
@@ -150,7 +151,7 @@ class VpnCommands(object):
state = 'up' state = 'up'
print address, print address,
print vpn['host'], print vpn['host'],
print vpn['ec2_id'], print ec2utils.id_to_ec2_id(vpn['id']),
print vpn['state_description'], print vpn['state_description'],
print state print state
else: else:
@@ -385,10 +386,10 @@ class ProjectCommands(object):
with open(filename, 'w') as f: with open(filename, 'w') as f:
f.write(rc) f.write(rc)
def list(self): def list(self, username=None):
"""Lists all projects """Lists all projects
arguments: <none>""" arguments: [username]"""
for project in self.manager.get_projects(): for project in self.manager.get_projects(username):
print project.name print project.name
def quota(self, project_id, key=None, value=None): def quota(self, project_id, key=None, value=None):
@@ -758,6 +759,17 @@ class DbCommands(object):
print migration.db_version() print migration.db_version()
class VersionCommands(object):
"""Class for exposing the codebase version."""
def __init__(self):
pass
def list(self):
print _("%s (%s)") %\
(version.version_string(), version.version_string_with_vcs())
class VolumeCommands(object): class VolumeCommands(object):
"""Methods for dealing with a cloud in an odd state""" """Methods for dealing with a cloud in an odd state"""
@@ -960,7 +972,7 @@ class ImageCommands(object):
try: try:
internal_id = ec2utils.ec2_id_to_id(old_image_id) internal_id = ec2utils.ec2_id_to_id(old_image_id)
image = self.image_service.show(context, internal_id) image = self.image_service.show(context, internal_id)
except exception.NotFound: except (exception.InvalidEc2Id, exception.ImageNotFound):
image = self.image_service.show_by_name(context, old_image_id) image = self.image_service.show_by_name(context, old_image_id)
return image['id'] return image['id']
@@ -1049,7 +1061,8 @@ CATEGORIES = [
('volume', VolumeCommands), ('volume', VolumeCommands),
('instance_type', InstanceTypeCommands), ('instance_type', InstanceTypeCommands),
('image', ImageCommands), ('image', ImageCommands),
('flavor', InstanceTypeCommands)] ('flavor', InstanceTypeCommands),
('version', VersionCommands)]
def lazy_match(name, key_value_tuples): def lazy_match(name, key_value_tuples):
@@ -1091,6 +1104,8 @@ def main():
script_name = argv.pop(0) script_name = argv.pop(0)
if len(argv) < 1: if len(argv) < 1:
print _("\nOpenStack Nova version: %s (%s)\n") %\
(version.version_string(), version.version_string_with_vcs())
print script_name + " category action [<args>]" print script_name + " category action [<args>]"
print _("Available categories:") print _("Available categories:")
for k, _v in CATEGORIES: for k, _v in CATEGORIES:

View File

@@ -81,7 +81,7 @@ class DbDriver(object):
user_ref = db.user_create(context.get_admin_context(), values) user_ref = db.user_create(context.get_admin_context(), values)
return self._db_user_to_auth_user(user_ref) return self._db_user_to_auth_user(user_ref)
except exception.Duplicate, e: except exception.Duplicate, e:
raise exception.Duplicate(_('User %s already exists') % name) raise exception.UserExists(user=name)
def _db_user_to_auth_user(self, user_ref): def _db_user_to_auth_user(self, user_ref):
return {'id': user_ref['id'], return {'id': user_ref['id'],
@@ -103,9 +103,7 @@ class DbDriver(object):
"""Create a project""" """Create a project"""
manager = db.user_get(context.get_admin_context(), manager_uid) manager = db.user_get(context.get_admin_context(), manager_uid)
if not manager: if not manager:
raise exception.NotFound(_("Project can't be created because " raise exception.UserNotFound(user_id=manager_uid)
"manager %s doesn't exist")
% manager_uid)
# description is a required attribute # description is a required attribute
if description is None: if description is None:
@@ -119,9 +117,7 @@ class DbDriver(object):
for member_uid in member_uids: for member_uid in member_uids:
member = db.user_get(context.get_admin_context(), member_uid) member = db.user_get(context.get_admin_context(), member_uid)
if not member: if not member:
raise exception.NotFound(_("Project can't be created " raise exception.UserNotFound(user_id=member_uid)
"because user %s doesn't exist")
% member_uid)
members.add(member) members.add(member)
values = {'id': name, values = {'id': name,
@@ -132,8 +128,7 @@ class DbDriver(object):
try: try:
project = db.project_create(context.get_admin_context(), values) project = db.project_create(context.get_admin_context(), values)
except exception.Duplicate: except exception.Duplicate:
raise exception.Duplicate(_("Project can't be created because " raise exception.ProjectExists(project=name)
"project %s already exists") % name)
for member in members: for member in members:
db.project_add_member(context.get_admin_context(), db.project_add_member(context.get_admin_context(),
@@ -154,9 +149,7 @@ class DbDriver(object):
if manager_uid: if manager_uid:
manager = db.user_get(context.get_admin_context(), manager_uid) manager = db.user_get(context.get_admin_context(), manager_uid)
if not manager: if not manager:
raise exception.NotFound(_("Project can't be modified because " raise exception.UserNotFound(user_id=manager_uid)
"manager %s doesn't exist") %
manager_uid)
values['project_manager'] = manager['id'] values['project_manager'] = manager['id']
if description: if description:
values['description'] = description values['description'] = description
@@ -244,8 +237,8 @@ class DbDriver(object):
def _validate_user_and_project(self, user_id, project_id): def _validate_user_and_project(self, user_id, project_id):
user = db.user_get(context.get_admin_context(), user_id) user = db.user_get(context.get_admin_context(), user_id)
if not user: if not user:
raise exception.NotFound(_('User "%s" not found') % user_id) raise exception.UserNotFound(user_id=user_id)
project = db.project_get(context.get_admin_context(), project_id) project = db.project_get(context.get_admin_context(), project_id)
if not project: if not project:
raise exception.NotFound(_('Project "%s" not found') % project_id) raise exception.ProjectNotFound(project_id=project_id)
return user, project return user, project

View File

@@ -171,7 +171,7 @@ class LdapDriver(object):
def create_user(self, name, access_key, secret_key, is_admin): def create_user(self, name, access_key, secret_key, is_admin):
"""Create a user""" """Create a user"""
if self.__user_exists(name): if self.__user_exists(name):
raise exception.Duplicate(_("LDAP user %s already exists") % name) raise exception.LDAPUserExists(user=name)
if FLAGS.ldap_user_modify_only: if FLAGS.ldap_user_modify_only:
if self.__ldap_user_exists(name): if self.__ldap_user_exists(name):
# Retrieve user by name # Retrieve user by name
@@ -202,8 +202,7 @@ class LdapDriver(object):
self.conn.modify_s(self.__uid_to_dn(name), attr) self.conn.modify_s(self.__uid_to_dn(name), attr)
return self.get_user(name) return self.get_user(name)
else: else:
raise exception.NotFound(_("LDAP object for %s doesn't exist") raise exception.LDAPUserNotFound(user_id=name)
% name)
else: else:
attr = [ attr = [
('objectclass', ['person', ('objectclass', ['person',
@@ -226,12 +225,9 @@ class LdapDriver(object):
description=None, member_uids=None): description=None, member_uids=None):
"""Create a project""" """Create a project"""
if self.__project_exists(name): if self.__project_exists(name):
raise exception.Duplicate(_("Project can't be created because " raise exception.ProjectExists(project=name)
"project %s already exists") % name)
if not self.__user_exists(manager_uid): if not self.__user_exists(manager_uid):
raise exception.NotFound(_("Project can't be created because " raise exception.LDAPUserNotFound(user_id=manager_uid)
"manager %s doesn't exist")
% manager_uid)
manager_dn = self.__uid_to_dn(manager_uid) manager_dn = self.__uid_to_dn(manager_uid)
# description is a required attribute # description is a required attribute
if description is None: if description is None:
@@ -240,9 +236,7 @@ class LdapDriver(object):
if member_uids is not None: if member_uids is not None:
for member_uid in member_uids: for member_uid in member_uids:
if not self.__user_exists(member_uid): if not self.__user_exists(member_uid):
raise exception.NotFound(_("Project can't be created " raise exception.LDAPUserNotFound(user_id=member_uid)
"because user %s doesn't exist")
% member_uid)
members.append(self.__uid_to_dn(member_uid)) members.append(self.__uid_to_dn(member_uid))
# always add the manager as a member because members is required # always add the manager as a member because members is required
if not manager_dn in members: if not manager_dn in members:
@@ -265,9 +259,7 @@ class LdapDriver(object):
attr = [] attr = []
if manager_uid: if manager_uid:
if not self.__user_exists(manager_uid): if not self.__user_exists(manager_uid):
raise exception.NotFound(_("Project can't be modified because " raise exception.LDAPUserNotFound(user_id=manager_uid)
"manager %s doesn't exist")
% manager_uid)
manager_dn = self.__uid_to_dn(manager_uid) manager_dn = self.__uid_to_dn(manager_uid)
attr.append((self.ldap.MOD_REPLACE, LdapDriver.project_attribute, attr.append((self.ldap.MOD_REPLACE, LdapDriver.project_attribute,
manager_dn)) manager_dn))
@@ -347,7 +339,7 @@ class LdapDriver(object):
def delete_user(self, uid): def delete_user(self, uid):
"""Delete a user""" """Delete a user"""
if not self.__user_exists(uid): if not self.__user_exists(uid):
raise exception.NotFound(_("User %s doesn't exist") % uid) raise exception.LDAPUserNotFound(user_id=uid)
self.__remove_from_all(uid) self.__remove_from_all(uid)
if FLAGS.ldap_user_modify_only: if FLAGS.ldap_user_modify_only:
# Delete attributes # Delete attributes
@@ -471,15 +463,12 @@ class LdapDriver(object):
description, member_uids=None): description, member_uids=None):
"""Create a group""" """Create a group"""
if self.__group_exists(group_dn): if self.__group_exists(group_dn):
raise exception.Duplicate(_("Group can't be created because " raise exception.LDAPGroupExists(group=name)
"group %s already exists") % name)
members = [] members = []
if member_uids is not None: if member_uids is not None:
for member_uid in member_uids: for member_uid in member_uids:
if not self.__user_exists(member_uid): if not self.__user_exists(member_uid):
raise exception.NotFound(_("Group can't be created " raise exception.LDAPUserNotFound(user_id=member_uid)
"because user %s doesn't exist")
% member_uid)
members.append(self.__uid_to_dn(member_uid)) members.append(self.__uid_to_dn(member_uid))
dn = self.__uid_to_dn(uid) dn = self.__uid_to_dn(uid)
if not dn in members: if not dn in members:
@@ -494,8 +483,7 @@ class LdapDriver(object):
def __is_in_group(self, uid, group_dn): def __is_in_group(self, uid, group_dn):
"""Check if user is in group""" """Check if user is in group"""
if not self.__user_exists(uid): if not self.__user_exists(uid):
raise exception.NotFound(_("User %s can't be searched in group " raise exception.LDAPUserNotFound(user_id=uid)
"because the user doesn't exist") % uid)
if not self.__group_exists(group_dn): if not self.__group_exists(group_dn):
return False return False
res = self.__find_object(group_dn, res = self.__find_object(group_dn,
@@ -506,29 +494,23 @@ class LdapDriver(object):
def __add_to_group(self, uid, group_dn): def __add_to_group(self, uid, group_dn):
"""Add user to group""" """Add user to group"""
if not self.__user_exists(uid): if not self.__user_exists(uid):
raise exception.NotFound(_("User %s can't be added to the group " raise exception.LDAPUserNotFound(user_id=uid)
"because the user doesn't exist") % uid)
if not self.__group_exists(group_dn): if not self.__group_exists(group_dn):
raise exception.NotFound(_("The group at dn %s doesn't exist") % raise exception.LDAPGroupNotFound(group_id=group_dn)
group_dn)
if self.__is_in_group(uid, group_dn): if self.__is_in_group(uid, group_dn):
raise exception.Duplicate(_("User %(uid)s is already a member of " raise exception.LDAPMembershipExists(uid=uid, group_dn=group_dn)
"the group %(group_dn)s") % locals())
attr = [(self.ldap.MOD_ADD, 'member', self.__uid_to_dn(uid))] attr = [(self.ldap.MOD_ADD, 'member', self.__uid_to_dn(uid))]
self.conn.modify_s(group_dn, attr) self.conn.modify_s(group_dn, attr)
def __remove_from_group(self, uid, group_dn): def __remove_from_group(self, uid, group_dn):
"""Remove user from group""" """Remove user from group"""
if not self.__group_exists(group_dn): if not self.__group_exists(group_dn):
raise exception.NotFound(_("The group at dn %s doesn't exist") raise exception.LDAPGroupNotFound(group_id=group_dn)
% group_dn)
if not self.__user_exists(uid): if not self.__user_exists(uid):
raise exception.NotFound(_("User %s can't be removed from the " raise exception.LDAPUserNotFound(user_id=uid)
"group because the user doesn't exist")
% uid)
if not self.__is_in_group(uid, group_dn): if not self.__is_in_group(uid, group_dn):
raise exception.NotFound(_("User %s is not a member of the group") raise exception.LDAPGroupMembershipNotFound(user_id=uid,
% uid) group_id=group_dn)
# NOTE(vish): remove user from group and any sub_groups # NOTE(vish): remove user from group and any sub_groups
sub_dns = self.__find_group_dns_with_member(group_dn, uid) sub_dns = self.__find_group_dns_with_member(group_dn, uid)
for sub_dn in sub_dns: for sub_dn in sub_dns:
@@ -548,9 +530,7 @@ class LdapDriver(object):
def __remove_from_all(self, uid): def __remove_from_all(self, uid):
"""Remove user from all roles and projects""" """Remove user from all roles and projects"""
if not self.__user_exists(uid): if not self.__user_exists(uid):
raise exception.NotFound(_("User %s can't be removed from all " raise exception.LDAPUserNotFound(user_id=uid)
"because the user doesn't exist")
% uid)
role_dns = self.__find_group_dns_with_member( role_dns = self.__find_group_dns_with_member(
FLAGS.role_project_subtree, uid) FLAGS.role_project_subtree, uid)
for role_dn in role_dns: for role_dn in role_dns:
@@ -563,8 +543,7 @@ class LdapDriver(object):
def __delete_group(self, group_dn): def __delete_group(self, group_dn):
"""Delete Group""" """Delete Group"""
if not self.__group_exists(group_dn): if not self.__group_exists(group_dn):
raise exception.NotFound(_("Group at dn %s doesn't exist") raise exception.LDAPGroupNotFound(group_id=group_dn)
% group_dn)
self.conn.delete_s(group_dn) self.conn.delete_s(group_dn)
def __delete_roles(self, project_dn): def __delete_roles(self, project_dn):

View File

@@ -223,6 +223,13 @@ class AuthManager(object):
if driver or not getattr(self, 'driver', None): if driver or not getattr(self, 'driver', None):
self.driver = utils.import_class(driver or FLAGS.auth_driver) self.driver = utils.import_class(driver or FLAGS.auth_driver)
if FLAGS.memcached_servers:
import memcache
else:
from nova import fakememcache as memcache
self.mc = memcache.Client(FLAGS.memcached_servers,
debug=0)
def authenticate(self, access, signature, params, verb='GET', def authenticate(self, access, signature, params, verb='GET',
server_string='127.0.0.1:8773', path='/', server_string='127.0.0.1:8773', path='/',
check_type='ec2', headers=None): check_type='ec2', headers=None):
@@ -270,8 +277,7 @@ class AuthManager(object):
LOG.debug('user: %r', user) LOG.debug('user: %r', user)
if user is None: if user is None:
LOG.audit(_("Failed authorization for access key %s"), access_key) LOG.audit(_("Failed authorization for access key %s"), access_key)
raise exception.NotFound(_('No user found for access key %s') raise exception.AccessKeyNotFound(access_key=access_key)
% access_key)
# NOTE(vish): if we stop using project name as id we need better # NOTE(vish): if we stop using project name as id we need better
# logic to find a default project for user # logic to find a default project for user
@@ -285,8 +291,7 @@ class AuthManager(object):
uname = user.name uname = user.name
LOG.audit(_("failed authorization: no project named %(pjid)s" LOG.audit(_("failed authorization: no project named %(pjid)s"
" (user=%(uname)s)") % locals()) " (user=%(uname)s)") % locals())
raise exception.NotFound(_('No project called %s could be found') raise exception.ProjectNotFound(project_id=project_id)
% project_id)
if not self.is_admin(user) and not self.is_project_member(user, if not self.is_admin(user) and not self.is_project_member(user,
project): project):
uname = user.name uname = user.name
@@ -295,28 +300,40 @@ class AuthManager(object):
pjid = project.id pjid = project.id
LOG.audit(_("Failed authorization: user %(uname)s not admin" LOG.audit(_("Failed authorization: user %(uname)s not admin"
" and not member of project %(pjname)s") % locals()) " and not member of project %(pjname)s") % locals())
raise exception.NotFound(_('User %(uid)s is not a member of' raise exception.ProjectMembershipNotFound(project_id=pjid,
' project %(pjid)s') % locals()) user_id=uid)
if check_type == 's3': if check_type == 's3':
sign = signer.Signer(user.secret.encode()) sign = signer.Signer(user.secret.encode())
expected_signature = sign.s3_authorization(headers, verb, path) expected_signature = sign.s3_authorization(headers, verb, path)
LOG.debug('user.secret: %s', user.secret) LOG.debug(_('user.secret: %s'), user.secret)
LOG.debug('expected_signature: %s', expected_signature) LOG.debug(_('expected_signature: %s'), expected_signature)
LOG.debug('signature: %s', signature) LOG.debug(_('signature: %s'), signature)
if signature != expected_signature: if signature != expected_signature:
LOG.audit(_("Invalid signature for user %s"), user.name) LOG.audit(_("Invalid signature for user %s"), user.name)
raise exception.NotAuthorized(_('Signature does not match')) raise exception.InvalidSignature(signature=signature,
user=user)
elif check_type == 'ec2': elif check_type == 'ec2':
# NOTE(vish): hmac can't handle unicode, so encode ensures that # NOTE(vish): hmac can't handle unicode, so encode ensures that
# secret isn't unicode # secret isn't unicode
expected_signature = signer.Signer(user.secret.encode()).generate( expected_signature = signer.Signer(user.secret.encode()).generate(
params, verb, server_string, path) params, verb, server_string, path)
LOG.debug('user.secret: %s', user.secret) LOG.debug(_('user.secret: %s'), user.secret)
LOG.debug('expected_signature: %s', expected_signature) LOG.debug(_('expected_signature: %s'), expected_signature)
LOG.debug('signature: %s', signature) LOG.debug(_('signature: %s'), signature)
if signature != expected_signature: if signature != expected_signature:
(addr_str, port_str) = utils.parse_server_string(server_string)
# If the given server_string contains port num, try without it.
if port_str != '':
host_only_signature = signer.Signer(
user.secret.encode()).generate(params, verb,
addr_str, path)
LOG.debug(_('host_only_signature: %s'),
host_only_signature)
if signature == host_only_signature:
return (user, project)
LOG.audit(_("Invalid signature for user %s"), user.name) LOG.audit(_("Invalid signature for user %s"), user.name)
raise exception.NotAuthorized(_('Signature does not match')) raise exception.InvalidSignature(signature=signature,
user=user)
return (user, project) return (user, project)
def get_access_key(self, user, project): def get_access_key(self, user, project):
@@ -360,6 +377,27 @@ class AuthManager(object):
if self.has_role(user, role): if self.has_role(user, role):
return True return True
def _build_mc_key(self, user, role, project=None):
key_parts = ['rolecache', User.safe_id(user), str(role)]
if project:
key_parts.append(Project.safe_id(project))
return '-'.join(key_parts)
def _clear_mc_key(self, user, role, project=None):
# NOTE(anthony): it would be better to delete the key
self.mc.set(self._build_mc_key(user, role, project), None)
def _has_role(self, user, role, project=None):
mc_key = self._build_mc_key(user, role, project)
rslt = self.mc.get(mc_key)
if rslt is None:
with self.driver() as drv:
rslt = drv.has_role(user, role, project)
self.mc.set(mc_key, rslt)
return rslt
else:
return rslt
def has_role(self, user, role, project=None): def has_role(self, user, role, project=None):
"""Checks existence of role for user """Checks existence of role for user
@@ -383,24 +421,24 @@ class AuthManager(object):
@rtype: bool @rtype: bool
@return: True if the user has the role. @return: True if the user has the role.
""" """
with self.driver() as drv: if role == 'projectmanager':
if role == 'projectmanager': if not project:
if not project: raise exception.Error(_("Must specify project"))
raise exception.Error(_("Must specify project")) return self.is_project_manager(user, project)
return self.is_project_manager(user, project)
global_role = drv.has_role(User.safe_id(user), global_role = self._has_role(User.safe_id(user),
role, role,
None) None)
if not global_role:
return global_role
if not project or role in FLAGS.global_roles: if not global_role:
return global_role return global_role
return drv.has_role(User.safe_id(user), if not project or role in FLAGS.global_roles:
role, return global_role
Project.safe_id(project))
return self._has_role(User.safe_id(user),
role,
Project.safe_id(project))
def add_role(self, user, role, project=None): def add_role(self, user, role, project=None):
"""Adds role for user """Adds role for user
@@ -420,9 +458,9 @@ class AuthManager(object):
@param project: Project in which to add local role. @param project: Project in which to add local role.
""" """
if role not in FLAGS.allowed_roles: if role not in FLAGS.allowed_roles:
raise exception.NotFound(_("The %s role can not be found") % role) raise exception.UserRoleNotFound(role_id=role)
if project is not None and role in FLAGS.global_roles: if project is not None and role in FLAGS.global_roles:
raise exception.NotFound(_("The %s role is global only") % role) raise exception.GlobalRoleNotAllowed(role_id=role)
uid = User.safe_id(user) uid = User.safe_id(user)
pid = Project.safe_id(project) pid = Project.safe_id(project)
if project: if project:
@@ -432,6 +470,7 @@ class AuthManager(object):
LOG.audit(_("Adding sitewide role %(role)s to user %(uid)s") LOG.audit(_("Adding sitewide role %(role)s to user %(uid)s")
% locals()) % locals())
with self.driver() as drv: with self.driver() as drv:
self._clear_mc_key(uid, role, pid)
drv.add_role(uid, role, pid) drv.add_role(uid, role, pid)
def remove_role(self, user, role, project=None): def remove_role(self, user, role, project=None):
@@ -460,6 +499,7 @@ class AuthManager(object):
LOG.audit(_("Removing sitewide role %(role)s" LOG.audit(_("Removing sitewide role %(role)s"
" from user %(uid)s") % locals()) " from user %(uid)s") % locals())
with self.driver() as drv: with self.driver() as drv:
self._clear_mc_key(uid, role, pid)
drv.remove_role(uid, role, pid) drv.remove_role(uid, role, pid)
@staticmethod @staticmethod

View File

@@ -369,6 +369,9 @@ DEFINE_string('host', socket.gethostname(),
DEFINE_string('node_availability_zone', 'nova', DEFINE_string('node_availability_zone', 'nova',
'availability zone of this node') 'availability zone of this node')
DEFINE_list('memcached_servers', None,
'Memcached servers or None for in process cache.')
DEFINE_string('zone_name', 'nova', 'name of this zone') DEFINE_string('zone_name', 'nova', 'name of this zone')
DEFINE_list('zone_capabilities', DEFINE_list('zone_capabilities',
['hypervisor=xenserver;kvm', 'os=linux;windows'], ['hypervisor=xenserver;kvm', 'os=linux;windows'],

View File

@@ -76,11 +76,9 @@ def zone_update(context, zone_id, data):
return db.zone_update(context, zone_id, data) return db.zone_update(context, zone_id, data)
def get_zone_capabilities(context, service=None): def get_zone_capabilities(context):
"""Returns a dict of key, value capabilities for this zone, """Returns a dict of key, value capabilities for this zone."""
or for a particular class of services running in this zone.""" return _call_scheduler('get_zone_capabilities', context=context)
return _call_scheduler('get_zone_capabilities', context=context,
params=dict(service=service))
def update_service_capabilities(context, service_name, host, capabilities): def update_service_capabilities(context, service_name, host, capabilities):

View File

@@ -0,0 +1,288 @@
# Copyright (c) 2011 Openstack, LLC.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Host Filter is a driver mechanism for requesting instance resources.
Three drivers are included: AllHosts, Flavor & JSON. AllHosts just
returns the full, unfiltered list of hosts. Flavor is a hard coded
matching mechanism based on flavor criteria and JSON is an ad-hoc
filter grammar.
Why JSON? The requests for instances may come in through the
REST interface from a user or a parent Zone.
Currently Flavors and/or InstanceTypes are used for
specifing the type of instance desired. Specific Nova users have
noted a need for a more expressive way of specifying instances.
Since we don't want to get into building full DSL this is a simple
form as an example of how this could be done. In reality, most
consumers will use the more rigid filters such as FlavorFilter.
Note: These are "required" capability filters. These capabilities
used must be present or the host will be excluded. The hosts
returned are then weighed by the Weighted Scheduler. Weights
can take the more esoteric factors into consideration (such as
server affinity and customer separation).
"""
import json
from nova import exception
from nova import flags
from nova import log as logging
from nova import utils
LOG = logging.getLogger('nova.scheduler.host_filter')
FLAGS = flags.FLAGS
flags.DEFINE_string('default_host_filter_driver',
'nova.scheduler.host_filter.AllHostsFilter',
'Which driver to use for filtering hosts.')
class HostFilter(object):
"""Base class for host filter drivers."""
def instance_type_to_filter(self, instance_type):
"""Convert instance_type into a filter for most common use-case."""
raise NotImplementedError()
def filter_hosts(self, zone_manager, query):
"""Return a list of hosts that fulfill the filter."""
raise NotImplementedError()
def _full_name(self):
"""module.classname of the filter driver"""
return "%s.%s" % (self.__module__, self.__class__.__name__)
class AllHostsFilter(HostFilter):
"""NOP host filter driver. Returns all hosts in ZoneManager.
This essentially does what the old Scheduler+Chance used
to give us."""
def instance_type_to_filter(self, instance_type):
"""Return anything to prevent base-class from raising
exception."""
return (self._full_name(), instance_type)
def filter_hosts(self, zone_manager, query):
"""Return a list of hosts from ZoneManager list."""
return [(host, services)
for host, services in zone_manager.service_states.iteritems()]
class FlavorFilter(HostFilter):
"""HostFilter driver hard-coded to work with flavors."""
def instance_type_to_filter(self, instance_type):
"""Use instance_type to filter hosts."""
return (self._full_name(), instance_type)
def filter_hosts(self, zone_manager, query):
"""Return a list of hosts that can create instance_type."""
instance_type = query
selected_hosts = []
for host, services in zone_manager.service_states.iteritems():
capabilities = services.get('compute', {})
host_ram_mb = capabilities['host_memory_free']
disk_bytes = capabilities['disk_available']
if host_ram_mb >= instance_type['memory_mb'] and \
disk_bytes >= instance_type['local_gb']:
selected_hosts.append((host, capabilities))
return selected_hosts
#host entries (currently) are like:
# {'host_name-description': 'Default install of XenServer',
# 'host_hostname': 'xs-mini',
# 'host_memory_total': 8244539392,
# 'host_memory_overhead': 184225792,
# 'host_memory_free': 3868327936,
# 'host_memory_free_computed': 3840843776},
# 'host_other-config': {},
# 'host_ip_address': '192.168.1.109',
# 'host_cpu_info': {},
# 'disk_available': 32954957824,
# 'disk_total': 50394562560,
# 'disk_used': 17439604736},
# 'host_uuid': 'cedb9b39-9388-41df-8891-c5c9a0c0fe5f',
# 'host_name-label': 'xs-mini'}
# instance_type table has:
#name = Column(String(255), unique=True)
#memory_mb = Column(Integer)
#vcpus = Column(Integer)
#local_gb = Column(Integer)
#flavorid = Column(Integer, unique=True)
#swap = Column(Integer, nullable=False, default=0)
#rxtx_quota = Column(Integer, nullable=False, default=0)
#rxtx_cap = Column(Integer, nullable=False, default=0)
class JsonFilter(HostFilter):
"""Host Filter driver to allow simple JSON-based grammar for
selecting hosts."""
def _equals(self, args):
"""First term is == all the other terms."""
if len(args) < 2:
return False
lhs = args[0]
for rhs in args[1:]:
if lhs != rhs:
return False
return True
def _less_than(self, args):
"""First term is < all the other terms."""
if len(args) < 2:
return False
lhs = args[0]
for rhs in args[1:]:
if lhs >= rhs:
return False
return True
def _greater_than(self, args):
"""First term is > all the other terms."""
if len(args) < 2:
return False
lhs = args[0]
for rhs in args[1:]:
if lhs <= rhs:
return False
return True
def _in(self, args):
"""First term is in set of remaining terms"""
if len(args) < 2:
return False
return args[0] in args[1:]
def _less_than_equal(self, args):
"""First term is <= all the other terms."""
if len(args) < 2:
return False
lhs = args[0]
for rhs in args[1:]:
if lhs > rhs:
return False
return True
def _greater_than_equal(self, args):
"""First term is >= all the other terms."""
if len(args) < 2:
return False
lhs = args[0]
for rhs in args[1:]:
if lhs < rhs:
return False
return True
def _not(self, args):
"""Flip each of the arguments."""
if len(args) == 0:
return False
return [not arg for arg in args]
def _or(self, args):
"""True if any arg is True."""
return True in args
def _and(self, args):
"""True if all args are True."""
return False not in args
commands = {
'=': _equals,
'<': _less_than,
'>': _greater_than,
'in': _in,
'<=': _less_than_equal,
'>=': _greater_than_equal,
'not': _not,
'or': _or,
'and': _and,
}
def instance_type_to_filter(self, instance_type):
"""Convert instance_type into JSON filter object."""
required_ram = instance_type['memory_mb']
required_disk = instance_type['local_gb']
query = ['and',
['>=', '$compute.host_memory_free', required_ram],
['>=', '$compute.disk_available', required_disk]
]
return (self._full_name(), json.dumps(query))
def _parse_string(self, string, host, services):
"""Strings prefixed with $ are capability lookups in the
form '$service.capability[.subcap*]'"""
if not string:
return None
if string[0] != '$':
return string
path = string[1:].split('.')
for item in path:
services = services.get(item, None)
if not services:
return None
return services
def _process_filter(self, zone_manager, query, host, services):
"""Recursively parse the query structure."""
if len(query) == 0:
return True
cmd = query[0]
method = self.commands[cmd] # Let exception fly.
cooked_args = []
for arg in query[1:]:
if isinstance(arg, list):
arg = self._process_filter(zone_manager, arg, host, services)
elif isinstance(arg, basestring):
arg = self._parse_string(arg, host, services)
if arg != None:
cooked_args.append(arg)
result = method(self, cooked_args)
return result
def filter_hosts(self, zone_manager, query):
"""Return a list of hosts that can fulfill filter."""
expanded = json.loads(query)
hosts = []
for host, services in zone_manager.service_states.iteritems():
r = self._process_filter(zone_manager, expanded, host, services)
if isinstance(r, list):
r = True in r
if r:
hosts.append((host, services))
return hosts
DRIVERS = [AllHostsFilter, FlavorFilter, JsonFilter]
def choose_driver(driver_name=None):
"""Since the caller may specify which driver to use we need
to have an authoritative list of what is permissible. This
function checks the driver name against a predefined set
of acceptable drivers."""
if not driver_name:
driver_name = FLAGS.default_host_filter_driver
for driver in DRIVERS:
if "%s.%s" % (driver.__module__, driver.__name__) == driver_name:
return driver()
raise exception.SchedulerHostFilterDriverNotFound(driver_name=driver_name)

View File

@@ -106,28 +106,26 @@ class ZoneManager(object):
def __init__(self): def __init__(self):
self.last_zone_db_check = datetime.min self.last_zone_db_check = datetime.min
self.zone_states = {} # { <zone_id> : ZoneState } self.zone_states = {} # { <zone_id> : ZoneState }
self.service_states = {} # { <service> : { <host> : { cap k : v }}} self.service_states = {} # { <host> : { <service> : { cap k : v }}}
self.green_pool = greenpool.GreenPool() self.green_pool = greenpool.GreenPool()
def get_zone_list(self): def get_zone_list(self):
"""Return the list of zones we know about.""" """Return the list of zones we know about."""
return [zone.to_dict() for zone in self.zone_states.values()] return [zone.to_dict() for zone in self.zone_states.values()]
def get_zone_capabilities(self, context, service=None): def get_zone_capabilities(self, context):
"""Roll up all the individual host info to generic 'service' """Roll up all the individual host info to generic 'service'
capabilities. Each capability is aggregated into capabilities. Each capability is aggregated into
<cap>_min and <cap>_max values.""" <cap>_min and <cap>_max values."""
service_dict = self.service_states hosts_dict = self.service_states
if service:
service_dict = {service: self.service_states.get(service, {})}
# TODO(sandy) - be smarter about fabricating this structure. # TODO(sandy) - be smarter about fabricating this structure.
# But it's likely to change once we understand what the Best-Match # But it's likely to change once we understand what the Best-Match
# code will need better. # code will need better.
combined = {} # { <service>_<cap> : (min, max), ... } combined = {} # { <service>_<cap> : (min, max), ... }
for service_name, host_dict in service_dict.iteritems(): for host, host_dict in hosts_dict.iteritems():
for host, caps_dict in host_dict.iteritems(): for service_name, service_dict in host_dict.iteritems():
for cap, value in caps_dict.iteritems(): for cap, value in service_dict.iteritems():
key = "%s_%s" % (service_name, cap) key = "%s_%s" % (service_name, cap)
min_value, max_value = combined.get(key, (value, value)) min_value, max_value = combined.get(key, (value, value))
min_value = min(min_value, value) min_value = min(min_value, value)
@@ -171,6 +169,6 @@ class ZoneManager(object):
"""Update the per-service capabilities based on this notification.""" """Update the per-service capabilities based on this notification."""
logging.debug(_("Received %(service_name)s service update from " logging.debug(_("Received %(service_name)s service update from "
"%(host)s: %(capabilities)s") % locals()) "%(host)s: %(capabilities)s") % locals())
service_caps = self.service_states.get(service_name, {}) service_caps = self.service_states.get(host, {})
service_caps[host] = capabilities service_caps[service_name] = capabilities
self.service_states[service_name] = service_caps self.service_states[host] = service_caps

View File

@@ -28,10 +28,12 @@ import StringIO
import webob import webob
from nova import context from nova import context
from nova import exception
from nova import test from nova import test
from nova.api import ec2 from nova.api import ec2
from nova.api.ec2 import cloud
from nova.api.ec2 import apirequest from nova.api.ec2 import apirequest
from nova.api.ec2 import cloud
from nova.api.ec2 import ec2utils
from nova.auth import manager from nova.auth import manager
@@ -101,6 +103,21 @@ class XmlConversionTestCase(test.TestCase):
self.assertEqual(conv('-0'), 0) self.assertEqual(conv('-0'), 0)
class Ec2utilsTestCase(test.TestCase):
def test_ec2_id_to_id(self):
self.assertEqual(ec2utils.ec2_id_to_id('i-0000001e'), 30)
self.assertEqual(ec2utils.ec2_id_to_id('ami-1d'), 29)
def test_bad_ec2_id(self):
self.assertRaises(exception.InvalidEc2Id,
ec2utils.ec2_id_to_id,
'badone')
def test_id_to_ec2_id(self):
self.assertEqual(ec2utils.id_to_ec2_id(30), 'i-0000001e')
self.assertEqual(ec2utils.id_to_ec2_id(29, 'ami-%08x'), 'ami-0000001d')
class ApiEc2TestCase(test.TestCase): class ApiEc2TestCase(test.TestCase):
"""Unit test for the cloud controller on an EC2 API""" """Unit test for the cloud controller on an EC2 API"""
def setUp(self): def setUp(self):

View File

@@ -101,9 +101,43 @@ class _AuthManagerBaseTestCase(test.TestCase):
self.assertEqual('private-party', u.access) self.assertEqual('private-party', u.access)
def test_004_signature_is_valid(self): def test_004_signature_is_valid(self):
#self.assertTrue(self.manager.authenticate(**boto.generate_url ...? )) with user_generator(self.manager, name='admin', secret='admin',
pass access='admin'):
#raise NotImplementedError with project_generator(self.manager, name="admin",
manager_user='admin'):
accesskey = 'admin:admin'
expected_result = (self.manager.get_user('admin'),
self.manager.get_project('admin'))
# captured sig and query string using boto 1.9b/euca2ools 1.2
sig = 'd67Wzd9Bwz8xid9QU+lzWXcF2Y3tRicYABPJgrqfrwM='
auth_params = {'AWSAccessKeyId': 'admin:admin',
'Action': 'DescribeAvailabilityZones',
'SignatureMethod': 'HmacSHA256',
'SignatureVersion': '2',
'Timestamp': '2011-04-22T11:29:29',
'Version': '2009-11-30'}
self.assertTrue(expected_result, self.manager.authenticate(
accesskey,
sig,
auth_params,
'GET',
'127.0.0.1:8773',
'/services/Cloud/'))
# captured sig and query string using RightAWS 1.10.0
sig = 'ECYLU6xdFG0ZqRVhQybPJQNJ5W4B9n8fGs6+/fuGD2c='
auth_params = {'AWSAccessKeyId': 'admin:admin',
'Action': 'DescribeAvailabilityZones',
'SignatureMethod': 'HmacSHA256',
'SignatureVersion': '2',
'Timestamp': '2011-04-22T11:29:49.000Z',
'Version': '2008-12-01'}
self.assertTrue(expected_result, self.manager.authenticate(
accesskey,
sig,
auth_params,
'GET',
'127.0.0.1',
'/services/Cloud'))
def test_005_can_get_credentials(self): def test_005_can_get_credentials(self):
return return

View File

@@ -310,7 +310,7 @@ class CloudTestCase(test.TestCase):
instance_id = rv['instancesSet'][0]['instanceId'] instance_id = rv['instancesSet'][0]['instanceId']
output = self.cloud.get_console_output(context=self.context, output = self.cloud.get_console_output(context=self.context,
instance_id=[instance_id]) instance_id=[instance_id])
self.assertEquals(b64decode(output['output']), 'FAKE CONSOLE OUTPUT') self.assertEquals(b64decode(output['output']), 'FAKE CONSOLE?OUTPUT')
# TODO(soren): We need this until we can stop polling in the rpc code # TODO(soren): We need this until we can stop polling in the rpc code
# for unit tests. # for unit tests.
greenthread.sleep(0.3) greenthread.sleep(0.3)

View File

@@ -21,6 +21,7 @@ Tests For Compute
import datetime import datetime
import mox import mox
import stubout
from nova import compute from nova import compute
from nova import context from nova import context
@@ -52,6 +53,10 @@ class FakeTime(object):
self.counter += t self.counter += t
def nop_report_driver_status(self):
pass
class ComputeTestCase(test.TestCase): class ComputeTestCase(test.TestCase):
"""Test case for compute""" """Test case for compute"""
def setUp(self): def setUp(self):
@@ -649,6 +654,10 @@ class ComputeTestCase(test.TestCase):
def test_run_kill_vm(self): def test_run_kill_vm(self):
"""Detect when a vm is terminated behind the scenes""" """Detect when a vm is terminated behind the scenes"""
self.stubs = stubout.StubOutForTesting()
self.stubs.Set(compute_manager.ComputeManager,
'_report_driver_status', nop_report_driver_status)
instance_id = self._create_instance() instance_id = self._create_instance()
self.compute.run_instance(self.context, instance_id) self.compute.run_instance(self.context, instance_id)

View File

@@ -0,0 +1,208 @@
# Copyright 2011 OpenStack LLC.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Tests For Scheduler Host Filter Drivers.
"""
import json
from nova import exception
from nova import flags
from nova import test
from nova.scheduler import host_filter
FLAGS = flags.FLAGS
class FakeZoneManager:
pass
class HostFilterTestCase(test.TestCase):
"""Test case for host filter drivers."""
def _host_caps(self, multiplier):
# Returns host capabilities in the following way:
# host1 = memory:free 10 (100max)
# disk:available 100 (1000max)
# hostN = memory:free 10 + 10N
# disk:available 100 + 100N
# in other words: hostN has more resources than host0
# which means ... don't go above 10 hosts.
return {'host_name-description': 'XenServer %s' % multiplier,
'host_hostname': 'xs-%s' % multiplier,
'host_memory_total': 100,
'host_memory_overhead': 10,
'host_memory_free': 10 + multiplier * 10,
'host_memory_free-computed': 10 + multiplier * 10,
'host_other-config': {},
'host_ip_address': '192.168.1.%d' % (100 + multiplier),
'host_cpu_info': {},
'disk_available': 100 + multiplier * 100,
'disk_total': 1000,
'disk_used': 0,
'host_uuid': 'xxx-%d' % multiplier,
'host_name-label': 'xs-%s' % multiplier}
def setUp(self):
self.old_flag = FLAGS.default_host_filter_driver
FLAGS.default_host_filter_driver = \
'nova.scheduler.host_filter.AllHostsFilter'
self.instance_type = dict(name='tiny',
memory_mb=50,
vcpus=10,
local_gb=500,
flavorid=1,
swap=500,
rxtx_quota=30000,
rxtx_cap=200)
self.zone_manager = FakeZoneManager()
states = {}
for x in xrange(10):
states['host%02d' % (x + 1)] = {'compute': self._host_caps(x)}
self.zone_manager.service_states = states
def tearDown(self):
FLAGS.default_host_filter_driver = self.old_flag
def test_choose_driver(self):
# Test default driver ...
driver = host_filter.choose_driver()
self.assertEquals(driver._full_name(),
'nova.scheduler.host_filter.AllHostsFilter')
# Test valid driver ...
driver = host_filter.choose_driver(
'nova.scheduler.host_filter.FlavorFilter')
self.assertEquals(driver._full_name(),
'nova.scheduler.host_filter.FlavorFilter')
# Test invalid driver ...
try:
host_filter.choose_driver('does not exist')
self.fail("Should not find driver")
except exception.SchedulerHostFilterDriverNotFound:
pass
def test_all_host_driver(self):
driver = host_filter.AllHostsFilter()
cooked = driver.instance_type_to_filter(self.instance_type)
hosts = driver.filter_hosts(self.zone_manager, cooked)
self.assertEquals(10, len(hosts))
for host, capabilities in hosts:
self.assertTrue(host.startswith('host'))
def test_flavor_driver(self):
driver = host_filter.FlavorFilter()
# filter all hosts that can support 50 ram and 500 disk
name, cooked = driver.instance_type_to_filter(self.instance_type)
self.assertEquals('nova.scheduler.host_filter.FlavorFilter', name)
hosts = driver.filter_hosts(self.zone_manager, cooked)
self.assertEquals(6, len(hosts))
just_hosts = [host for host, caps in hosts]
just_hosts.sort()
self.assertEquals('host05', just_hosts[0])
self.assertEquals('host10', just_hosts[5])
def test_json_driver(self):
driver = host_filter.JsonFilter()
# filter all hosts that can support 50 ram and 500 disk
name, cooked = driver.instance_type_to_filter(self.instance_type)
self.assertEquals('nova.scheduler.host_filter.JsonFilter', name)
hosts = driver.filter_hosts(self.zone_manager, cooked)
self.assertEquals(6, len(hosts))
just_hosts = [host for host, caps in hosts]
just_hosts.sort()
self.assertEquals('host05', just_hosts[0])
self.assertEquals('host10', just_hosts[5])
# Try some custom queries
raw = ['or',
['and',
['<', '$compute.host_memory_free', 30],
['<', '$compute.disk_available', 300]
],
['and',
['>', '$compute.host_memory_free', 70],
['>', '$compute.disk_available', 700]
]
]
cooked = json.dumps(raw)
hosts = driver.filter_hosts(self.zone_manager, cooked)
self.assertEquals(5, len(hosts))
just_hosts = [host for host, caps in hosts]
just_hosts.sort()
for index, host in zip([1, 2, 8, 9, 10], just_hosts):
self.assertEquals('host%02d' % index, host)
raw = ['not',
['=', '$compute.host_memory_free', 30],
]
cooked = json.dumps(raw)
hosts = driver.filter_hosts(self.zone_manager, cooked)
self.assertEquals(9, len(hosts))
just_hosts = [host for host, caps in hosts]
just_hosts.sort()
for index, host in zip([1, 2, 4, 5, 6, 7, 8, 9, 10], just_hosts):
self.assertEquals('host%02d' % index, host)
raw = ['in', '$compute.host_memory_free', 20, 40, 60, 80, 100]
cooked = json.dumps(raw)
hosts = driver.filter_hosts(self.zone_manager, cooked)
self.assertEquals(5, len(hosts))
just_hosts = [host for host, caps in hosts]
just_hosts.sort()
for index, host in zip([2, 4, 6, 8, 10], just_hosts):
self.assertEquals('host%02d' % index, host)
# Try some bogus input ...
raw = ['unknown command', ]
cooked = json.dumps(raw)
try:
driver.filter_hosts(self.zone_manager, cooked)
self.fail("Should give KeyError")
except KeyError, e:
pass
self.assertTrue(driver.filter_hosts(self.zone_manager, json.dumps([])))
self.assertTrue(driver.filter_hosts(self.zone_manager, json.dumps({})))
self.assertTrue(driver.filter_hosts(self.zone_manager, json.dumps(
['not', True, False, True, False]
)))
try:
driver.filter_hosts(self.zone_manager, json.dumps(
'not', True, False, True, False
))
self.fail("Should give KeyError")
except KeyError, e:
pass
self.assertFalse(driver.filter_hosts(self.zone_manager, json.dumps(
['=', '$foo', 100]
)))
self.assertFalse(driver.filter_hosts(self.zone_manager, json.dumps(
['=', '$.....', 100]
)))
self.assertFalse(driver.filter_hosts(self.zone_manager, json.dumps(
['>', ['and', ['or', ['not', ['<', ['>=', ['<=', ['in', ]]]]]]]]
)))
self.assertFalse(driver.filter_hosts(self.zone_manager, json.dumps(
['=', {}, ['>', '$missing....foo']]
)))

View File

@@ -75,13 +75,13 @@ class InstanceTypeTestCase(test.TestCase):
def test_invalid_create_args_should_fail(self): def test_invalid_create_args_should_fail(self):
"""Ensures that instance type creation fails with invalid args""" """Ensures that instance type creation fails with invalid args"""
self.assertRaises( self.assertRaises(
exception.InvalidInputException, exception.InvalidInput,
instance_types.create, self.name, 0, 1, 120, self.flavorid) instance_types.create, self.name, 0, 1, 120, self.flavorid)
self.assertRaises( self.assertRaises(
exception.InvalidInputException, exception.InvalidInput,
instance_types.create, self.name, 256, -1, 120, self.flavorid) instance_types.create, self.name, 256, -1, 120, self.flavorid)
self.assertRaises( self.assertRaises(
exception.InvalidInputException, exception.InvalidInput,
instance_types.create, self.name, 256, 1, "aa", self.flavorid) instance_types.create, self.name, 256, 1, "aa", self.flavorid)
def test_non_existant_inst_type_shouldnt_delete(self): def test_non_existant_inst_type_shouldnt_delete(self):

View File

@@ -29,11 +29,12 @@ from nova.utils import parse_mailmap, str_dict_replace
class ProjectTestCase(test.TestCase): class ProjectTestCase(test.TestCase):
def test_authors_up_to_date(self): def test_authors_up_to_date(self):
topdir = os.path.normpath(os.path.dirname(__file__) + '/../../') topdir = os.path.normpath(os.path.dirname(__file__) + '/../../')
missing = set()
contributors = set()
mailmap = parse_mailmap(os.path.join(topdir, '.mailmap'))
authors_file = open(os.path.join(topdir, 'Authors'), 'r').read()
if os.path.exists(os.path.join(topdir, '.bzr')): if os.path.exists(os.path.join(topdir, '.bzr')):
contributors = set()
mailmap = parse_mailmap(os.path.join(topdir, '.mailmap'))
import bzrlib.workingtree import bzrlib.workingtree
tree = bzrlib.workingtree.WorkingTree.open(topdir) tree = bzrlib.workingtree.WorkingTree.open(topdir)
tree.lock_read() tree.lock_read()
@@ -47,23 +48,37 @@ class ProjectTestCase(test.TestCase):
for r in revs: for r in revs:
for author in r.get_apparent_authors(): for author in r.get_apparent_authors():
email = author.split(' ')[-1] email = author.split(' ')[-1]
contributors.add(str_dict_replace(email, mailmap)) contributors.add(str_dict_replace(email,
mailmap))
authors_file = open(os.path.join(topdir, 'Authors'),
'r').read()
missing = set()
for contributor in contributors:
if contributor == 'nova-core':
continue
if not contributor in authors_file:
missing.add(contributor)
self.assertTrue(len(missing) == 0,
'%r not listed in Authors' % missing)
finally: finally:
tree.unlock() tree.unlock()
elif os.path.exists(os.path.join(topdir, '.git')):
import git
repo = git.Repo(topdir)
for commit in repo.head.commit.iter_parents():
email = commit.author.email
if email is None:
email = commit.author.name
if 'nova-core' in email:
continue
if email.split(' ')[-1] == '<>':
email = email.split(' ')[-2]
email = '<' + email + '>'
contributors.add(str_dict_replace(email, mailmap))
else:
return
for contributor in contributors:
if contributor == 'nova-core':
continue
if not contributor in authors_file:
missing.add(contributor)
self.assertTrue(len(missing) == 0,
'%r not listed in Authors' % missing)
class LockTestCase(test.TestCase): class LockTestCase(test.TestCase):
def test_synchronized_wrapped_function_metadata(self): def test_synchronized_wrapped_function_metadata(self):

View File

@@ -120,12 +120,11 @@ class SchedulerTestCase(test.TestCase):
dest = 'dummydest' dest = 'dummydest'
ctxt = context.get_admin_context() ctxt = context.get_admin_context()
try: self.assertRaises(exception.NotFound, scheduler.show_host_resources,
scheduler.show_host_resources(ctxt, dest) ctxt, dest)
except exception.NotFound, e: #TODO(bcwaldon): reimplement this functionality
c1 = (e.message.find(_("does not exist or is not a " #c1 = (e.message.find(_("does not exist or is not a "
"compute node.")) >= 0) # "compute node.")) >= 0)
self.assertTrue(c1)
def _dic_is_equal(self, dic1, dic2, keys=None): def _dic_is_equal(self, dic1, dic2, keys=None):
"""Compares 2 dictionary contents(Helper method)""" """Compares 2 dictionary contents(Helper method)"""
@@ -769,14 +768,10 @@ class SimpleDriverTestCase(test.TestCase):
s_ref = self._create_compute_service(host='somewhere', s_ref = self._create_compute_service(host='somewhere',
memory_mb_used=12) memory_mb_used=12)
try: self.assertRaises(exception.MigrationError,
self.scheduler.driver._live_migration_dest_check(self.context, self.scheduler.driver._live_migration_dest_check,
i_ref, self.context, i_ref, 'somewhere')
'somewhere')
except exception.NotEmpty, e:
c = (e.message.find('Unable to migrate') >= 0)
self.assertTrue(c)
db.instance_destroy(self.context, instance_id) db.instance_destroy(self.context, instance_id)
db.service_destroy(self.context, s_ref['id']) db.service_destroy(self.context, s_ref['id'])
@@ -941,7 +936,7 @@ class FakeRerouteCompute(api.reroute_compute):
def go_boom(self, context, instance): def go_boom(self, context, instance):
raise exception.InstanceNotFound("boom message", instance) raise exception.InstanceNotFound(instance_id=instance)
def found_instance(self, context, instance): def found_instance(self, context, instance):
@@ -990,11 +985,8 @@ class ZoneRedirectTest(test.TestCase):
def test_routing_flags(self): def test_routing_flags(self):
FLAGS.enable_zone_routing = False FLAGS.enable_zone_routing = False
decorator = FakeRerouteCompute("foo") decorator = FakeRerouteCompute("foo")
try: self.assertRaises(exception.InstanceNotFound, decorator(go_boom),
result = decorator(go_boom)(None, None, 1) None, None, 1)
self.assertFail(_("Should have thrown exception."))
except exception.InstanceNotFound, e:
self.assertEquals(e.message, 'boom message')
def test_get_collection_context_and_id(self): def test_get_collection_context_and_id(self):
decorator = api.reroute_compute("foo") decorator = api.reroute_compute("foo")

View File

@@ -31,9 +31,7 @@ from nova import test
from nova import utils from nova import utils
from nova.api.ec2 import cloud from nova.api.ec2 import cloud
from nova.auth import manager from nova.auth import manager
from nova.compute import manager as compute_manager
from nova.compute import power_state from nova.compute import power_state
from nova.db.sqlalchemy import models
from nova.virt import libvirt_conn from nova.virt import libvirt_conn
libvirt = None libvirt = None
@@ -46,6 +44,27 @@ def _concurrency(wait, done, target):
done.send() done.send()
def _create_network_info(count=1, ipv6=None):
if ipv6 is None:
ipv6 = FLAGS.use_ipv6
fake = 'fake'
fake_ip = '0.0.0.0/0'
fake_ip_2 = '0.0.0.1/0'
fake_ip_3 = '0.0.0.1/0'
network = {'gateway': fake,
'gateway_v6': fake,
'bridge': fake,
'cidr': fake_ip,
'cidr_v6': fake_ip}
mapping = {'mac': fake,
'ips': [{'ip': fake_ip}, {'ip': fake_ip}]}
if ipv6:
mapping['ip6s'] = [{'ip': fake_ip},
{'ip': fake_ip_2},
{'ip': fake_ip_3}]
return [(network, mapping) for x in xrange(0, count)]
class CacheConcurrencyTestCase(test.TestCase): class CacheConcurrencyTestCase(test.TestCase):
def setUp(self): def setUp(self):
super(CacheConcurrencyTestCase, self).setUp() super(CacheConcurrencyTestCase, self).setUp()
@@ -194,6 +213,37 @@ class LibvirtConnTestCase(test.TestCase):
return db.service_create(context.get_admin_context(), service_ref) return db.service_create(context.get_admin_context(), service_ref)
def test_preparing_xml_info(self):
conn = libvirt_conn.LibvirtConnection(True)
instance_ref = db.instance_create(self.context, self.test_instance)
result = conn._prepare_xml_info(instance_ref, False)
self.assertFalse(result['nics'])
result = conn._prepare_xml_info(instance_ref, False,
_create_network_info())
self.assertTrue(len(result['nics']) == 1)
result = conn._prepare_xml_info(instance_ref, False,
_create_network_info(2))
self.assertTrue(len(result['nics']) == 2)
def test_get_nic_for_xml_v4(self):
conn = libvirt_conn.LibvirtConnection(True)
network, mapping = _create_network_info()[0]
self.flags(use_ipv6=False)
params = conn._get_nic_for_xml(network, mapping)['extra_params']
self.assertTrue(params.find('PROJNETV6') == -1)
self.assertTrue(params.find('PROJMASKV6') == -1)
def test_get_nic_for_xml_v6(self):
conn = libvirt_conn.LibvirtConnection(True)
network, mapping = _create_network_info()[0]
self.flags(use_ipv6=True)
params = conn._get_nic_for_xml(network, mapping)['extra_params']
self.assertTrue(params.find('PROJNETV6') > -1)
self.assertTrue(params.find('PROJMASKV6') > -1)
def test_xml_and_uri_no_ramdisk_no_kernel(self): def test_xml_and_uri_no_ramdisk_no_kernel(self):
instance_data = dict(self.test_instance) instance_data = dict(self.test_instance)
self._check_xml_and_uri(instance_data, self._check_xml_and_uri(instance_data,
@@ -229,6 +279,22 @@ class LibvirtConnTestCase(test.TestCase):
instance_data = dict(self.test_instance) instance_data = dict(self.test_instance)
self._check_xml_and_container(instance_data) self._check_xml_and_container(instance_data)
def test_multi_nic(self):
instance_data = dict(self.test_instance)
network_info = _create_network_info(2)
conn = libvirt_conn.LibvirtConnection(True)
instance_ref = db.instance_create(self.context, instance_data)
xml = conn.to_xml(instance_ref, False, network_info)
tree = xml_to_tree(xml)
interfaces = tree.findall("./devices/interface")
self.assertEquals(len(interfaces), 2)
parameters = interfaces[0].findall('./filterref/parameter')
self.assertEquals(interfaces[0].get('type'), 'bridge')
self.assertEquals(parameters[0].get('name'), 'IP')
self.assertEquals(parameters[0].get('value'), '0.0.0.0/0')
self.assertEquals(parameters[1].get('name'), 'DHCPSERVER')
self.assertEquals(parameters[1].get('value'), 'fake')
def _check_xml_and_container(self, instance): def _check_xml_and_container(self, instance):
user_context = context.RequestContext(project=self.project, user_context = context.RequestContext(project=self.project,
user=self.user) user=self.user)
@@ -327,19 +393,13 @@ class LibvirtConnTestCase(test.TestCase):
check = (lambda t: t.find('./os/initrd'), None) check = (lambda t: t.find('./os/initrd'), None)
check_list.append(check) check_list.append(check)
parameter = './devices/interface/filterref/parameter'
common_checks = [ common_checks = [
(lambda t: t.find('.').tag, 'domain'), (lambda t: t.find('.').tag, 'domain'),
(lambda t: t.find( (lambda t: t.find(parameter).get('name'), 'IP'),
'./devices/interface/filterref/parameter').get('name'), 'IP'), (lambda t: t.find(parameter).get('value'), '10.11.12.13'),
(lambda t: t.find( (lambda t: t.findall(parameter)[1].get('name'), 'DHCPSERVER'),
'./devices/interface/filterref/parameter').get( (lambda t: t.findall(parameter)[1].get('value'), '10.0.0.1'),
'value'), '10.11.12.13'),
(lambda t: t.findall(
'./devices/interface/filterref/parameter')[1].get(
'name'), 'DHCPSERVER'),
(lambda t: t.findall(
'./devices/interface/filterref/parameter')[1].get(
'value'), '10.0.0.1'),
(lambda t: t.find('./devices/serial/source').get( (lambda t: t.find('./devices/serial/source').get(
'path').split('/')[1], 'console.log'), 'path').split('/')[1], 'console.log'),
(lambda t: t.find('./memory').text, '2097152')] (lambda t: t.find('./memory').text, '2097152')]
@@ -586,6 +646,11 @@ class LibvirtConnTestCase(test.TestCase):
self.assertTrue(count) self.assertTrue(count)
def test_get_host_ip_addr(self):
conn = libvirt_conn.LibvirtConnection(False)
ip = conn.get_host_ip_addr()
self.assertEquals(ip, FLAGS.my_ip)
def tearDown(self): def tearDown(self):
self.manager.delete_project(self.project) self.manager.delete_project(self.project)
self.manager.delete_user(self.user) self.manager.delete_user(self.user)
@@ -651,12 +716,15 @@ class IptablesFirewallTestCase(test.TestCase):
'# Completed on Tue Jan 18 23:47:56 2011', '# Completed on Tue Jan 18 23:47:56 2011',
] ]
def _create_instance_ref(self):
return db.instance_create(self.context,
{'user_id': 'fake',
'project_id': 'fake',
'mac_address': '56:12:12:12:12:12',
'instance_type_id': 1})
def test_static_filters(self): def test_static_filters(self):
instance_ref = db.instance_create(self.context, instance_ref = self._create_instance_ref()
{'user_id': 'fake',
'project_id': 'fake',
'mac_address': '56:12:12:12:12:12',
'instance_type_id': 1})
ip = '10.11.12.13' ip = '10.11.12.13'
network_ref = db.project_get_network(self.context, network_ref = db.project_get_network(self.context,
@@ -767,6 +835,40 @@ class IptablesFirewallTestCase(test.TestCase):
"TCP port 80/81 acceptance rule wasn't added") "TCP port 80/81 acceptance rule wasn't added")
db.instance_destroy(admin_ctxt, instance_ref['id']) db.instance_destroy(admin_ctxt, instance_ref['id'])
def test_filters_for_instance_with_ip_v6(self):
self.flags(use_ipv6=True)
network_info = _create_network_info()
rulesv4, rulesv6 = self.fw._filters_for_instance("fake", network_info)
self.assertEquals(len(rulesv4), 2)
self.assertEquals(len(rulesv6), 3)
def test_filters_for_instance_without_ip_v6(self):
self.flags(use_ipv6=False)
network_info = _create_network_info()
rulesv4, rulesv6 = self.fw._filters_for_instance("fake", network_info)
self.assertEquals(len(rulesv4), 2)
self.assertEquals(len(rulesv6), 0)
def multinic_iptables_test(self):
ipv4_rules_per_network = 2
ipv6_rules_per_network = 3
networks_count = 5
instance_ref = self._create_instance_ref()
network_info = _create_network_info(networks_count)
ipv4_len = len(self.fw.iptables.ipv4['filter'].rules)
ipv6_len = len(self.fw.iptables.ipv6['filter'].rules)
inst_ipv4, inst_ipv6 = self.fw.instance_rules(instance_ref,
network_info)
self.fw.add_filters_for_instance(instance_ref, network_info)
ipv4 = self.fw.iptables.ipv4['filter'].rules
ipv6 = self.fw.iptables.ipv6['filter'].rules
ipv4_network_rules = len(ipv4) - len(inst_ipv4) - ipv4_len
ipv6_network_rules = len(ipv6) - len(inst_ipv6) - ipv6_len
self.assertEquals(ipv4_network_rules,
ipv4_rules_per_network * networks_count)
self.assertEquals(ipv6_network_rules,
ipv6_rules_per_network * networks_count)
class NWFilterTestCase(test.TestCase): class NWFilterTestCase(test.TestCase):
def setUp(self): def setUp(self):
@@ -848,6 +950,28 @@ class NWFilterTestCase(test.TestCase):
return db.security_group_get_by_name(self.context, 'fake', 'testgroup') return db.security_group_get_by_name(self.context, 'fake', 'testgroup')
def _create_instance(self):
return db.instance_create(self.context,
{'user_id': 'fake',
'project_id': 'fake',
'mac_address': '00:A0:C9:14:C8:29',
'instance_type_id': 1})
def _create_instance_type(self, params={}):
"""Create a test instance"""
context = self.context.elevated()
inst = {}
inst['name'] = 'm1.small'
inst['memory_mb'] = '1024'
inst['vcpus'] = '1'
inst['local_gb'] = '20'
inst['flavorid'] = '1'
inst['swap'] = '2048'
inst['rxtx_quota'] = 100
inst['rxtx_cap'] = 200
inst.update(params)
return db.instance_type_create(context, inst)['id']
def test_creates_base_rule_first(self): def test_creates_base_rule_first(self):
# These come pre-defined by libvirt # These come pre-defined by libvirt
self.defined_filters = ['no-mac-spoofing', self.defined_filters = ['no-mac-spoofing',
@@ -876,25 +1000,18 @@ class NWFilterTestCase(test.TestCase):
self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock
instance_ref = db.instance_create(self.context, instance_ref = self._create_instance()
{'user_id': 'fake',
'project_id': 'fake',
'mac_address': '00:A0:C9:14:C8:29',
'instance_type_id': 1})
inst_id = instance_ref['id'] inst_id = instance_ref['id']
ip = '10.11.12.13' ip = '10.11.12.13'
network_ref = db.project_get_network(self.context, network_ref = db.project_get_network(self.context, 'fake')
'fake') fixed_ip = {'address': ip, 'network_id': network_ref['id']}
fixed_ip = {'address': ip,
'network_id': network_ref['id']}
admin_ctxt = context.get_admin_context() admin_ctxt = context.get_admin_context()
db.fixed_ip_create(admin_ctxt, fixed_ip) db.fixed_ip_create(admin_ctxt, fixed_ip)
db.fixed_ip_update(admin_ctxt, ip, {'allocated': True, db.fixed_ip_update(admin_ctxt, ip, {'allocated': True,
'instance_id': instance_ref['id']}) 'instance_id': inst_id})
def _ensure_all_called(): def _ensure_all_called():
instance_filter = 'nova-instance-%s-%s' % (instance_ref['name'], instance_filter = 'nova-instance-%s-%s' % (instance_ref['name'],
@@ -920,3 +1037,11 @@ class NWFilterTestCase(test.TestCase):
_ensure_all_called() _ensure_all_called()
self.teardown_security_group() self.teardown_security_group()
db.instance_destroy(admin_ctxt, instance_ref['id']) db.instance_destroy(admin_ctxt, instance_ref['id'])
def test_create_network_filters(self):
instance_ref = self._create_instance()
network_info = _create_network_info(3)
result = self.fw._create_network_filters(instance_ref,
network_info,
"fake")
self.assertEquals(len(result), 3)

View File

@@ -142,7 +142,7 @@ class VolumeTestCase(test.TestCase):
self.assertEqual(vol['status'], "available") self.assertEqual(vol['status'], "available")
self.volume.delete_volume(self.context, volume_id) self.volume.delete_volume(self.context, volume_id)
self.assertRaises(exception.Error, self.assertRaises(exception.VolumeNotFound,
db.volume_get, db.volume_get,
self.context, self.context,
volume_id) volume_id)

View File

@@ -17,6 +17,7 @@
"""Test suite for XenAPI.""" """Test suite for XenAPI."""
import functools import functools
import json
import os import os
import re import re
import stubout import stubout
@@ -665,3 +666,52 @@ class XenAPIDetermineDiskImageTestCase(test.TestCase):
self.fake_instance.image_id = glance_stubs.FakeGlance.IMAGE_VHD self.fake_instance.image_id = glance_stubs.FakeGlance.IMAGE_VHD
self.fake_instance.kernel_id = None self.fake_instance.kernel_id = None
self.assert_disk_type(vm_utils.ImageType.DISK_VHD) self.assert_disk_type(vm_utils.ImageType.DISK_VHD)
class FakeXenApi(object):
"""Fake XenApi for testing HostState."""
class FakeSR(object):
def get_record(self, ref):
return {'virtual_allocation': 10000,
'physical_utilisation': 20000}
SR = FakeSR()
class FakeSession(object):
"""Fake Session class for HostState testing."""
def async_call_plugin(self, *args):
return None
def wait_for_task(self, *args):
vm = {'total': 10,
'overhead': 20,
'free': 30,
'free-computed': 40}
return json.dumps({'host_memory': vm})
def get_xenapi(self):
return FakeXenApi()
class HostStateTestCase(test.TestCase):
"""Tests HostState, which holds metrics from XenServer that get
reported back to the Schedulers."""
def _fake_safe_find_sr(self, session):
"""None SR ref since we're ignoring it in FakeSR."""
return None
def test_host_state(self):
self.stubs = stubout.StubOutForTesting()
self.stubs.Set(vm_utils, 'safe_find_sr', self._fake_safe_find_sr)
host_state = xenapi_conn.HostState(FakeSession())
stats = host_state._stats
self.assertEquals(stats['disk_total'], 10000)
self.assertEquals(stats['disk_used'], 20000)
self.assertEquals(stats['host_memory_total'], 10)
self.assertEquals(stats['host_memory_overhead'], 20)
self.assertEquals(stats['host_memory_free'], 30)
self.assertEquals(stats['host_memory_free_computed'], 40)

View File

@@ -78,38 +78,32 @@ class ZoneManagerTestCase(test.TestCase):
def test_service_capabilities(self): def test_service_capabilities(self):
zm = zone_manager.ZoneManager() zm = zone_manager.ZoneManager()
caps = zm.get_zone_capabilities(self, None) caps = zm.get_zone_capabilities(None)
self.assertEquals(caps, {}) self.assertEquals(caps, {})
zm.update_service_capabilities("svc1", "host1", dict(a=1, b=2)) zm.update_service_capabilities("svc1", "host1", dict(a=1, b=2))
caps = zm.get_zone_capabilities(self, None) caps = zm.get_zone_capabilities(None)
self.assertEquals(caps, dict(svc1_a=(1, 1), svc1_b=(2, 2))) self.assertEquals(caps, dict(svc1_a=(1, 1), svc1_b=(2, 2)))
zm.update_service_capabilities("svc1", "host1", dict(a=2, b=3)) zm.update_service_capabilities("svc1", "host1", dict(a=2, b=3))
caps = zm.get_zone_capabilities(self, None) caps = zm.get_zone_capabilities(None)
self.assertEquals(caps, dict(svc1_a=(2, 2), svc1_b=(3, 3))) self.assertEquals(caps, dict(svc1_a=(2, 2), svc1_b=(3, 3)))
zm.update_service_capabilities("svc1", "host2", dict(a=20, b=30)) zm.update_service_capabilities("svc1", "host2", dict(a=20, b=30))
caps = zm.get_zone_capabilities(self, None) caps = zm.get_zone_capabilities(None)
self.assertEquals(caps, dict(svc1_a=(2, 20), svc1_b=(3, 30))) self.assertEquals(caps, dict(svc1_a=(2, 20), svc1_b=(3, 30)))
zm.update_service_capabilities("svc10", "host1", dict(a=99, b=99)) zm.update_service_capabilities("svc10", "host1", dict(a=99, b=99))
caps = zm.get_zone_capabilities(self, None) caps = zm.get_zone_capabilities(None)
self.assertEquals(caps, dict(svc1_a=(2, 20), svc1_b=(3, 30), self.assertEquals(caps, dict(svc1_a=(2, 20), svc1_b=(3, 30),
svc10_a=(99, 99), svc10_b=(99, 99))) svc10_a=(99, 99), svc10_b=(99, 99)))
zm.update_service_capabilities("svc1", "host3", dict(c=5)) zm.update_service_capabilities("svc1", "host3", dict(c=5))
caps = zm.get_zone_capabilities(self, None) caps = zm.get_zone_capabilities(None)
self.assertEquals(caps, dict(svc1_a=(2, 20), svc1_b=(3, 30), self.assertEquals(caps, dict(svc1_a=(2, 20), svc1_b=(3, 30),
svc1_c=(5, 5), svc10_a=(99, 99), svc1_c=(5, 5), svc10_a=(99, 99),
svc10_b=(99, 99))) svc10_b=(99, 99)))
caps = zm.get_zone_capabilities(self, 'svc1')
self.assertEquals(caps, dict(svc1_a=(2, 20), svc1_b=(3, 30),
svc1_c=(5, 5)))
caps = zm.get_zone_capabilities(self, 'svc10')
self.assertEquals(caps, dict(svc10_a=(99, 99), svc10_b=(99, 99)))
def test_refresh_from_db_replace_existing(self): def test_refresh_from_db_replace_existing(self):
zm = zone_manager.ZoneManager() zm = zone_manager.ZoneManager()
zone_state = zone_manager.ZoneState() zone_state = zone_manager.ZoneState()