move keypair generation out of auth and fix tests
This commit is contained in:
@@ -128,24 +128,6 @@ class User(AuthBase):
|
||||
def is_project_manager(self, project):
|
||||
return AuthManager().is_project_manager(self, project)
|
||||
|
||||
def generate_key_pair(self, name):
|
||||
return AuthManager().generate_key_pair(self.id, name)
|
||||
|
||||
def create_key_pair(self, name, public_key, fingerprint):
|
||||
return AuthManager().create_key_pair(self.id,
|
||||
name,
|
||||
public_key,
|
||||
fingerprint)
|
||||
|
||||
def get_key_pair(self, name):
|
||||
return AuthManager().get_key_pair(self.id, name)
|
||||
|
||||
def delete_key_pair(self, name):
|
||||
return AuthManager().delete_key_pair(self.id, name)
|
||||
|
||||
def get_key_pairs(self):
|
||||
return AuthManager().get_key_pairs(self.id)
|
||||
|
||||
def __repr__(self):
|
||||
return "User('%s', '%s', '%s', '%s', %s)" % (self.id,
|
||||
self.name,
|
||||
@@ -628,58 +610,6 @@ class AuthManager(object):
|
||||
with self.driver() as drv:
|
||||
drv.delete_user(uid)
|
||||
|
||||
def generate_key_pair(self, user, key_name):
|
||||
"""Generates a key pair for a user
|
||||
|
||||
Generates a public and private key, stores the public key using the
|
||||
key_name, and returns the private key and fingerprint.
|
||||
|
||||
@type user: User or uid
|
||||
@param user: User for which to create key pair.
|
||||
|
||||
@type key_name: str
|
||||
@param key_name: Name to use for the generated KeyPair.
|
||||
|
||||
@rtype: tuple (private_key, fingerprint)
|
||||
@return: A tuple containing the private_key and fingerprint.
|
||||
"""
|
||||
# NOTE(vish): generating key pair is slow so check for legal
|
||||
# creation before creating keypair
|
||||
uid = User.safe_id(user)
|
||||
with self.driver() as drv:
|
||||
if not drv.get_user(uid):
|
||||
raise exception.NotFound("User %s doesn't exist" % user)
|
||||
try:
|
||||
db.keypair_get(None, uid, key_name)
|
||||
raise exception.Duplicate("The keypair %s already exists"
|
||||
% key_name)
|
||||
except exception.NotFound:
|
||||
pass
|
||||
private_key, public_key, fingerprint = crypto.generate_key_pair()
|
||||
self.create_key_pair(uid, key_name, public_key, fingerprint)
|
||||
return private_key, fingerprint
|
||||
|
||||
def create_key_pair(self, user, key_name, public_key, fingerprint):
|
||||
"""Creates a key pair for user"""
|
||||
key = {}
|
||||
key['user_id'] = User.safe_id(user)
|
||||
key['name'] = key_name
|
||||
key['public_key'] = public_key
|
||||
key['fingerprint'] = fingerprint
|
||||
return db.keypair_create(None, key)
|
||||
|
||||
def get_key_pair(self, user, key_name):
|
||||
"""Retrieves a key pair for user"""
|
||||
return db.keypair_get(None, User.safe_id(user), key_name)
|
||||
|
||||
def get_key_pairs(self, user):
|
||||
"""Retrieves all key pairs for user"""
|
||||
return db.keypair_get_all_by_user(None, User.safe_id(user))
|
||||
|
||||
def delete_key_pair(self, user, key_name):
|
||||
"""Deletes a key pair for user"""
|
||||
return db.keypair_destroy(None, User.safe_id(user), key_name)
|
||||
|
||||
def get_credentials(self, user, project=None):
|
||||
"""Get credential zip for user in project"""
|
||||
if not isinstance(user, User):
|
||||
|
||||
@@ -29,13 +29,13 @@ import time
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from nova import crypto
|
||||
from nova import db
|
||||
from nova import exception
|
||||
from nova import flags
|
||||
from nova import rpc
|
||||
from nova import utils
|
||||
from nova.auth import rbac
|
||||
from nova.auth import manager
|
||||
from nova.compute.instance_types import INSTANCE_TYPES
|
||||
from nova.endpoint import images
|
||||
|
||||
@@ -44,14 +44,30 @@ FLAGS = flags.FLAGS
|
||||
flags.DECLARE('storage_availability_zone', 'nova.volume.manager')
|
||||
|
||||
|
||||
def _gen_key(user_id, key_name):
|
||||
""" Tuck this into AuthManager """
|
||||
def _gen_key(context, user_id, key_name):
|
||||
"""Generate a key
|
||||
|
||||
This is a module level method because it is slow and we need to defer
|
||||
it into a process pool."""
|
||||
try:
|
||||
mgr = manager.AuthManager()
|
||||
private_key, fingerprint = mgr.generate_key_pair(user_id, key_name)
|
||||
# NOTE(vish): generating key pair is slow so check for legal
|
||||
# creation before creating keypair
|
||||
try:
|
||||
db.keypair_get(context, user_id, key_name)
|
||||
raise exception.Duplicate("The keypair %s already exists"
|
||||
% key_name)
|
||||
except exception.NotFound:
|
||||
pass
|
||||
private_key, public_key, fingerprint = crypto.generate_key_pair()
|
||||
key = {}
|
||||
key['user_id'] = user_id
|
||||
key['name'] = key_name
|
||||
key['public_key'] = public_key
|
||||
key['fingerprint'] = fingerprint
|
||||
db.keypair_create(context, key)
|
||||
return {'private_key': private_key, 'fingerprint': fingerprint}
|
||||
except Exception as ex:
|
||||
return {'exception': ex}
|
||||
return {'private_key': private_key, 'fingerprint': fingerprint}
|
||||
|
||||
|
||||
class CloudController(object):
|
||||
@@ -177,18 +193,18 @@ class CloudController(object):
|
||||
|
||||
@rbac.allow('all')
|
||||
def describe_key_pairs(self, context, key_name=None, **kwargs):
|
||||
key_pairs = context.user.get_key_pairs()
|
||||
key_pairs = db.keypair_get_all_by_user(context, context.user.id)
|
||||
if not key_name is None:
|
||||
key_pairs = [x for x in key_pairs if x.name in key_name]
|
||||
key_pairs = [x for x in key_pairs if x['name'] in key_name]
|
||||
|
||||
result = []
|
||||
for key_pair in key_pairs:
|
||||
# filter out the vpn keys
|
||||
suffix = FLAGS.vpn_key_suffix
|
||||
if context.user.is_admin() or not key_pair.name.endswith(suffix):
|
||||
if context.user.is_admin() or not key_pair['name'].endswith(suffix):
|
||||
result.append({
|
||||
'keyName': key_pair.name,
|
||||
'keyFingerprint': key_pair.fingerprint,
|
||||
'keyName': key_pair['name'],
|
||||
'keyFingerprint': key_pair['fingerprint'],
|
||||
})
|
||||
|
||||
return {'keypairsSet': result}
|
||||
@@ -204,14 +220,18 @@ class CloudController(object):
|
||||
dcall.callback({'keyName': key_name,
|
||||
'keyFingerprint': kwargs['fingerprint'],
|
||||
'keyMaterial': kwargs['private_key']})
|
||||
pool.apply_async(_gen_key, [context.user.id, key_name],
|
||||
# TODO(vish): when context is no longer an object, pass it here
|
||||
pool.apply_async(_gen_key, [None, context.user.id, key_name],
|
||||
callback=_complete)
|
||||
return dcall
|
||||
|
||||
@rbac.allow('all')
|
||||
def delete_key_pair(self, context, key_name, **kwargs):
|
||||
context.user.delete_key_pair(key_name)
|
||||
# aws returns true even if the key doens't exist
|
||||
try:
|
||||
db.keypair_destroy(context, context.user.id, key_name)
|
||||
except exception.NotFound:
|
||||
# aws returns true even if the key doesn't exist
|
||||
pass
|
||||
return True
|
||||
|
||||
@rbac.allow('all')
|
||||
|
||||
@@ -224,7 +224,8 @@ class ApiEc2TestCase(test.BaseTestCase):
|
||||
for x in range(random.randint(4, 8)))
|
||||
user = self.manager.create_user('fake', 'fake', 'fake')
|
||||
project = self.manager.create_project('fake', 'fake', 'fake')
|
||||
self.manager.generate_key_pair(user.id, keyname)
|
||||
# NOTE(vish): create depends on pool, so call helper directly
|
||||
cloud._gen_key(None, user.id, keyname)
|
||||
|
||||
rv = self.ec2.get_all_key_pairs()
|
||||
results = [k for k in rv if k.name == keyname]
|
||||
|
||||
@@ -17,8 +17,6 @@
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from M2Crypto import BIO
|
||||
from M2Crypto import RSA
|
||||
from M2Crypto import X509
|
||||
import unittest
|
||||
|
||||
@@ -65,35 +63,6 @@ class AuthTestCase(test.BaseTestCase):
|
||||
'export S3_URL="http://127.0.0.1:3333/"\n' +
|
||||
'export EC2_USER_ID="test1"\n')
|
||||
|
||||
def test_006_test_key_storage(self):
|
||||
user = self.manager.get_user('test1')
|
||||
user.create_key_pair('public', 'key', 'fingerprint')
|
||||
key = user.get_key_pair('public')
|
||||
self.assertEqual('key', key.public_key)
|
||||
self.assertEqual('fingerprint', key.fingerprint)
|
||||
|
||||
def test_007_test_key_generation(self):
|
||||
user = self.manager.get_user('test1')
|
||||
private_key, fingerprint = user.generate_key_pair('public2')
|
||||
key = RSA.load_key_string(private_key, callback=lambda: None)
|
||||
bio = BIO.MemoryBuffer()
|
||||
public_key = user.get_key_pair('public2').public_key
|
||||
key.save_pub_key_bio(bio)
|
||||
converted = crypto.ssl_pub_to_ssh_pub(bio.read())
|
||||
# assert key fields are equal
|
||||
self.assertEqual(public_key.split(" ")[1].strip(),
|
||||
converted.split(" ")[1].strip())
|
||||
|
||||
def test_008_can_list_key_pairs(self):
|
||||
keys = self.manager.get_user('test1').get_key_pairs()
|
||||
self.assertTrue(filter(lambda k: k.name == 'public', keys))
|
||||
self.assertTrue(filter(lambda k: k.name == 'public2', keys))
|
||||
|
||||
def test_009_can_delete_key_pair(self):
|
||||
self.manager.get_user('test1').delete_key_pair('public')
|
||||
keys = self.manager.get_user('test1').get_key_pairs()
|
||||
self.assertFalse(filter(lambda k: k.name == 'public', keys))
|
||||
|
||||
def test_010_can_list_users(self):
|
||||
users = self.manager.get_users()
|
||||
logging.warn(users)
|
||||
|
||||
@@ -17,13 +17,18 @@
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from M2Crypto import BIO
|
||||
from M2Crypto import RSA
|
||||
import StringIO
|
||||
import time
|
||||
|
||||
from tornado import ioloop
|
||||
from twisted.internet import defer
|
||||
import unittest
|
||||
from xml.etree import ElementTree
|
||||
|
||||
from nova import crypto
|
||||
from nova import db
|
||||
from nova import flags
|
||||
from nova import rpc
|
||||
from nova import test
|
||||
@@ -54,16 +59,21 @@ class CloudTestCase(test.BaseTestCase):
|
||||
proxy=self.compute)
|
||||
self.injected.append(self.compute_consumer.attach_to_tornado(self.ioloop))
|
||||
|
||||
try:
|
||||
manager.AuthManager().create_user('admin', 'admin', 'admin')
|
||||
except: pass
|
||||
admin = manager.AuthManager().get_user('admin')
|
||||
project = manager.AuthManager().create_project('proj', 'admin', 'proj')
|
||||
self.context = api.APIRequestContext(handler=None,project=project,user=admin)
|
||||
self.manager = manager.AuthManager()
|
||||
self.user = self.manager.create_user('admin', 'admin', 'admin', True)
|
||||
self.project = self.manager.create_project('proj', 'admin', 'proj')
|
||||
self.context = api.APIRequestContext(handler=None,
|
||||
user=self.user,
|
||||
project=self.project)
|
||||
|
||||
def tearDown(self):
|
||||
manager.AuthManager().delete_project('proj')
|
||||
manager.AuthManager().delete_user('admin')
|
||||
self.manager.delete_project(self.project)
|
||||
self.manager.delete_user(self.user)
|
||||
super(CloudTestCase, self).setUp()
|
||||
|
||||
def _create_key(self, name):
|
||||
# NOTE(vish): create depends on pool, so just call helper directly
|
||||
return cloud._gen_key(self.context, self.context.user.id, name)
|
||||
|
||||
def test_console_output(self):
|
||||
if FLAGS.connection_type == 'fake':
|
||||
@@ -76,6 +86,33 @@ class CloudTestCase(test.BaseTestCase):
|
||||
self.assert_(output)
|
||||
rv = yield self.compute.terminate_instance(instance_id)
|
||||
|
||||
|
||||
def test_key_generation(self):
|
||||
result = self._create_key('test')
|
||||
private_key = result['private_key']
|
||||
key = RSA.load_key_string(private_key, callback=lambda: None)
|
||||
bio = BIO.MemoryBuffer()
|
||||
public_key = db.keypair_get(self.context,
|
||||
self.context.user.id,
|
||||
'test')['public_key']
|
||||
key.save_pub_key_bio(bio)
|
||||
converted = crypto.ssl_pub_to_ssh_pub(bio.read())
|
||||
# assert key fields are equal
|
||||
self.assertEqual(public_key.split(" ")[1].strip(),
|
||||
converted.split(" ")[1].strip())
|
||||
|
||||
def test_describe_key_pairs(self):
|
||||
self._create_key('test1')
|
||||
self._create_key('test2')
|
||||
result = self.cloud.describe_key_pairs(self.context)
|
||||
keys = result["keypairsSet"]
|
||||
self.assertTrue(filter(lambda k: k['keyName'] == 'test1', keys))
|
||||
self.assertTrue(filter(lambda k: k['keyName'] == 'test2', keys))
|
||||
|
||||
def test_delete_key_pair(self):
|
||||
self._create_key('test')
|
||||
self.cloud.delete_key_pair(self.context, 'test')
|
||||
|
||||
def test_run_instances(self):
|
||||
if FLAGS.connection_type == 'fake':
|
||||
logging.debug("Can't test instances without a real virtual env.")
|
||||
|
||||
Reference in New Issue
Block a user