refactoring of imports for fakeldapdriver

This commit is contained in:
Vishvananda Ishaya
2010-07-21 19:56:08 -05:00
parent 1fa4a70100
commit eaf648d7a0
10 changed files with 75 additions and 47 deletions

View File

@@ -0,0 +1,32 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Fake Auth driver for ldap
"""
from nova.auth import ldapdriver
class AuthDriver(ldapdriver.AuthDriver):
"""Ldap Auth driver
Defines enter and exit and therefore supports the with/as syntax.
"""
def __init__(self):
self.ldap = __import__('nova.auth.fakeldap', fromlist=True)

View File

@@ -57,19 +57,18 @@ flags.DEFINE_string('ldap_developer',
'cn=developers,ou=Groups,dc=example,dc=com', 'cn for Developers') 'cn=developers,ou=Groups,dc=example,dc=com', 'cn for Developers')
class LdapDriver(object): class AuthDriver(object):
"""Ldap Auth driver """Ldap Auth driver
Defines enter and exit and therefore supports the with/as syntax. Defines enter and exit and therefore supports the with/as syntax.
""" """
def __init__(self):
"""Imports the LDAP module"""
self.ldap = __import__('ldap')
def __enter__(self): def __enter__(self):
"""Creates the connection to LDAP""" """Creates the connection to LDAP"""
global ldap self.conn = self.ldap.initialize(FLAGS.ldap_url)
if FLAGS.fake_users:
from nova.auth import fakeldap as ldap
else:
import ldap
self.conn = ldap.initialize(FLAGS.ldap_url)
self.conn.simple_bind_s(FLAGS.ldap_user_dn, FLAGS.ldap_password) self.conn.simple_bind_s(FLAGS.ldap_user_dn, FLAGS.ldap_password)
return self return self
@@ -275,8 +274,8 @@ class LdapDriver(object):
def __find_dns(self, dn, query=None): def __find_dns(self, dn, query=None):
"""Find dns by query""" """Find dns by query"""
try: try:
res = self.conn.search_s(dn, ldap.SCOPE_SUBTREE, query) res = self.conn.search_s(dn, self.ldap.SCOPE_SUBTREE, query)
except ldap.NO_SUCH_OBJECT: except self.ldap.NO_SUCH_OBJECT:
return [] return []
# just return the DNs # just return the DNs
return [dn for dn, attributes in res] return [dn for dn, attributes in res]
@@ -284,8 +283,8 @@ class LdapDriver(object):
def __find_objects(self, dn, query = None): def __find_objects(self, dn, query = None):
"""Find objects by query""" """Find objects by query"""
try: try:
res = self.conn.search_s(dn, ldap.SCOPE_SUBTREE, query) res = self.conn.search_s(dn, self.ldap.SCOPE_SUBTREE, query)
except ldap.NO_SUCH_OBJECT: except self.ldap.NO_SUCH_OBJECT:
return [] return []
# just return the attributes # just return the attributes
return [attributes for dn, attributes in res] return [attributes for dn, attributes in res]
@@ -369,7 +368,7 @@ class LdapDriver(object):
raise exception.Duplicate("User %s is already a member of " raise exception.Duplicate("User %s is already a member of "
"the group %s" % (uid, group_dn)) "the group %s" % (uid, group_dn))
attr = [ attr = [
(ldap.MOD_ADD, 'member', self.__uid_to_dn(uid)) (self.ldap.MOD_ADD, 'member', self.__uid_to_dn(uid))
] ]
self.conn.modify_s(group_dn, attr) self.conn.modify_s(group_dn, attr)
@@ -389,10 +388,10 @@ class LdapDriver(object):
def __safe_remove_from_group(self, uid, group_dn): def __safe_remove_from_group(self, uid, group_dn):
"""Remove user from group, deleting group if user is last member""" """Remove user from group, deleting group if user is last member"""
# FIXME(vish): what if deleted user is a project manager? # FIXME(vish): what if deleted user is a project manager?
attr = [(ldap.MOD_DELETE, 'member', self.__uid_to_dn(uid))] attr = [(self.ldap.MOD_DELETE, 'member', self.__uid_to_dn(uid))]
try: try:
self.conn.modify_s(group_dn, attr) self.conn.modify_s(group_dn, attr)
except ldap.OBJECT_CLASS_VIOLATION: except self.ldap.OBJECT_CLASS_VIOLATION:
logging.debug("Attempted to remove the last member of a group. " logging.debug("Attempted to remove the last member of a group. "
"Deleting the group at %s instead." % group_dn ) "Deleting the group at %s instead." % group_dn )
self.__delete_group(group_dn) self.__delete_group(group_dn)

View File

@@ -34,7 +34,6 @@ from nova import exception
from nova import flags from nova import flags
from nova import objectstore # for flags from nova import objectstore # for flags
from nova import utils from nova import utils
from nova.auth import ldapdriver
from nova.auth import signer from nova.auth import signer
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
@@ -76,6 +75,8 @@ flags.DEFINE_string('credential_cert_subject',
flags.DEFINE_string('vpn_ip', '127.0.0.1', flags.DEFINE_string('vpn_ip', '127.0.0.1',
'Public IP for the cloudpipe VPN servers') 'Public IP for the cloudpipe VPN servers')
flags.DEFINE_string('auth_driver', 'fakeldapdriver',
'Driver that auth manager uses')
class AuthBase(object): class AuthBase(object):
"""Base class for objects relating to auth """Base class for objects relating to auth
@@ -312,7 +313,7 @@ class AuthManager(object):
Methods accept objects or ids. Methods accept objects or ids.
AuthManager uses a driver object to make requests to the data backend. AuthManager uses a driver object to make requests to the data backend.
See ldapdriver.LdapDriver for reference. See ldapdriver for reference.
AuthManager also manages associated data related to Auth objects that AuthManager also manages associated data related to Auth objects that
need to be more accessible, such as vpn ips and ports. need to be more accessible, such as vpn ips and ports.
@@ -325,7 +326,9 @@ class AuthManager(object):
return cls._instance return cls._instance
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.driver_class = kwargs.get('driver_class', ldapdriver.LdapDriver) """Imports the driver module and saves the Driver class"""
mod = __import__(FLAGS.auth_driver, fromlist=True)
self.driver = mod.AuthDriver
def authenticate(self, access, signature, params, verb='GET', def authenticate(self, access, signature, params, verb='GET',
server_string='127.0.0.1:8773', path='/', server_string='127.0.0.1:8773', path='/',
@@ -451,7 +454,7 @@ class AuthManager(object):
@rtype: bool @rtype: bool
@return: True if the user has the role. @return: True if the user has the role.
""" """
with self.driver_class() as drv: with self.driver() as drv:
if role == 'projectmanager': if role == 'projectmanager':
if not project: if not project:
raise exception.Error("Must specify project") raise exception.Error("Must specify project")
@@ -487,7 +490,7 @@ class AuthManager(object):
@type project: Project or project_id @type project: Project or project_id
@param project: Project in which to add local role. @param project: Project in which to add local role.
""" """
with self.driver_class() as drv: with self.driver() as drv:
drv.add_role(User.safe_id(user), role, Project.safe_id(project)) drv.add_role(User.safe_id(user), role, Project.safe_id(project))
def remove_role(self, user, role, project=None): def remove_role(self, user, role, project=None):
@@ -507,19 +510,19 @@ class AuthManager(object):
@type project: Project or project_id @type project: Project or project_id
@param project: Project in which to remove local role. @param project: Project in which to remove local role.
""" """
with self.driver_class() as drv: with self.driver() as drv:
drv.remove_role(User.safe_id(user), role, Project.safe_id(project)) drv.remove_role(User.safe_id(user), role, Project.safe_id(project))
def get_project(self, pid): def get_project(self, pid):
"""Get project object by id""" """Get project object by id"""
with self.driver_class() as drv: with self.driver() as drv:
project_dict = drv.get_project(pid) project_dict = drv.get_project(pid)
if project_dict: if project_dict:
return Project(**project_dict) return Project(**project_dict)
def get_projects(self): def get_projects(self):
"""Retrieves list of all projects""" """Retrieves list of all projects"""
with self.driver_class() as drv: with self.driver() as drv:
project_list = drv.get_projects() project_list = drv.get_projects()
if not project_list: if not project_list:
return [] return []
@@ -549,7 +552,7 @@ class AuthManager(object):
""" """
if member_users: if member_users:
member_users = [User.safe_id(u) for u in member_users] member_users = [User.safe_id(u) for u in member_users]
with self.driver_class() as drv: with self.driver() as drv:
project_dict = drv.create_project(name, project_dict = drv.create_project(name,
User.safe_id(manager_user), User.safe_id(manager_user),
description, description,
@@ -561,7 +564,7 @@ class AuthManager(object):
def add_to_project(self, user, project): def add_to_project(self, user, project):
"""Add user to project""" """Add user to project"""
with self.driver_class() as drv: with self.driver() as drv:
return drv.add_to_project(User.safe_id(user), return drv.add_to_project(User.safe_id(user),
Project.safe_id(project)) Project.safe_id(project))
@@ -579,7 +582,7 @@ class AuthManager(object):
def remove_from_project(self, user, project): def remove_from_project(self, user, project):
"""Removes a user from a project""" """Removes a user from a project"""
with self.driver_class() as drv: with self.driver() as drv:
return drv.remove_from_project(User.safe_id(user), return drv.remove_from_project(User.safe_id(user),
Project.safe_id(project)) Project.safe_id(project))
@@ -600,26 +603,26 @@ class AuthManager(object):
def delete_project(self, project): def delete_project(self, project):
"""Deletes a project""" """Deletes a project"""
with self.driver_class() as drv: with self.driver() as drv:
return drv.delete_project(Project.safe_id(project)) return drv.delete_project(Project.safe_id(project))
def get_user(self, uid): def get_user(self, uid):
"""Retrieves a user by id""" """Retrieves a user by id"""
with self.driver_class() as drv: with self.driver() as drv:
user_dict = drv.get_user(uid) user_dict = drv.get_user(uid)
if user_dict: if user_dict:
return User(**user_dict) return User(**user_dict)
def get_user_from_access_key(self, access_key): def get_user_from_access_key(self, access_key):
"""Retrieves a user by access key""" """Retrieves a user by access key"""
with self.driver_class() as drv: with self.driver() as drv:
user_dict = drv.get_user_from_access_key(access_key) user_dict = drv.get_user_from_access_key(access_key)
if user_dict: if user_dict:
return User(**user_dict) return User(**user_dict)
def get_users(self): def get_users(self):
"""Retrieves a list of all users""" """Retrieves a list of all users"""
with self.driver_class() as drv: with self.driver() as drv:
user_list = drv.get_users() user_list = drv.get_users()
if not user_list: if not user_list:
return [] return []
@@ -649,14 +652,14 @@ class AuthManager(object):
""" """
if access == None: access = str(uuid.uuid4()) if access == None: access = str(uuid.uuid4())
if secret == None: secret = str(uuid.uuid4()) if secret == None: secret = str(uuid.uuid4())
with self.driver_class() as drv: with self.driver() as drv:
user_dict = drv.create_user(name, access, secret, admin) user_dict = drv.create_user(name, access, secret, admin)
if user_dict: if user_dict:
return User(**user_dict) return User(**user_dict)
def delete_user(self, user): def delete_user(self, user):
"""Deletes a user""" """Deletes a user"""
with self.driver_class() as drv: with self.driver() as drv:
drv.delete_user(User.safe_id(user)) drv.delete_user(User.safe_id(user))
def generate_key_pair(self, user, key_name): def generate_key_pair(self, user, key_name):
@@ -677,7 +680,7 @@ class AuthManager(object):
# NOTE(vish): generating key pair is slow so check for legal # NOTE(vish): generating key pair is slow so check for legal
# creation before creating keypair # creation before creating keypair
uid = User.safe_id(user) uid = User.safe_id(user)
with self.driver_class() as drv: with self.driver() as drv:
if not drv.get_user(uid): if not drv.get_user(uid):
raise exception.NotFound("User %s doesn't exist" % user) raise exception.NotFound("User %s doesn't exist" % user)
if drv.get_key_pair(uid, key_name): if drv.get_key_pair(uid, key_name):
@@ -689,7 +692,7 @@ class AuthManager(object):
def create_key_pair(self, user, key_name, public_key, fingerprint): def create_key_pair(self, user, key_name, public_key, fingerprint):
"""Creates a key pair for user""" """Creates a key pair for user"""
with self.driver_class() as drv: with self.driver() as drv:
kp_dict = drv.create_key_pair(User.safe_id(user), kp_dict = drv.create_key_pair(User.safe_id(user),
key_name, key_name,
public_key, public_key,
@@ -699,14 +702,14 @@ class AuthManager(object):
def get_key_pair(self, user, key_name): def get_key_pair(self, user, key_name):
"""Retrieves a key pair for user""" """Retrieves a key pair for user"""
with self.driver_class() as drv: with self.driver() as drv:
kp_dict = drv.get_key_pair(User.safe_id(user), key_name) kp_dict = drv.get_key_pair(User.safe_id(user), key_name)
if kp_dict: if kp_dict:
return KeyPair(**kp_dict) return KeyPair(**kp_dict)
def get_key_pairs(self, user): def get_key_pairs(self, user):
"""Retrieves all key pairs for user""" """Retrieves all key pairs for user"""
with self.driver_class() as drv: with self.driver() as drv:
kp_list = drv.get_key_pairs(User.safe_id(user)) kp_list = drv.get_key_pairs(User.safe_id(user))
if not kp_list: if not kp_list:
return [] return []
@@ -714,7 +717,7 @@ class AuthManager(object):
def delete_key_pair(self, user, key_name): def delete_key_pair(self, user, key_name):
"""Deletes a key pair for user""" """Deletes a key pair for user"""
with self.driver_class() as drv: with self.driver() as drv:
drv.delete_key_pair(User.safe_id(user), key_name) drv.delete_key_pair(User.safe_id(user), key_name)
def get_credentials(self, user, project=None): def get_credentials(self, user, project=None):

View File

@@ -46,7 +46,6 @@ DEFINE_bool('fake_libvirt', False,
DEFINE_bool('verbose', False, 'show debug output') DEFINE_bool('verbose', False, 'show debug output')
DEFINE_boolean('fake_rabbit', False, 'use a fake rabbit') DEFINE_boolean('fake_rabbit', False, 'use a fake rabbit')
DEFINE_bool('fake_network', False, 'should we use fake network devices and addresses') DEFINE_bool('fake_network', False, 'should we use fake network devices and addresses')
DEFINE_bool('fake_users', False, 'use fake users')
DEFINE_string('rabbit_host', 'localhost', 'rabbit host') DEFINE_string('rabbit_host', 'localhost', 'rabbit host')
DEFINE_integer('rabbit_port', 5672, 'rabbit port') DEFINE_integer('rabbit_port', 5672, 'rabbit port')
DEFINE_string('rabbit_userid', 'guest', 'rabbit userid') DEFINE_string('rabbit_userid', 'guest', 'rabbit userid')

View File

@@ -40,8 +40,7 @@ class CloudTestCase(test.BaseTestCase):
def setUp(self): def setUp(self):
super(CloudTestCase, self).setUp() super(CloudTestCase, self).setUp()
self.flags(fake_libvirt=True, self.flags(fake_libvirt=True,
fake_storage=True, fake_storage=True)
fake_users=True)
self.conn = rpc.Connection.instance() self.conn = rpc.Connection.instance()
logging.getLogger().setLevel(logging.DEBUG) logging.getLogger().setLevel(logging.DEBUG)

View File

@@ -24,5 +24,5 @@ FLAGS.fake_libvirt = True
FLAGS.fake_storage = True FLAGS.fake_storage = True
FLAGS.fake_rabbit = True FLAGS.fake_rabbit = True
FLAGS.fake_network = True FLAGS.fake_network = True
FLAGS.fake_users = True FLAGS.auth_driver = 'nova.auth.fakeldapdriver'
FLAGS.verbose = True FLAGS.verbose = True

View File

@@ -35,8 +35,7 @@ class ModelTestCase(test.TrialTestCase):
def setUp(self): def setUp(self):
super(ModelTestCase, self).setUp() super(ModelTestCase, self).setUp()
self.flags(fake_libvirt=True, self.flags(fake_libvirt=True,
fake_storage=True, fake_storage=True)
fake_users=True)
def tearDown(self): def tearDown(self):
model.Instance('i-test').destroy() model.Instance('i-test').destroy()

View File

@@ -58,8 +58,7 @@ class NodeConnectionTestCase(test.TrialTestCase):
logging.getLogger().setLevel(logging.DEBUG) logging.getLogger().setLevel(logging.DEBUG)
super(NodeConnectionTestCase, self).setUp() super(NodeConnectionTestCase, self).setUp()
self.flags(fake_libvirt=True, self.flags(fake_libvirt=True,
fake_storage=True, fake_storage=True)
fake_users=True)
self.node = node.Node() self.node = node.Node()
def create_instance(self): def create_instance(self):

View File

@@ -51,8 +51,7 @@ os.makedirs(os.path.join(oss_tempdir, 'buckets'))
class ObjectStoreTestCase(test.BaseTestCase): class ObjectStoreTestCase(test.BaseTestCase):
def setUp(self): def setUp(self):
super(ObjectStoreTestCase, self).setUp() super(ObjectStoreTestCase, self).setUp()
self.flags(fake_users=True, self.flags(buckets_path=os.path.join(oss_tempdir, 'buckets'),
buckets_path=os.path.join(oss_tempdir, 'buckets'),
images_path=os.path.join(oss_tempdir, 'images'), images_path=os.path.join(oss_tempdir, 'images'),
ca_path=os.path.join(os.path.dirname(__file__), 'CA')) ca_path=os.path.join(os.path.dirname(__file__), 'CA'))
logging.getLogger().setLevel(logging.DEBUG) logging.getLogger().setLevel(logging.DEBUG)

View File

@@ -24,5 +24,4 @@ FLAGS.fake_libvirt = False
FLAGS.fake_storage = False FLAGS.fake_storage = False
FLAGS.fake_rabbit = False FLAGS.fake_rabbit = False
FLAGS.fake_network = False FLAGS.fake_network = False
FLAGS.fake_users = False
FLAGS.verbose = False FLAGS.verbose = False