move keypair generation out of auth and fix tests

This commit is contained in:
Vishvananda Ishaya
2010-09-10 22:13:36 -07:00
parent c88569728e
commit ad0de7c9b2
5 changed files with 83 additions and 126 deletions

View File

@@ -128,24 +128,6 @@ class User(AuthBase):
def is_project_manager(self, project):
return AuthManager().is_project_manager(self, project)
def generate_key_pair(self, name):
return AuthManager().generate_key_pair(self.id, name)
def create_key_pair(self, name, public_key, fingerprint):
return AuthManager().create_key_pair(self.id,
name,
public_key,
fingerprint)
def get_key_pair(self, name):
return AuthManager().get_key_pair(self.id, name)
def delete_key_pair(self, name):
return AuthManager().delete_key_pair(self.id, name)
def get_key_pairs(self):
return AuthManager().get_key_pairs(self.id)
def __repr__(self):
return "User('%s', '%s', '%s', '%s', %s)" % (self.id,
self.name,
@@ -628,58 +610,6 @@ class AuthManager(object):
with self.driver() as drv:
drv.delete_user(uid)
def generate_key_pair(self, user, key_name):
"""Generates a key pair for a user
Generates a public and private key, stores the public key using the
key_name, and returns the private key and fingerprint.
@type user: User or uid
@param user: User for which to create key pair.
@type key_name: str
@param key_name: Name to use for the generated KeyPair.
@rtype: tuple (private_key, fingerprint)
@return: A tuple containing the private_key and fingerprint.
"""
# NOTE(vish): generating key pair is slow so check for legal
# creation before creating keypair
uid = User.safe_id(user)
with self.driver() as drv:
if not drv.get_user(uid):
raise exception.NotFound("User %s doesn't exist" % user)
try:
db.keypair_get(None, uid, key_name)
raise exception.Duplicate("The keypair %s already exists"
% key_name)
except exception.NotFound:
pass
private_key, public_key, fingerprint = crypto.generate_key_pair()
self.create_key_pair(uid, key_name, public_key, fingerprint)
return private_key, fingerprint
def create_key_pair(self, user, key_name, public_key, fingerprint):
"""Creates a key pair for user"""
key = {}
key['user_id'] = User.safe_id(user)
key['name'] = key_name
key['public_key'] = public_key
key['fingerprint'] = fingerprint
return db.keypair_create(None, key)
def get_key_pair(self, user, key_name):
"""Retrieves a key pair for user"""
return db.keypair_get(None, User.safe_id(user), key_name)
def get_key_pairs(self, user):
"""Retrieves all key pairs for user"""
return db.keypair_get_all_by_user(None, User.safe_id(user))
def delete_key_pair(self, user, key_name):
"""Deletes a key pair for user"""
return db.keypair_destroy(None, User.safe_id(user), key_name)
def get_credentials(self, user, project=None):
"""Get credential zip for user in project"""
if not isinstance(user, User):

View File

@@ -29,13 +29,13 @@ import time
from twisted.internet import defer
from nova import crypto
from nova import db
from nova import exception
from nova import flags
from nova import rpc
from nova import utils
from nova.auth import rbac
from nova.auth import manager
from nova.compute.instance_types import INSTANCE_TYPES
from nova.endpoint import images
@@ -44,14 +44,30 @@ FLAGS = flags.FLAGS
flags.DECLARE('storage_availability_zone', 'nova.volume.manager')
def _gen_key(user_id, key_name):
""" Tuck this into AuthManager """
def _gen_key(context, user_id, key_name):
"""Generate a key
This is a module level method because it is slow and we need to defer
it into a process pool."""
try:
mgr = manager.AuthManager()
private_key, fingerprint = mgr.generate_key_pair(user_id, key_name)
# NOTE(vish): generating key pair is slow so check for legal
# creation before creating keypair
try:
db.keypair_get(context, user_id, key_name)
raise exception.Duplicate("The keypair %s already exists"
% key_name)
except exception.NotFound:
pass
private_key, public_key, fingerprint = crypto.generate_key_pair()
key = {}
key['user_id'] = user_id
key['name'] = key_name
key['public_key'] = public_key
key['fingerprint'] = fingerprint
db.keypair_create(context, key)
return {'private_key': private_key, 'fingerprint': fingerprint}
except Exception as ex:
return {'exception': ex}
return {'private_key': private_key, 'fingerprint': fingerprint}
class CloudController(object):
@@ -177,18 +193,18 @@ class CloudController(object):
@rbac.allow('all')
def describe_key_pairs(self, context, key_name=None, **kwargs):
key_pairs = context.user.get_key_pairs()
key_pairs = db.keypair_get_all_by_user(context, context.user.id)
if not key_name is None:
key_pairs = [x for x in key_pairs if x.name in key_name]
key_pairs = [x for x in key_pairs if x['name'] in key_name]
result = []
for key_pair in key_pairs:
# filter out the vpn keys
suffix = FLAGS.vpn_key_suffix
if context.user.is_admin() or not key_pair.name.endswith(suffix):
if context.user.is_admin() or not key_pair['name'].endswith(suffix):
result.append({
'keyName': key_pair.name,
'keyFingerprint': key_pair.fingerprint,
'keyName': key_pair['name'],
'keyFingerprint': key_pair['fingerprint'],
})
return {'keypairsSet': result}
@@ -204,14 +220,18 @@ class CloudController(object):
dcall.callback({'keyName': key_name,
'keyFingerprint': kwargs['fingerprint'],
'keyMaterial': kwargs['private_key']})
pool.apply_async(_gen_key, [context.user.id, key_name],
# TODO(vish): when context is no longer an object, pass it here
pool.apply_async(_gen_key, [None, context.user.id, key_name],
callback=_complete)
return dcall
@rbac.allow('all')
def delete_key_pair(self, context, key_name, **kwargs):
context.user.delete_key_pair(key_name)
# aws returns true even if the key doens't exist
try:
db.keypair_destroy(context, context.user.id, key_name)
except exception.NotFound:
# aws returns true even if the key doesn't exist
pass
return True
@rbac.allow('all')

View File

@@ -224,7 +224,8 @@ class ApiEc2TestCase(test.BaseTestCase):
for x in range(random.randint(4, 8)))
user = self.manager.create_user('fake', 'fake', 'fake')
project = self.manager.create_project('fake', 'fake', 'fake')
self.manager.generate_key_pair(user.id, keyname)
# NOTE(vish): create depends on pool, so call helper directly
cloud._gen_key(None, user.id, keyname)
rv = self.ec2.get_all_key_pairs()
results = [k for k in rv if k.name == keyname]

View File

@@ -17,8 +17,6 @@
# under the License.
import logging
from M2Crypto import BIO
from M2Crypto import RSA
from M2Crypto import X509
import unittest
@@ -65,35 +63,6 @@ class AuthTestCase(test.BaseTestCase):
'export S3_URL="http://127.0.0.1:3333/"\n' +
'export EC2_USER_ID="test1"\n')
def test_006_test_key_storage(self):
user = self.manager.get_user('test1')
user.create_key_pair('public', 'key', 'fingerprint')
key = user.get_key_pair('public')
self.assertEqual('key', key.public_key)
self.assertEqual('fingerprint', key.fingerprint)
def test_007_test_key_generation(self):
user = self.manager.get_user('test1')
private_key, fingerprint = user.generate_key_pair('public2')
key = RSA.load_key_string(private_key, callback=lambda: None)
bio = BIO.MemoryBuffer()
public_key = user.get_key_pair('public2').public_key
key.save_pub_key_bio(bio)
converted = crypto.ssl_pub_to_ssh_pub(bio.read())
# assert key fields are equal
self.assertEqual(public_key.split(" ")[1].strip(),
converted.split(" ")[1].strip())
def test_008_can_list_key_pairs(self):
keys = self.manager.get_user('test1').get_key_pairs()
self.assertTrue(filter(lambda k: k.name == 'public', keys))
self.assertTrue(filter(lambda k: k.name == 'public2', keys))
def test_009_can_delete_key_pair(self):
self.manager.get_user('test1').delete_key_pair('public')
keys = self.manager.get_user('test1').get_key_pairs()
self.assertFalse(filter(lambda k: k.name == 'public', keys))
def test_010_can_list_users(self):
users = self.manager.get_users()
logging.warn(users)

View File

@@ -17,13 +17,18 @@
# under the License.
import logging
from M2Crypto import BIO
from M2Crypto import RSA
import StringIO
import time
from tornado import ioloop
from twisted.internet import defer
import unittest
from xml.etree import ElementTree
from nova import crypto
from nova import db
from nova import flags
from nova import rpc
from nova import test
@@ -54,16 +59,21 @@ class CloudTestCase(test.BaseTestCase):
proxy=self.compute)
self.injected.append(self.compute_consumer.attach_to_tornado(self.ioloop))
try:
manager.AuthManager().create_user('admin', 'admin', 'admin')
except: pass
admin = manager.AuthManager().get_user('admin')
project = manager.AuthManager().create_project('proj', 'admin', 'proj')
self.context = api.APIRequestContext(handler=None,project=project,user=admin)
self.manager = manager.AuthManager()
self.user = self.manager.create_user('admin', 'admin', 'admin', True)
self.project = self.manager.create_project('proj', 'admin', 'proj')
self.context = api.APIRequestContext(handler=None,
user=self.user,
project=self.project)
def tearDown(self):
manager.AuthManager().delete_project('proj')
manager.AuthManager().delete_user('admin')
self.manager.delete_project(self.project)
self.manager.delete_user(self.user)
super(CloudTestCase, self).setUp()
def _create_key(self, name):
# NOTE(vish): create depends on pool, so just call helper directly
return cloud._gen_key(self.context, self.context.user.id, name)
def test_console_output(self):
if FLAGS.connection_type == 'fake':
@@ -76,6 +86,33 @@ class CloudTestCase(test.BaseTestCase):
self.assert_(output)
rv = yield self.compute.terminate_instance(instance_id)
def test_key_generation(self):
result = self._create_key('test')
private_key = result['private_key']
key = RSA.load_key_string(private_key, callback=lambda: None)
bio = BIO.MemoryBuffer()
public_key = db.keypair_get(self.context,
self.context.user.id,
'test')['public_key']
key.save_pub_key_bio(bio)
converted = crypto.ssl_pub_to_ssh_pub(bio.read())
# assert key fields are equal
self.assertEqual(public_key.split(" ")[1].strip(),
converted.split(" ")[1].strip())
def test_describe_key_pairs(self):
self._create_key('test1')
self._create_key('test2')
result = self.cloud.describe_key_pairs(self.context)
keys = result["keypairsSet"]
self.assertTrue(filter(lambda k: k['keyName'] == 'test1', keys))
self.assertTrue(filter(lambda k: k['keyName'] == 'test2', keys))
def test_delete_key_pair(self):
self._create_key('test')
self.cloud.delete_key_pair(self.context, 'test')
def test_run_instances(self):
if FLAGS.connection_type == 'fake':
logging.debug("Can't test instances without a real virtual env.")