Make compute_api use KeyPair objects

This makes compute_api's KeypairAPI use KeyPair objects to do
its work. Most of the change here is to tests to account for the
change.

Related to blueprint compute-api-objects

Change-Id: If516a88eac107670d9ae37fcfaf56f2add9dd0c5
This commit is contained in:
Dan Smith
2013-08-14 11:30:30 -07:00
parent e91c6ffbea
commit 28e937cd2b
11 changed files with 113 additions and 79 deletions

View File

@@ -444,15 +444,14 @@ class CloudController(object):
LOG.audit(_("Create key pair %s"), key_name, context=context) LOG.audit(_("Create key pair %s"), key_name, context=context)
try: try:
keypair = self.keypair_api.create_key_pair(context, keypair, private_key = self.keypair_api.create_key_pair(
context.user_id, context, context.user_id, key_name)
key_name)
except exception.KeypairLimitExceeded: except exception.KeypairLimitExceeded:
msg = _("Quota exceeded, too many key pairs.") msg = _("Quota exceeded, too many key pairs.")
raise exception.EC2APIError(msg, code='ResourceLimitExceeded') raise exception.EC2APIError(msg, code='ResourceLimitExceeded')
return {'keyName': key_name, return {'keyName': key_name,
'keyFingerprint': keypair['fingerprint'], 'keyFingerprint': keypair['fingerprint'],
'keyMaterial': keypair['private_key']} 'keyMaterial': private_key}
# TODO(vish): when context is no longer an object, pass it here # TODO(vish): when context is no longer an object, pass it here
def import_key_pair(self, context, key_name, public_key_material, def import_key_pair(self, context, key_name, public_key_material,

View File

@@ -54,6 +54,16 @@ class KeypairController(object):
def __init__(self): def __init__(self):
self.api = compute_api.KeypairAPI() self.api = compute_api.KeypairAPI()
def _filter_keypair(self, keypair, **attrs):
clean = {
'name': keypair.name,
'public_key': keypair.public_key,
'fingerprint': keypair.fingerprint,
}
for attr in attrs:
clean[attr] = keypair[attr]
return clean
@wsgi.serializers(xml=KeypairTemplate) @wsgi.serializers(xml=KeypairTemplate)
def create(self, req, body): def create(self, req, body):
""" """
@@ -84,9 +94,12 @@ class KeypairController(object):
keypair = self.api.import_key_pair(context, keypair = self.api.import_key_pair(context,
context.user_id, name, context.user_id, name,
params['public_key']) params['public_key'])
keypair = self._filter_keypair(keypair, user_id=True)
else: else:
keypair = self.api.create_key_pair(context, context.user_id, keypair, private_key = self.api.create_key_pair(
name) context, context.user_id, name)
keypair = self._filter_keypair(keypair, user_id=True)
keypair['private_key'] = private_key
return {'keypair': keypair} return {'keypair': keypair}
@@ -134,11 +147,7 @@ class KeypairController(object):
key_pairs = self.api.get_key_pairs(context, context.user_id) key_pairs = self.api.get_key_pairs(context, context.user_id)
rval = [] rval = []
for key_pair in key_pairs: for key_pair in key_pairs:
rval.append({'keypair': { rval.append({'keypair': self._filter_keypair(key_pair)})
'name': key_pair['name'],
'public_key': key_pair['public_key'],
'fingerprint': key_pair['fingerprint'],
}})
return {'keypairs': rval} return {'keypairs': rval}

View File

@@ -55,6 +55,16 @@ class KeypairController(object):
def __init__(self): def __init__(self):
self.api = compute_api.KeypairAPI() self.api = compute_api.KeypairAPI()
def _filter_keypair(self, keypair, **attrs):
clean = {
'name': keypair.name,
'public_key': keypair.public_key,
'fingerprint': keypair.fingerprint,
}
for attr in attrs:
clean[attr] = keypair[attr]
return clean
@wsgi.serializers(xml=KeypairTemplate) @wsgi.serializers(xml=KeypairTemplate)
@extensions.expected_errors((400, 409, 413)) @extensions.expected_errors((400, 409, 413))
def create(self, req, body): def create(self, req, body):
@@ -86,9 +96,12 @@ class KeypairController(object):
keypair = self.api.import_key_pair(context, keypair = self.api.import_key_pair(context,
context.user_id, name, context.user_id, name,
params['public_key']) params['public_key'])
keypair = self._filter_keypair(keypair, user_id=True)
else: else:
keypair = self.api.create_key_pair(context, context.user_id, keypair, private_key = self.api.create_key_pair(
name) context, context.user_id, name)
keypair = self._filter_keypair(keypair, user_id=True)
keypair['private_key'] = private_key
return {'keypair': keypair} return {'keypair': keypair}
@@ -126,7 +139,7 @@ class KeypairController(object):
keypair = self.api.get_key_pair(context, context.user_id, id) keypair = self.api.get_key_pair(context, context.user_id, id)
except exception.KeypairNotFound: except exception.KeypairNotFound:
raise webob.exc.HTTPNotFound() raise webob.exc.HTTPNotFound()
return {'keypair': keypair} return {'keypair': self._filter_keypair(keypair)}
@extensions.expected_errors(()) @extensions.expected_errors(())
@wsgi.serializers(xml=KeypairsTemplate) @wsgi.serializers(xml=KeypairsTemplate)
@@ -139,11 +152,7 @@ class KeypairController(object):
key_pairs = self.api.get_key_pairs(context, context.user_id) key_pairs = self.api.get_key_pairs(context, context.user_id)
rval = [] rval = []
for key_pair in key_pairs: for key_pair in key_pairs:
rval.append({'keypair': { rval.append({'keypair': self._filter_keypair(key_pair)})
'name': key_pair['name'],
'public_key': key_pair['public_key'],
'fingerprint': key_pair['fingerprint'],
}})
return {'keypairs': rval} return {'keypairs': rval}

View File

@@ -171,10 +171,9 @@ class CloudPipe(object):
key_name = '%s%s' % (context.project_id, CONF.vpn_key_suffix) key_name = '%s%s' % (context.project_id, CONF.vpn_key_suffix)
try: try:
keypair_api = compute.api.KeypairAPI() keypair_api = compute.api.KeypairAPI()
result = keypair_api.create_key_pair(context, result, private_key = keypair_api.create_key_pair(context,
context.user_id, context.user_id,
key_name) key_name)
private_key = result['private_key']
key_dir = os.path.join(CONF.keys_path, context.user_id) key_dir = os.path.join(CONF.keys_path, context.user_id)
fileutils.ensure_tree(key_dir) fileutils.ensure_tree(key_dir)
key_path = os.path.join(key_dir, '%s.pem' % key_name) key_path = os.path.join(key_dir, '%s.pem' % key_name)

View File

@@ -55,6 +55,7 @@ from nova.objects import base as obj_base
from nova.objects import instance as instance_obj from nova.objects import instance as instance_obj
from nova.objects import instance_action from nova.objects import instance_action
from nova.objects import instance_info_cache from nova.objects import instance_info_cache
from nova.objects import keypair as keypair_obj
from nova.objects import security_group from nova.objects import security_group
from nova.openstack.common import excutils from nova.openstack.common import excutils
from nova.openstack.common.gettextutils import _ from nova.openstack.common.gettextutils import _
@@ -669,9 +670,10 @@ class API(base.Base):
config_drive = self._check_config_drive(config_drive) config_drive = self._check_config_drive(config_drive)
if key_data is None and key_name: if key_data is None and key_name:
key_pair = self.db.key_pair_get(context, context.user_id, key_pair = keypair_obj.KeyPair.get_by_name(context,
key_name) context.user_id,
key_data = key_pair['public_key'] key_name)
key_data = key_pair.public_key
root_device_name = block_device.properties_root_device_name( root_device_name = block_device.properties_root_device_name(
boot_meta.get('properties', {})) boot_meta.get('properties', {}))
@@ -3188,12 +3190,13 @@ class KeypairAPI(base.Base):
fingerprint = crypto.generate_fingerprint(public_key) fingerprint = crypto.generate_fingerprint(public_key)
keypair = {'user_id': user_id, keypair = keypair_obj.KeyPair()
'name': key_name, keypair.user_id = user_id
'fingerprint': fingerprint, keypair.name = key_name
'public_key': public_key} keypair.fingerprint = fingerprint
keypair.public_key = public_key
keypair.create(context)
self.db.key_pair_create(context, keypair)
return keypair return keypair
def create_key_pair(self, context, user_id, key_name): def create_key_pair(self, context, user_id, key_name):
@@ -3202,33 +3205,26 @@ class KeypairAPI(base.Base):
private_key, public_key, fingerprint = crypto.generate_key_pair() private_key, public_key, fingerprint = crypto.generate_key_pair()
keypair = {'user_id': user_id, keypair = keypair_obj.KeyPair()
'name': key_name, keypair.user_id = user_id
'fingerprint': fingerprint, keypair.name = key_name
'public_key': public_key, keypair.fingerprint = fingerprint
'private_key': private_key} keypair.public_key = public_key
keypair.create(context)
self.db.key_pair_create(context, keypair) return keypair, private_key
return keypair
def delete_key_pair(self, context, user_id, key_name): def delete_key_pair(self, context, user_id, key_name):
"""Delete a keypair by name.""" """Delete a keypair by name."""
self.db.key_pair_destroy(context, user_id, key_name) keypair_obj.KeyPair.destroy_by_name(context, user_id, key_name)
def _get_key_pair(self, key_pair):
return {'name': key_pair['name'],
'public_key': key_pair['public_key'],
'fingerprint': key_pair['fingerprint']}
def get_key_pairs(self, context, user_id): def get_key_pairs(self, context, user_id):
"""List key pairs.""" """List key pairs."""
key_pairs = self.db.key_pair_get_all_by_user(context, user_id) return keypair_obj.KeyPairList.get_by_user(context, user_id)
return [self._get_key_pair(k) for k in key_pairs]
def get_key_pair(self, context, user_id, key_name): def get_key_pair(self, context, user_id, key_name):
"""Get a keypair by name.""" """Get a keypair by name."""
key_pair = self.db.key_pair_get(context, user_id, key_name) return keypair_obj.KeyPair.get_by_name(context, user_id, key_name)
return self._get_key_pair(key_pair)
class SecurityGroupAPI(base.Base, security_group_base.SecurityGroupBase): class SecurityGroupAPI(base.Base, security_group_base.SecurityGroupBase):

View File

@@ -1678,8 +1678,7 @@ class CloudTestCase(test.TestCase):
rv = self.cloud.terminate_instances(self.context, [instance_id]) rv = self.cloud.terminate_instances(self.context, [instance_id])
def test_key_generation(self): def test_key_generation(self):
result = self._create_key('test') result, private_key = self._create_key('test')
private_key = result['private_key']
expected = db.key_pair_get(self.context, expected = db.key_pair_get(self.context,
self.context.user_id, self.context.user_id,

View File

@@ -25,15 +25,21 @@ from nova.openstack.common import policy
from nova import quota from nova import quota
from nova import test from nova import test
from nova.tests.api.openstack import fakes from nova.tests.api.openstack import fakes
from nova.tests.objects import test_keypair
QUOTAS = quota.QUOTAS QUOTAS = quota.QUOTAS
keypair_data = {
'public_key': 'FAKE_KEY',
'fingerprint': 'FAKE_FINGERPRINT',
}
def fake_keypair(name): def fake_keypair(name):
return {'public_key': 'FAKE_KEY', return dict(test_keypair.fake_keypair,
'fingerprint': 'FAKE_FINGERPRINT', name=name, **keypair_data)
'name': name}
def db_key_pair_get_all_by_user(self, user_id): def db_key_pair_get_all_by_user(self, user_id):
@@ -41,7 +47,7 @@ def db_key_pair_get_all_by_user(self, user_id):
def db_key_pair_create(self, keypair): def db_key_pair_create(self, keypair):
return keypair return fake_keypair(name=keypair['name'])
def db_key_pair_destroy(context, user_id, name): def db_key_pair_destroy(context, user_id, name):
@@ -78,7 +84,7 @@ class KeypairsTest(test.TestCase):
res = req.get_response(self.app) res = req.get_response(self.app)
self.assertEqual(res.status_int, 200) self.assertEqual(res.status_int, 200)
res_dict = jsonutils.loads(res.body) res_dict = jsonutils.loads(res.body)
response = {'keypairs': [{'keypair': fake_keypair('FAKE')}]} response = {'keypairs': [{'keypair': dict(keypair_data, name='FAKE')}]}
self.assertEqual(res_dict, response) self.assertEqual(res_dict, response)
def test_keypair_create(self): def test_keypair_create(self):
@@ -284,7 +290,8 @@ class KeypairsTest(test.TestCase):
def test_keypair_show(self): def test_keypair_show(self):
def _db_key_pair_get(context, user_id, name): def _db_key_pair_get(context, user_id, name):
return {'name': 'foo', 'public_key': 'XXX', 'fingerprint': 'YYY'} return dict(test_keypair.fake_keypair,
name='foo', public_key='XXX', fingerprint='YYY')
self.stubs.Set(db, "key_pair_get", _db_key_pair_get) self.stubs.Set(db, "key_pair_get", _db_key_pair_get)
@@ -356,7 +363,8 @@ class KeypairPolicyTest(test.TestCase):
self.KeyPairController = keypairs.KeypairController() self.KeyPairController = keypairs.KeypairController()
def _db_key_pair_get(context, user_id, name): def _db_key_pair_get(context, user_id, name):
return {'name': 'foo', 'public_key': 'XXX', 'fingerprint': 'YYY'} return dict(test_keypair.fake_keypair,
name='foo', public_key='XXX', fingerprint='YYY')
self.stubs.Set(db, "key_pair_get", self.stubs.Set(db, "key_pair_get",
_db_key_pair_get) _db_key_pair_get)

View File

@@ -27,15 +27,21 @@ from nova.openstack.common import policy
from nova import quota from nova import quota
from nova import test from nova import test
from nova.tests.api.openstack import fakes from nova.tests.api.openstack import fakes
from nova.tests.objects import test_keypair
QUOTAS = quota.QUOTAS QUOTAS = quota.QUOTAS
keypair_data = {
'public_key': 'FAKE_KEY',
'fingerprint': 'FAKE_FINGERPRINT',
}
def fake_keypair(name): def fake_keypair(name):
return {'public_key': 'FAKE_KEY', return dict(test_keypair.fake_keypair,
'fingerprint': 'FAKE_FINGERPRINT', name=name, **keypair_data)
'name': name}
def db_key_pair_get_all_by_user(self, user_id): def db_key_pair_get_all_by_user(self, user_id):
@@ -43,7 +49,7 @@ def db_key_pair_get_all_by_user(self, user_id):
def db_key_pair_create(self, keypair): def db_key_pair_create(self, keypair):
return keypair return fake_keypair(name=keypair['name'])
def db_key_pair_destroy(context, user_id, name): def db_key_pair_destroy(context, user_id, name):
@@ -80,7 +86,7 @@ class KeypairsTest(test.TestCase):
res = req.get_response(self.app) res = req.get_response(self.app)
self.assertEqual(res.status_int, 200) self.assertEqual(res.status_int, 200)
res_dict = jsonutils.loads(res.body) res_dict = jsonutils.loads(res.body)
response = {'keypairs': [{'keypair': fake_keypair('FAKE')}]} response = {'keypairs': [{'keypair': dict(keypair_data, name='FAKE')}]}
self.assertEqual(res_dict, response) self.assertEqual(res_dict, response)
def test_keypair_create(self): def test_keypair_create(self):
@@ -285,7 +291,8 @@ class KeypairsTest(test.TestCase):
def test_keypair_show(self): def test_keypair_show(self):
def _db_key_pair_get(context, user_id, name): def _db_key_pair_get(context, user_id, name):
return {'name': 'foo', 'public_key': 'XXX', 'fingerprint': 'YYY'} return dict(test_keypair.fake_keypair,
name='foo', public_key='XXX', fingerprint='YYY')
self.stubs.Set(db, "key_pair_get", _db_key_pair_get) self.stubs.Set(db, "key_pair_get", _db_key_pair_get)
@@ -358,7 +365,8 @@ class KeypairPolicyTest(test.TestCase):
self.KeyPairController = keypairs.KeypairController() self.KeyPairController = keypairs.KeypairController()
def _db_key_pair_get(context, user_id, name): def _db_key_pair_get(context, user_id, name):
return {'name': 'foo', 'public_key': 'XXX', 'fingerprint': 'YYY'} return dict(test_keypair.fake_keypair,
name='foo', public_key='XXX', fingerprint='YYY')
self.stubs.Set(db, "key_pair_get", self.stubs.Set(db, "key_pair_get",
_db_key_pair_get) _db_key_pair_get)

View File

@@ -58,6 +58,7 @@ from nova.tests import fake_instance
from nova.tests import fake_network from nova.tests import fake_network
from nova.tests.image import fake from nova.tests.image import fake
from nova.tests import matchers from nova.tests import matchers
from nova.tests.objects import test_keypair
from nova.tests import utils from nova.tests import utils
from nova import utils as nova_utils from nova import utils as nova_utils
@@ -2838,9 +2839,10 @@ class ServersControllerCreateTest(test.TestCase):
# NOTE(sdague): key pair goes back to the database, # NOTE(sdague): key pair goes back to the database,
# so we need to stub it out for tests # so we need to stub it out for tests
def key_pair_get(context, user_id, name): def key_pair_get(context, user_id, name):
return {'public_key': 'FAKE_KEY', return dict(test_keypair.fake_keypair,
'fingerprint': 'FAKE_FINGERPRINT', public_key='FAKE_KEY',
'name': name} fingerprint='FAKE_FINGERPRINT',
name=name)
def create(*args, **kwargs): def create(*args, **kwargs):
self.assertEqual(kwargs['key_name'], key_name) self.assertEqual(kwargs['key_name'], key_name)

View File

@@ -45,6 +45,7 @@ from nova.openstack.common import timeutils
from nova import quota from nova import quota
from nova.tests import fake_network from nova.tests import fake_network
from nova.tests.glance import stubs as glance_stubs from nova.tests.glance import stubs as glance_stubs
from nova.tests.objects import test_keypair
from nova import utils from nova import utils
from nova import wsgi from nova import wsgi
@@ -126,11 +127,13 @@ def wsgi_app_v3(inner_app_v3=None, fake_auth_context=None,
def stub_out_key_pair_funcs(stubs, have_key_pair=True): def stub_out_key_pair_funcs(stubs, have_key_pair=True):
def key_pair(context, user_id): def key_pair(context, user_id):
return [dict(name='key', public_key='public_key')] return [dict(test_keypair.fake_keypair,
name='key', public_key='public_key')]
def one_key_pair(context, user_id, name): def one_key_pair(context, user_id, name):
if name == 'key': if name == 'key':
return dict(name='key', public_key='public_key') return dict(test_keypair.fake_keypair,
name='key', public_key='public_key')
else: else:
raise exc.KeypairNotFound(user_id=user_id, name=name) raise exc.KeypairNotFound(user_id=user_id, name=name)

View File

@@ -25,7 +25,7 @@ from nova import exception
from nova.openstack.common.gettextutils import _ from nova.openstack.common.gettextutils import _
from nova import quota from nova import quota
from nova.tests.compute import test_compute from nova.tests.compute import test_compute
from nova.tests.objects import test_keypair
CONF = cfg.CONF CONF = cfg.CONF
QUOTAS = quota.QUOTAS QUOTAS = quota.QUOTAS
@@ -51,21 +51,23 @@ class KeypairAPITestCase(test_compute.BaseTestCase):
def _keypair_db_call_stubs(self): def _keypair_db_call_stubs(self):
def db_key_pair_get_all_by_user(context, user_id): def db_key_pair_get_all_by_user(context, user_id):
return [{'name': self.existing_key_name, return [dict(test_keypair.fake_keypair,
'public_key': self.pub_key, name=self.existing_key_name,
'fingerprint': self.fingerprint}] public_key=self.pub_key,
fingerprint=self.fingerprint)]
def db_key_pair_create(context, keypair): def db_key_pair_create(context, keypair):
pass return dict(test_keypair.fake_keypair, **keypair)
def db_key_pair_destroy(context, user_id, name): def db_key_pair_destroy(context, user_id, name):
pass pass
def db_key_pair_get(context, user_id, name): def db_key_pair_get(context, user_id, name):
if name == self.existing_key_name: if name == self.existing_key_name:
return {'name': self.existing_key_name, return dict(test_keypair.fake_keypair,
'public_key': self.pub_key, name=self.existing_key_name,
'fingerprint': self.fingerprint} public_key=self.pub_key,
fingerprint=self.fingerprint)
else: else:
raise exception.KeypairNotFound(user_id=user_id, name=name) raise exception.KeypairNotFound(user_id=user_id, name=name)
@@ -135,8 +137,8 @@ class CreateKeypairTestCase(KeypairAPITestCase, CreateImportSharedTestMixIn):
func_name = 'create_key_pair' func_name = 'create_key_pair'
def test_success(self): def test_success(self):
keypair = self.keypair_api.create_key_pair(self.ctxt, keypair, private_key = self.keypair_api.create_key_pair(
self.ctxt.user_id, 'foo') self.ctxt, self.ctxt.user_id, 'foo')
self.assertEqual('foo', keypair['name']) self.assertEqual('foo', keypair['name'])