refactoring of imports for fakeldapdriver

This commit is contained in:
Vishvananda Ishaya
2010-07-21 19:56:08 -05:00
parent 1fa4a70100
commit eaf648d7a0
10 changed files with 75 additions and 47 deletions

View File

@@ -0,0 +1,32 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Fake Auth driver for ldap
"""
from nova.auth import ldapdriver
class AuthDriver(ldapdriver.AuthDriver):
"""Ldap Auth driver
Defines enter and exit and therefore supports the with/as syntax.
"""
def __init__(self):
self.ldap = __import__('nova.auth.fakeldap', fromlist=True)

View File

@@ -57,19 +57,18 @@ flags.DEFINE_string('ldap_developer',
'cn=developers,ou=Groups,dc=example,dc=com', 'cn for Developers')
class LdapDriver(object):
class AuthDriver(object):
"""Ldap Auth driver
Defines enter and exit and therefore supports the with/as syntax.
"""
def __init__(self):
"""Imports the LDAP module"""
self.ldap = __import__('ldap')
def __enter__(self):
"""Creates the connection to LDAP"""
global ldap
if FLAGS.fake_users:
from nova.auth import fakeldap as ldap
else:
import ldap
self.conn = ldap.initialize(FLAGS.ldap_url)
self.conn = self.ldap.initialize(FLAGS.ldap_url)
self.conn.simple_bind_s(FLAGS.ldap_user_dn, FLAGS.ldap_password)
return self
@@ -275,8 +274,8 @@ class LdapDriver(object):
def __find_dns(self, dn, query=None):
"""Find dns by query"""
try:
res = self.conn.search_s(dn, ldap.SCOPE_SUBTREE, query)
except ldap.NO_SUCH_OBJECT:
res = self.conn.search_s(dn, self.ldap.SCOPE_SUBTREE, query)
except self.ldap.NO_SUCH_OBJECT:
return []
# just return the DNs
return [dn for dn, attributes in res]
@@ -284,8 +283,8 @@ class LdapDriver(object):
def __find_objects(self, dn, query = None):
"""Find objects by query"""
try:
res = self.conn.search_s(dn, ldap.SCOPE_SUBTREE, query)
except ldap.NO_SUCH_OBJECT:
res = self.conn.search_s(dn, self.ldap.SCOPE_SUBTREE, query)
except self.ldap.NO_SUCH_OBJECT:
return []
# just return the attributes
return [attributes for dn, attributes in res]
@@ -369,7 +368,7 @@ class LdapDriver(object):
raise exception.Duplicate("User %s is already a member of "
"the group %s" % (uid, group_dn))
attr = [
(ldap.MOD_ADD, 'member', self.__uid_to_dn(uid))
(self.ldap.MOD_ADD, 'member', self.__uid_to_dn(uid))
]
self.conn.modify_s(group_dn, attr)
@@ -389,10 +388,10 @@ class LdapDriver(object):
def __safe_remove_from_group(self, uid, group_dn):
"""Remove user from group, deleting group if user is last member"""
# FIXME(vish): what if deleted user is a project manager?
attr = [(ldap.MOD_DELETE, 'member', self.__uid_to_dn(uid))]
attr = [(self.ldap.MOD_DELETE, 'member', self.__uid_to_dn(uid))]
try:
self.conn.modify_s(group_dn, attr)
except ldap.OBJECT_CLASS_VIOLATION:
except self.ldap.OBJECT_CLASS_VIOLATION:
logging.debug("Attempted to remove the last member of a group. "
"Deleting the group at %s instead." % group_dn )
self.__delete_group(group_dn)

View File

@@ -34,7 +34,6 @@ from nova import exception
from nova import flags
from nova import objectstore # for flags
from nova import utils
from nova.auth import ldapdriver
from nova.auth import signer
FLAGS = flags.FLAGS
@@ -76,6 +75,8 @@ flags.DEFINE_string('credential_cert_subject',
flags.DEFINE_string('vpn_ip', '127.0.0.1',
'Public IP for the cloudpipe VPN servers')
flags.DEFINE_string('auth_driver', 'fakeldapdriver',
'Driver that auth manager uses')
class AuthBase(object):
"""Base class for objects relating to auth
@@ -312,7 +313,7 @@ class AuthManager(object):
Methods accept objects or ids.
AuthManager uses a driver object to make requests to the data backend.
See ldapdriver.LdapDriver for reference.
See ldapdriver for reference.
AuthManager also manages associated data related to Auth objects that
need to be more accessible, such as vpn ips and ports.
@@ -325,7 +326,9 @@ class AuthManager(object):
return cls._instance
def __init__(self, *args, **kwargs):
self.driver_class = kwargs.get('driver_class', ldapdriver.LdapDriver)
"""Imports the driver module and saves the Driver class"""
mod = __import__(FLAGS.auth_driver, fromlist=True)
self.driver = mod.AuthDriver
def authenticate(self, access, signature, params, verb='GET',
server_string='127.0.0.1:8773', path='/',
@@ -451,7 +454,7 @@ class AuthManager(object):
@rtype: bool
@return: True if the user has the role.
"""
with self.driver_class() as drv:
with self.driver() as drv:
if role == 'projectmanager':
if not project:
raise exception.Error("Must specify project")
@@ -487,7 +490,7 @@ class AuthManager(object):
@type project: Project or project_id
@param project: Project in which to add local role.
"""
with self.driver_class() as drv:
with self.driver() as drv:
drv.add_role(User.safe_id(user), role, Project.safe_id(project))
def remove_role(self, user, role, project=None):
@@ -507,19 +510,19 @@ class AuthManager(object):
@type project: Project or project_id
@param project: Project in which to remove local role.
"""
with self.driver_class() as drv:
with self.driver() as drv:
drv.remove_role(User.safe_id(user), role, Project.safe_id(project))
def get_project(self, pid):
"""Get project object by id"""
with self.driver_class() as drv:
with self.driver() as drv:
project_dict = drv.get_project(pid)
if project_dict:
return Project(**project_dict)
def get_projects(self):
"""Retrieves list of all projects"""
with self.driver_class() as drv:
with self.driver() as drv:
project_list = drv.get_projects()
if not project_list:
return []
@@ -549,7 +552,7 @@ class AuthManager(object):
"""
if member_users:
member_users = [User.safe_id(u) for u in member_users]
with self.driver_class() as drv:
with self.driver() as drv:
project_dict = drv.create_project(name,
User.safe_id(manager_user),
description,
@@ -561,7 +564,7 @@ class AuthManager(object):
def add_to_project(self, user, project):
"""Add user to project"""
with self.driver_class() as drv:
with self.driver() as drv:
return drv.add_to_project(User.safe_id(user),
Project.safe_id(project))
@@ -579,7 +582,7 @@ class AuthManager(object):
def remove_from_project(self, user, project):
"""Removes a user from a project"""
with self.driver_class() as drv:
with self.driver() as drv:
return drv.remove_from_project(User.safe_id(user),
Project.safe_id(project))
@@ -600,26 +603,26 @@ class AuthManager(object):
def delete_project(self, project):
"""Deletes a project"""
with self.driver_class() as drv:
with self.driver() as drv:
return drv.delete_project(Project.safe_id(project))
def get_user(self, uid):
"""Retrieves a user by id"""
with self.driver_class() as drv:
with self.driver() as drv:
user_dict = drv.get_user(uid)
if user_dict:
return User(**user_dict)
def get_user_from_access_key(self, access_key):
"""Retrieves a user by access key"""
with self.driver_class() as drv:
with self.driver() as drv:
user_dict = drv.get_user_from_access_key(access_key)
if user_dict:
return User(**user_dict)
def get_users(self):
"""Retrieves a list of all users"""
with self.driver_class() as drv:
with self.driver() as drv:
user_list = drv.get_users()
if not user_list:
return []
@@ -649,14 +652,14 @@ class AuthManager(object):
"""
if access == None: access = str(uuid.uuid4())
if secret == None: secret = str(uuid.uuid4())
with self.driver_class() as drv:
with self.driver() as drv:
user_dict = drv.create_user(name, access, secret, admin)
if user_dict:
return User(**user_dict)
def delete_user(self, user):
"""Deletes a user"""
with self.driver_class() as drv:
with self.driver() as drv:
drv.delete_user(User.safe_id(user))
def generate_key_pair(self, user, key_name):
@@ -677,7 +680,7 @@ class AuthManager(object):
# NOTE(vish): generating key pair is slow so check for legal
# creation before creating keypair
uid = User.safe_id(user)
with self.driver_class() as drv:
with self.driver() as drv:
if not drv.get_user(uid):
raise exception.NotFound("User %s doesn't exist" % user)
if drv.get_key_pair(uid, key_name):
@@ -689,7 +692,7 @@ class AuthManager(object):
def create_key_pair(self, user, key_name, public_key, fingerprint):
"""Creates a key pair for user"""
with self.driver_class() as drv:
with self.driver() as drv:
kp_dict = drv.create_key_pair(User.safe_id(user),
key_name,
public_key,
@@ -699,14 +702,14 @@ class AuthManager(object):
def get_key_pair(self, user, key_name):
"""Retrieves a key pair for user"""
with self.driver_class() as drv:
with self.driver() as drv:
kp_dict = drv.get_key_pair(User.safe_id(user), key_name)
if kp_dict:
return KeyPair(**kp_dict)
def get_key_pairs(self, user):
"""Retrieves all key pairs for user"""
with self.driver_class() as drv:
with self.driver() as drv:
kp_list = drv.get_key_pairs(User.safe_id(user))
if not kp_list:
return []
@@ -714,7 +717,7 @@ class AuthManager(object):
def delete_key_pair(self, user, key_name):
"""Deletes a key pair for user"""
with self.driver_class() as drv:
with self.driver() as drv:
drv.delete_key_pair(User.safe_id(user), key_name)
def get_credentials(self, user, project=None):

View File

@@ -46,7 +46,6 @@ DEFINE_bool('fake_libvirt', False,
DEFINE_bool('verbose', False, 'show debug output')
DEFINE_boolean('fake_rabbit', False, 'use a fake rabbit')
DEFINE_bool('fake_network', False, 'should we use fake network devices and addresses')
DEFINE_bool('fake_users', False, 'use fake users')
DEFINE_string('rabbit_host', 'localhost', 'rabbit host')
DEFINE_integer('rabbit_port', 5672, 'rabbit port')
DEFINE_string('rabbit_userid', 'guest', 'rabbit userid')

View File

@@ -40,8 +40,7 @@ class CloudTestCase(test.BaseTestCase):
def setUp(self):
super(CloudTestCase, self).setUp()
self.flags(fake_libvirt=True,
fake_storage=True,
fake_users=True)
fake_storage=True)
self.conn = rpc.Connection.instance()
logging.getLogger().setLevel(logging.DEBUG)

View File

@@ -24,5 +24,5 @@ FLAGS.fake_libvirt = True
FLAGS.fake_storage = True
FLAGS.fake_rabbit = True
FLAGS.fake_network = True
FLAGS.fake_users = True
FLAGS.auth_driver = 'nova.auth.fakeldapdriver'
FLAGS.verbose = True

View File

@@ -35,8 +35,7 @@ class ModelTestCase(test.TrialTestCase):
def setUp(self):
super(ModelTestCase, self).setUp()
self.flags(fake_libvirt=True,
fake_storage=True,
fake_users=True)
fake_storage=True)
def tearDown(self):
model.Instance('i-test').destroy()

View File

@@ -58,8 +58,7 @@ class NodeConnectionTestCase(test.TrialTestCase):
logging.getLogger().setLevel(logging.DEBUG)
super(NodeConnectionTestCase, self).setUp()
self.flags(fake_libvirt=True,
fake_storage=True,
fake_users=True)
fake_storage=True)
self.node = node.Node()
def create_instance(self):

View File

@@ -51,8 +51,7 @@ os.makedirs(os.path.join(oss_tempdir, 'buckets'))
class ObjectStoreTestCase(test.BaseTestCase):
def setUp(self):
super(ObjectStoreTestCase, self).setUp()
self.flags(fake_users=True,
buckets_path=os.path.join(oss_tempdir, 'buckets'),
self.flags(buckets_path=os.path.join(oss_tempdir, 'buckets'),
images_path=os.path.join(oss_tempdir, 'images'),
ca_path=os.path.join(os.path.dirname(__file__), 'CA'))
logging.getLogger().setLevel(logging.DEBUG)

View File

@@ -24,5 +24,4 @@ FLAGS.fake_libvirt = False
FLAGS.fake_storage = False
FLAGS.fake_rabbit = False
FLAGS.fake_network = False
FLAGS.fake_users = False
FLAGS.verbose = False