Moves keypairs out of ldap and into the common datastore.

This commit is contained in:
Vishvananda Ishaya
2010-09-21 05:08:31 +00:00
committed by Tarmac
6 changed files with 91 additions and 220 deletions

View File

@@ -99,13 +99,6 @@ class LdapDriver(object):
dn = FLAGS.ldap_user_subtree dn = FLAGS.ldap_user_subtree
return self.__to_user(self.__find_object(dn, query)) return self.__to_user(self.__find_object(dn, query))
def get_key_pair(self, uid, key_name):
"""Retrieve key pair by uid and key name"""
dn = 'cn=%s,%s' % (key_name,
self.__uid_to_dn(uid))
attr = self.__find_object(dn, '(objectclass=novaKeyPair)')
return self.__to_key_pair(uid, attr)
def get_project(self, pid): def get_project(self, pid):
"""Retrieve project by id""" """Retrieve project by id"""
dn = 'cn=%s,%s' % (pid, dn = 'cn=%s,%s' % (pid,
@@ -119,12 +112,6 @@ class LdapDriver(object):
'(objectclass=novaUser)') '(objectclass=novaUser)')
return [self.__to_user(attr) for attr in attrs] return [self.__to_user(attr) for attr in attrs]
def get_key_pairs(self, uid):
"""Retrieve list of key pairs"""
attrs = self.__find_objects(self.__uid_to_dn(uid),
'(objectclass=novaKeyPair)')
return [self.__to_key_pair(uid, attr) for attr in attrs]
def get_projects(self, uid=None): def get_projects(self, uid=None):
"""Retrieve list of projects""" """Retrieve list of projects"""
pattern = '(objectclass=novaProject)' pattern = '(objectclass=novaProject)'
@@ -154,21 +141,6 @@ class LdapDriver(object):
self.conn.add_s(self.__uid_to_dn(name), attr) self.conn.add_s(self.__uid_to_dn(name), attr)
return self.__to_user(dict(attr)) return self.__to_user(dict(attr))
def create_key_pair(self, uid, key_name, public_key, fingerprint):
"""Create a key pair"""
# TODO(vish): possibly refactor this to store keys in their own ou
# and put dn reference in the user object
attr = [
('objectclass', ['novaKeyPair']),
('cn', [key_name]),
('sshPublicKey', [public_key]),
('keyFingerprint', [fingerprint]),
]
self.conn.add_s('cn=%s,%s' % (key_name,
self.__uid_to_dn(uid)),
attr)
return self.__to_key_pair(uid, dict(attr))
def create_project(self, name, manager_uid, def create_project(self, name, manager_uid,
description=None, member_uids=None): description=None, member_uids=None):
"""Create a project""" """Create a project"""
@@ -283,19 +255,10 @@ class LdapDriver(object):
"""Delete a user""" """Delete a user"""
if not self.__user_exists(uid): if not self.__user_exists(uid):
raise exception.NotFound("User %s doesn't exist" % uid) raise exception.NotFound("User %s doesn't exist" % uid)
self.__delete_key_pairs(uid)
self.__remove_from_all(uid) self.__remove_from_all(uid)
self.conn.delete_s('uid=%s,%s' % (uid, self.conn.delete_s('uid=%s,%s' % (uid,
FLAGS.ldap_user_subtree)) FLAGS.ldap_user_subtree))
def delete_key_pair(self, uid, key_name):
"""Delete a key pair"""
if not self.__key_pair_exists(uid, key_name):
raise exception.NotFound("Key Pair %s doesn't exist for user %s" %
(key_name, uid))
self.conn.delete_s('cn=%s,uid=%s,%s' % (key_name, uid,
FLAGS.ldap_user_subtree))
def delete_project(self, project_id): def delete_project(self, project_id):
"""Delete a project""" """Delete a project"""
project_dn = 'cn=%s,%s' % (project_id, FLAGS.ldap_project_subtree) project_dn = 'cn=%s,%s' % (project_id, FLAGS.ldap_project_subtree)
@@ -306,10 +269,6 @@ class LdapDriver(object):
"""Check if user exists""" """Check if user exists"""
return self.get_user(uid) != None return self.get_user(uid) != None
def __key_pair_exists(self, uid, key_name):
"""Check if key pair exists"""
return self.get_key_pair(uid, key_name) != None
def __project_exists(self, project_id): def __project_exists(self, project_id):
"""Check if project exists""" """Check if project exists"""
return self.get_project(project_id) != None return self.get_project(project_id) != None
@@ -359,13 +318,6 @@ class LdapDriver(object):
"""Check if group exists""" """Check if group exists"""
return self.__find_object(dn, '(objectclass=groupOfNames)') != None return self.__find_object(dn, '(objectclass=groupOfNames)') != None
def __delete_key_pairs(self, uid):
"""Delete all key pairs for user"""
keys = self.get_key_pairs(uid)
if keys != None:
for key in keys:
self.delete_key_pair(uid, key['name'])
@staticmethod @staticmethod
def __role_to_dn(role, project_id=None): def __role_to_dn(role, project_id=None):
"""Convert role to corresponding dn""" """Convert role to corresponding dn"""
@@ -490,18 +442,6 @@ class LdapDriver(object):
'secret': attr['secretKey'][0], 'secret': attr['secretKey'][0],
'admin': (attr['isAdmin'][0] == 'TRUE')} 'admin': (attr['isAdmin'][0] == 'TRUE')}
@staticmethod
def __to_key_pair(owner, attr):
"""Convert ldap attributes to KeyPair object"""
if attr == None:
return None
return {
'id': attr['cn'][0],
'name': attr['cn'][0],
'owner_id': owner,
'public_key': attr['sshPublicKey'][0],
'fingerprint': attr['keyFingerprint'][0]}
def __to_project(self, attr): def __to_project(self, attr):
"""Convert ldap attributes to Project object""" """Convert ldap attributes to Project object"""
if attr == None: if attr == None:

View File

@@ -128,24 +128,6 @@ class User(AuthBase):
def is_project_manager(self, project): def is_project_manager(self, project):
return AuthManager().is_project_manager(self, project) return AuthManager().is_project_manager(self, project)
def generate_key_pair(self, name):
return AuthManager().generate_key_pair(self.id, name)
def create_key_pair(self, name, public_key, fingerprint):
return AuthManager().create_key_pair(self.id,
name,
public_key,
fingerprint)
def get_key_pair(self, name):
return AuthManager().get_key_pair(self.id, name)
def delete_key_pair(self, name):
return AuthManager().delete_key_pair(self.id, name)
def get_key_pairs(self):
return AuthManager().get_key_pairs(self.id)
def __repr__(self): def __repr__(self):
return "User('%s', '%s', '%s', '%s', %s)" % (self.id, return "User('%s', '%s', '%s', '%s', %s)" % (self.id,
self.name, self.name,
@@ -154,29 +136,6 @@ class User(AuthBase):
self.admin) self.admin)
class KeyPair(AuthBase):
"""Represents an ssh key returned from the datastore
Even though this object is named KeyPair, only the public key and
fingerprint is stored. The user's private key is not saved.
"""
def __init__(self, id, name, owner_id, public_key, fingerprint):
AuthBase.__init__(self)
self.id = id
self.name = name
self.owner_id = owner_id
self.public_key = public_key
self.fingerprint = fingerprint
def __repr__(self):
return "KeyPair('%s', '%s', '%s', '%s', '%s')" % (self.id,
self.name,
self.owner_id,
self.public_key,
self.fingerprint)
class Project(AuthBase): class Project(AuthBase):
"""Represents a Project returned from the datastore""" """Represents a Project returned from the datastore"""
@@ -663,67 +622,13 @@ class AuthManager(object):
return User(**user_dict) return User(**user_dict)
def delete_user(self, user): def delete_user(self, user):
"""Deletes a user""" """Deletes a user
with self.driver() as drv:
drv.delete_user(User.safe_id(user))
def generate_key_pair(self, user, key_name): Additionally deletes all users key_pairs"""
"""Generates a key pair for a user
Generates a public and private key, stores the public key using the
key_name, and returns the private key and fingerprint.
@type user: User or uid
@param user: User for which to create key pair.
@type key_name: str
@param key_name: Name to use for the generated KeyPair.
@rtype: tuple (private_key, fingerprint)
@return: A tuple containing the private_key and fingerprint.
"""
# NOTE(vish): generating key pair is slow so check for legal
# creation before creating keypair
uid = User.safe_id(user) uid = User.safe_id(user)
db.key_pair_destroy_all_by_user(None, uid)
with self.driver() as drv: with self.driver() as drv:
if not drv.get_user(uid): drv.delete_user(uid)
raise exception.NotFound("User %s doesn't exist" % user)
if drv.get_key_pair(uid, key_name):
raise exception.Duplicate("The keypair %s already exists"
% key_name)
private_key, public_key, fingerprint = crypto.generate_key_pair()
self.create_key_pair(uid, key_name, public_key, fingerprint)
return private_key, fingerprint
def create_key_pair(self, user, key_name, public_key, fingerprint):
"""Creates a key pair for user"""
with self.driver() as drv:
kp_dict = drv.create_key_pair(User.safe_id(user),
key_name,
public_key,
fingerprint)
if kp_dict:
return KeyPair(**kp_dict)
def get_key_pair(self, user, key_name):
"""Retrieves a key pair for user"""
with self.driver() as drv:
kp_dict = drv.get_key_pair(User.safe_id(user), key_name)
if kp_dict:
return KeyPair(**kp_dict)
def get_key_pairs(self, user):
"""Retrieves all key pairs for user"""
with self.driver() as drv:
kp_list = drv.get_key_pairs(User.safe_id(user))
if not kp_list:
return []
return [KeyPair(**kp_dict) for kp_dict in kp_list]
def delete_key_pair(self, user, key_name):
"""Deletes a key pair for user"""
with self.driver() as drv:
drv.delete_key_pair(User.safe_id(user), key_name)
def get_credentials(self, user, project=None): def get_credentials(self, user, project=None):
"""Get credential zip for user in project""" """Get credential zip for user in project"""

View File

@@ -30,6 +30,7 @@ import time
from twisted.internet import defer from twisted.internet import defer
from nova import crypto
from nova import db from nova import db
from nova import exception from nova import exception
from nova import flags from nova import flags
@@ -37,7 +38,6 @@ from nova import quota
from nova import rpc from nova import rpc
from nova import utils from nova import utils
from nova.auth import rbac from nova.auth import rbac
from nova.auth import manager
from nova.compute.instance_types import INSTANCE_TYPES from nova.compute.instance_types import INSTANCE_TYPES
from nova.endpoint import images from nova.endpoint import images
@@ -51,14 +51,30 @@ class QuotaError(exception.ApiError):
pass pass
def _gen_key(user_id, key_name): def _gen_key(context, user_id, key_name):
""" Tuck this into AuthManager """ """Generate a key
This is a module level method because it is slow and we need to defer
it into a process pool."""
try: try:
mgr = manager.AuthManager() # NOTE(vish): generating key pair is slow so check for legal
private_key, fingerprint = mgr.generate_key_pair(user_id, key_name) # creation before creating key_pair
try:
db.key_pair_get(context, user_id, key_name)
raise exception.Duplicate("The key_pair %s already exists"
% key_name)
except exception.NotFound:
pass
private_key, public_key, fingerprint = crypto.generate_key_pair()
key = {}
key['user_id'] = user_id
key['name'] = key_name
key['public_key'] = public_key
key['fingerprint'] = fingerprint
db.key_pair_create(context, key)
return {'private_key': private_key, 'fingerprint': fingerprint}
except Exception as ex: except Exception as ex:
return {'exception': ex} return {'exception': ex}
return {'private_key': private_key, 'fingerprint': fingerprint}
class CloudController(object): class CloudController(object):
@@ -194,18 +210,18 @@ class CloudController(object):
@rbac.allow('all') @rbac.allow('all')
def describe_key_pairs(self, context, key_name=None, **kwargs): def describe_key_pairs(self, context, key_name=None, **kwargs):
key_pairs = context.user.get_key_pairs() key_pairs = db.key_pair_get_all_by_user(context, context.user.id)
if not key_name is None: if not key_name is None:
key_pairs = [x for x in key_pairs if x.name in key_name] key_pairs = [x for x in key_pairs if x['name'] in key_name]
result = [] result = []
for key_pair in key_pairs: for key_pair in key_pairs:
# filter out the vpn keys # filter out the vpn keys
suffix = FLAGS.vpn_key_suffix suffix = FLAGS.vpn_key_suffix
if context.user.is_admin() or not key_pair.name.endswith(suffix): if context.user.is_admin() or not key_pair['name'].endswith(suffix):
result.append({ result.append({
'keyName': key_pair.name, 'keyName': key_pair['name'],
'keyFingerprint': key_pair.fingerprint, 'keyFingerprint': key_pair['fingerprint'],
}) })
return {'keypairsSet': result} return {'keypairsSet': result}
@@ -221,14 +237,18 @@ class CloudController(object):
dcall.callback({'keyName': key_name, dcall.callback({'keyName': key_name,
'keyFingerprint': kwargs['fingerprint'], 'keyFingerprint': kwargs['fingerprint'],
'keyMaterial': kwargs['private_key']}) 'keyMaterial': kwargs['private_key']})
pool.apply_async(_gen_key, [context.user.id, key_name], # TODO(vish): when context is no longer an object, pass it here
pool.apply_async(_gen_key, [None, context.user.id, key_name],
callback=_complete) callback=_complete)
return dcall return dcall
@rbac.allow('all') @rbac.allow('all')
def delete_key_pair(self, context, key_name, **kwargs): def delete_key_pair(self, context, key_name, **kwargs):
context.user.delete_key_pair(key_name) try:
# aws returns true even if the key doens't exist db.key_pair_destroy(context, context.user.id, key_name)
except exception.NotFound:
# aws returns true even if the key doesn't exist
pass
return True return True
@rbac.allow('all') @rbac.allow('all')
@@ -576,11 +596,10 @@ class CloudController(object):
launch_time = time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime()) launch_time = time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime())
key_data = None key_data = None
if kwargs.has_key('key_name'): if kwargs.has_key('key_name'):
key_pair = context.user.get_key_pair(kwargs['key_name']) key_pair_ref = db.key_pair_get(context,
if not key_pair: context.user.id,
raise exception.ApiError('Key Pair %s not found' % kwargs['key_name'])
kwargs['key_name']) key_data = key_pair_ref['public_key']
key_data = key_pair.public_key
# TODO: Get the real security group of launch in here # TODO: Get the real security group of launch in here
security_group = "default" security_group = "default"

View File

@@ -41,8 +41,8 @@ FLAGS = flags.FLAGS
# it's pretty damn circuitous so apologies if you have to fix # it's pretty damn circuitous so apologies if you have to fix
# a bug in it # a bug in it
# NOTE(jaypipes) The pylint disables here are for R0913 (too many args) which # NOTE(jaypipes) The pylint disables here are for R0913 (too many args) which
# isn't controllable since boto's HTTPRequest needs that many # isn't controllable since boto's HTTPRequest needs that many
# args, and for the version-differentiated import of tornado's # args, and for the version-differentiated import of tornado's
# httputil. # httputil.
# NOTE(jaypipes): The disable-msg=E1101 and E1103 below is because pylint is # NOTE(jaypipes): The disable-msg=E1101 and E1103 below is because pylint is
# unable to introspect the deferred's return value properly # unable to introspect the deferred's return value properly
@@ -224,7 +224,8 @@ class ApiEc2TestCase(test.BaseTestCase):
for x in range(random.randint(4, 8))) for x in range(random.randint(4, 8)))
user = self.manager.create_user('fake', 'fake', 'fake') user = self.manager.create_user('fake', 'fake', 'fake')
project = self.manager.create_project('fake', 'fake', 'fake') project = self.manager.create_project('fake', 'fake', 'fake')
self.manager.generate_key_pair(user.id, keyname) # NOTE(vish): create depends on pool, so call helper directly
cloud._gen_key(None, user.id, keyname)
rv = self.ec2.get_all_key_pairs() rv = self.ec2.get_all_key_pairs()
results = [k for k in rv if k.name == keyname] results = [k for k in rv if k.name == keyname]

View File

@@ -17,8 +17,6 @@
# under the License. # under the License.
import logging import logging
from M2Crypto import BIO
from M2Crypto import RSA
from M2Crypto import X509 from M2Crypto import X509
import unittest import unittest
@@ -65,35 +63,6 @@ class AuthTestCase(test.BaseTestCase):
'export S3_URL="http://127.0.0.1:3333/"\n' + 'export S3_URL="http://127.0.0.1:3333/"\n' +
'export EC2_USER_ID="test1"\n') 'export EC2_USER_ID="test1"\n')
def test_006_test_key_storage(self):
user = self.manager.get_user('test1')
user.create_key_pair('public', 'key', 'fingerprint')
key = user.get_key_pair('public')
self.assertEqual('key', key.public_key)
self.assertEqual('fingerprint', key.fingerprint)
def test_007_test_key_generation(self):
user = self.manager.get_user('test1')
private_key, fingerprint = user.generate_key_pair('public2')
key = RSA.load_key_string(private_key, callback=lambda: None)
bio = BIO.MemoryBuffer()
public_key = user.get_key_pair('public2').public_key
key.save_pub_key_bio(bio)
converted = crypto.ssl_pub_to_ssh_pub(bio.read())
# assert key fields are equal
self.assertEqual(public_key.split(" ")[1].strip(),
converted.split(" ")[1].strip())
def test_008_can_list_key_pairs(self):
keys = self.manager.get_user('test1').get_key_pairs()
self.assertTrue(filter(lambda k: k.name == 'public', keys))
self.assertTrue(filter(lambda k: k.name == 'public2', keys))
def test_009_can_delete_key_pair(self):
self.manager.get_user('test1').delete_key_pair('public')
keys = self.manager.get_user('test1').get_key_pairs()
self.assertFalse(filter(lambda k: k.name == 'public', keys))
def test_010_can_list_users(self): def test_010_can_list_users(self):
users = self.manager.get_users() users = self.manager.get_users()
logging.warn(users) logging.warn(users)

View File

@@ -17,13 +17,18 @@
# under the License. # under the License.
import logging import logging
from M2Crypto import BIO
from M2Crypto import RSA
import StringIO import StringIO
import time import time
from tornado import ioloop from tornado import ioloop
from twisted.internet import defer from twisted.internet import defer
import unittest import unittest
from xml.etree import ElementTree from xml.etree import ElementTree
from nova import crypto
from nova import db
from nova import flags from nova import flags
from nova import rpc from nova import rpc
from nova import test from nova import test
@@ -55,16 +60,21 @@ class CloudTestCase(test.BaseTestCase):
proxy=self.compute) proxy=self.compute)
self.injected.append(self.compute_consumer.attach_to_tornado(self.ioloop)) self.injected.append(self.compute_consumer.attach_to_tornado(self.ioloop))
try: self.manager = manager.AuthManager()
manager.AuthManager().create_user('admin', 'admin', 'admin') self.user = self.manager.create_user('admin', 'admin', 'admin', True)
except: pass self.project = self.manager.create_project('proj', 'admin', 'proj')
admin = manager.AuthManager().get_user('admin') self.context = api.APIRequestContext(handler=None,
project = manager.AuthManager().create_project('proj', 'admin', 'proj') user=self.user,
self.context = api.APIRequestContext(handler=None,project=project,user=admin) project=self.project)
def tearDown(self): def tearDown(self):
manager.AuthManager().delete_project('proj') self.manager.delete_project(self.project)
manager.AuthManager().delete_user('admin') self.manager.delete_user(self.user)
super(CloudTestCase, self).setUp()
def _create_key(self, name):
# NOTE(vish): create depends on pool, so just call helper directly
return cloud._gen_key(self.context, self.context.user.id, name)
def test_console_output(self): def test_console_output(self):
if FLAGS.connection_type == 'fake': if FLAGS.connection_type == 'fake':
@@ -77,6 +87,33 @@ class CloudTestCase(test.BaseTestCase):
self.assert_(output) self.assert_(output)
rv = yield self.compute.terminate_instance(instance_id) rv = yield self.compute.terminate_instance(instance_id)
def test_key_generation(self):
result = self._create_key('test')
private_key = result['private_key']
key = RSA.load_key_string(private_key, callback=lambda: None)
bio = BIO.MemoryBuffer()
public_key = db.key_pair_get(self.context,
self.context.user.id,
'test')['public_key']
key.save_pub_key_bio(bio)
converted = crypto.ssl_pub_to_ssh_pub(bio.read())
# assert key fields are equal
self.assertEqual(public_key.split(" ")[1].strip(),
converted.split(" ")[1].strip())
def test_describe_key_pairs(self):
self._create_key('test1')
self._create_key('test2')
result = self.cloud.describe_key_pairs(self.context)
keys = result["keypairsSet"]
self.assertTrue(filter(lambda k: k['keyName'] == 'test1', keys))
self.assertTrue(filter(lambda k: k['keyName'] == 'test2', keys))
def test_delete_key_pair(self):
self._create_key('test')
self.cloud.delete_key_pair(self.context, 'test')
def test_run_instances(self): def test_run_instances(self):
if FLAGS.connection_type == 'fake': if FLAGS.connection_type == 'fake':
logging.debug("Can't test instances without a real virtual env.") logging.debug("Can't test instances without a real virtual env.")