Create an internal key pair API.

Creates an internal key pair API and update the EC2 and OS API's to
use it. This de-duplicates some of the code used to manage keypairs
across the APIs.

Fixes LP Bug #998059.

Change-Id: I10d58d7ce68cc2b993c72b6639f66c72def3bfbc
This commit is contained in:
Dan Prince 2012-05-11 10:07:06 -04:00
parent 2c7e0d1e63
commit ec0a65d81f
7 changed files with 257 additions and 130 deletions

View File

@ -34,7 +34,6 @@ from nova import block_device
from nova import compute from nova import compute
from nova.compute import instance_types from nova.compute import instance_types
from nova.compute import vm_states from nova.compute import vm_states
from nova import crypto
from nova import db from nova import db
from nova import exception from nova import exception
from nova import flags from nova import flags
@ -61,33 +60,6 @@ def validate_ec2_id(val):
raise exception.InvalidInstanceIDMalformed(val) raise exception.InvalidInstanceIDMalformed(val)
def _gen_key(context, user_id, key_name):
"""Generate a key
This is a module level method because it is slow and we need to defer
it into a process pool."""
# NOTE(vish): generating key pair is slow so check for legal
# creation before creating key_pair
try:
db.key_pair_get(context, user_id, key_name)
raise exception.KeyPairExists(key_name=key_name)
except exception.NotFound:
pass
if quota.allowed_key_pairs(context, 1) < 1:
msg = _("Quota exceeded, too many key pairs.")
raise exception.EC2APIError(msg)
private_key, public_key, fingerprint = crypto.generate_key_pair()
key = {}
key['user_id'] = user_id
key['name'] = key_name
key['public_key'] = public_key
key['fingerprint'] = fingerprint
db.key_pair_create(context, key)
return {'private_key': private_key, 'fingerprint': fingerprint}
# EC2 API can return the following values as documented in the EC2 API # EC2 API can return the following values as documented in the EC2 API
# http://docs.amazonwebservices.com/AWSEC2/latest/APIReference/ # http://docs.amazonwebservices.com/AWSEC2/latest/APIReference/
# ApiReference-ItemType-InstanceStateType.html # ApiReference-ItemType-InstanceStateType.html
@ -217,6 +189,7 @@ class CloudController(object):
self.volume_api = volume.API() self.volume_api = volume.API()
self.compute_api = compute.API(network_api=self.network_api, self.compute_api = compute.API(network_api=self.network_api,
volume_api=self.volume_api) volume_api=self.volume_api)
self.keypair_api = compute.api.KeypairAPI()
self.sgh = importutils.import_object(FLAGS.security_group_handler) self.sgh = importutils.import_object(FLAGS.security_group_handler)
def __str__(self): def __str__(self):
@ -357,7 +330,7 @@ class CloudController(object):
return True return True
def describe_key_pairs(self, context, key_name=None, **kwargs): def describe_key_pairs(self, context, key_name=None, **kwargs):
key_pairs = db.key_pair_get_all_by_user(context, context.user_id) key_pairs = self.keypair_api.get_key_pairs(context, context.user_id)
if not key_name is None: if not key_name is None:
key_pairs = [x for x in key_pairs if x['name'] in key_name] key_pairs = [x for x in key_pairs if x['name'] in key_name]
@ -374,52 +347,55 @@ class CloudController(object):
return {'keySet': result} return {'keySet': result}
def create_key_pair(self, context, key_name, **kwargs): def create_key_pair(self, context, key_name, **kwargs):
if not re.match('^[a-zA-Z0-9_\- ]+$', str(key_name)):
err = _("Value (%s) for KeyName is invalid."
" Content limited to Alphanumeric character, "
"spaces, dashes, and underscore.") % key_name
raise exception.EC2APIError(err)
if len(str(key_name)) > 255:
err = _("Value (%s) for Keyname is invalid."
" Length exceeds maximum of 255.") % key_name
raise exception.EC2APIError(err)
LOG.audit(_("Create key pair %s"), key_name, context=context) LOG.audit(_("Create key pair %s"), key_name, context=context)
data = _gen_key(context, context.user_id, key_name)
try:
keypair = self.keypair_api.create_key_pair(context,
context.user_id,
key_name)
except exception.KeypairLimitExceeded:
msg = _("Quota exceeded, too many key pairs.")
raise exception.EC2APIError(msg)
except exception.InvalidKeypair:
msg = _("Keypair data is invalid")
raise exception.EC2APIError(msg)
except exception.KeyPairExists:
msg = _("Key pair '%s' already exists.") % key_name
raise exception.KeyPairExists(msg)
return {'keyName': key_name, return {'keyName': key_name,
'keyFingerprint': data['fingerprint'], 'keyFingerprint': keypair['fingerprint'],
'keyMaterial': data['private_key']} 'keyMaterial': keypair['private_key']}
# TODO(vish): when context is no longer an object, pass it here # TODO(vish): when context is no longer an object, pass it here
def import_key_pair(self, context, key_name, public_key_material, def import_key_pair(self, context, key_name, public_key_material,
**kwargs): **kwargs):
LOG.audit(_("Import key %s"), key_name, context=context) LOG.audit(_("Import key %s"), key_name, context=context)
try:
db.key_pair_get(context, context.user_id, key_name)
raise exception.KeyPairExists(key_name=key_name)
except exception.NotFound:
pass
if quota.allowed_key_pairs(context, 1) < 1:
msg = _("Quota exceeded, too many key pairs.")
raise exception.EC2APIError(msg)
public_key = base64.b64decode(public_key_material) public_key = base64.b64decode(public_key_material)
fingerprint = crypto.generate_fingerprint(public_key)
key = {} try:
key['user_id'] = context.user_id keypair = self.keypair_api.import_key_pair(context,
key['name'] = key_name context.user_id,
key['public_key'] = public_key key_name,
key['fingerprint'] = fingerprint public_key)
db.key_pair_create(context, key) except exception.KeypairLimitExceeded:
msg = _("Quota exceeded, too many key pairs.")
raise exception.EC2APIError(msg)
except exception.InvalidKeypair:
msg = _("Keypair data is invalid")
raise exception.EC2APIError(msg)
except exception.KeyPairExists:
msg = _("Key pair '%s' already exists.") % key_name
raise exception.EC2APIError(msg)
return {'keyName': key_name, return {'keyName': key_name,
'keyFingerprint': fingerprint} 'keyFingerprint': keypair['fingerprint']}
def delete_key_pair(self, context, key_name, **kwargs): def delete_key_pair(self, context, key_name, **kwargs):
LOG.audit(_("Delete key pair %s"), key_name, context=context) LOG.audit(_("Delete key pair %s"), key_name, context=context)
try: try:
db.key_pair_destroy(context, context.user_id, key_name) self.keypair_api.delete_key_pair(context, context.user_id,
key_name)
except exception.NotFound: except exception.NotFound:
# aws returns true even if the key doesn't exist # aws returns true even if the key doesn't exist
pass pass

View File

@ -17,18 +17,14 @@
""" Keypair management extension""" """ Keypair management extension"""
import string
import webob import webob
import webob.exc import webob.exc
from nova.api.openstack import wsgi from nova.api.openstack import wsgi
from nova.api.openstack import xmlutil from nova.api.openstack import xmlutil
from nova.api.openstack import extensions from nova.api.openstack import extensions
from nova import crypto from nova.compute import api as compute_api
from nova import db
from nova import exception from nova import exception
from nova import quota
authorize = extensions.extension_authorizer('compute', 'keypairs') authorize = extensions.extension_authorizer('compute', 'keypairs')
@ -50,26 +46,10 @@ class KeypairsTemplate(xmlutil.TemplateBuilder):
class KeypairController(object): class KeypairController(object):
""" Keypair API controller for the OpenStack API """ """ Keypair API controller for the OpenStack API """
def __init__(self):
# TODO(ja): both this file and nova.api.ec2.cloud.py have similar logic. self.api = compute_api.KeypairAPI()
# move the common keypair logic to nova.compute.API?
def _gen_key(self):
"""
Generate a key
"""
private_key, public_key, fingerprint = crypto.generate_key_pair()
return {'private_key': private_key,
'public_key': public_key,
'fingerprint': fingerprint}
def _validate_keypair_name(self, value):
safechars = "_-" + string.digits + string.ascii_letters
clean_value = "".join(x for x in value if x in safechars)
if clean_value != value:
msg = _("Keypair name contains unsafe characters")
raise webob.exc.HTTPBadRequest(explanation=msg)
@wsgi.serializers(xml=KeypairTemplate) @wsgi.serializers(xml=KeypairTemplate)
def create(self, req, body): def create(self, req, body):
@ -90,45 +70,29 @@ class KeypairController(object):
authorize(context) authorize(context)
params = body['keypair'] params = body['keypair']
name = params['name'] name = params['name']
self._validate_keypair_name(name)
if not 0 < len(name) < 256:
msg = _('Keypair name must be between 1 and 255 characters long')
raise webob.exc.HTTPBadRequest(explanation=msg)
# NOTE(ja): generation is slow, so shortcut invalid name exception
try: try:
db.key_pair_get(context, context.user_id, name) if 'public_key' in params:
msg = _("Key pair '%s' already exists.") % name keypair = self.api.import_key_pair(context,
raise webob.exc.HTTPConflict(explanation=msg) context.user_id, name,
except exception.NotFound: params['public_key'])
pass else:
keypair = self.api.create_key_pair(context, context.user_id,
name)
keypair = {'user_id': context.user_id, return {'keypair': keypair}
'name': name}
if quota.allowed_key_pairs(context, 1) < 1: except exception.KeypairLimitExceeded:
msg = _("Quota exceeded, too many key pairs.") msg = _("Quota exceeded, too many key pairs.")
raise webob.exc.HTTPRequestEntityTooLarge( raise webob.exc.HTTPRequestEntityTooLarge(
explanation=msg, explanation=msg,
headers={'Retry-After': 0}) headers={'Retry-After': 0})
# import if public_key is sent except exception.InvalidKeypair:
if 'public_key' in params: msg = _("Keypair data is invalid")
try: raise webob.exc.HTTPBadRequest(explanation=msg)
fingerprint = crypto.generate_fingerprint(params['public_key']) except exception.KeyPairExists:
except exception.InvalidKeypair: msg = _("Key pair '%s' already exists.") % name
msg = _("Keypair data is invalid") raise webob.exc.HTTPConflict(explanation=msg)
raise webob.exc.HTTPBadRequest(explanation=msg)
keypair['public_key'] = params['public_key']
keypair['fingerprint'] = fingerprint
else:
generated_key = self._gen_key()
keypair['private_key'] = generated_key['private_key']
keypair['public_key'] = generated_key['public_key']
keypair['fingerprint'] = generated_key['fingerprint']
db.key_pair_create(context, keypair)
return {'keypair': keypair}
def delete(self, req, id): def delete(self, req, id):
""" """
@ -137,7 +101,7 @@ class KeypairController(object):
context = req.environ['nova.context'] context = req.environ['nova.context']
authorize(context) authorize(context)
try: try:
db.key_pair_destroy(context, context.user_id, id) self.api.delete_key_pair(context, context.user_id, id)
except exception.KeypairNotFound: except exception.KeypairNotFound:
raise webob.exc.HTTPNotFound() raise webob.exc.HTTPNotFound()
return webob.Response(status_int=202) return webob.Response(status_int=202)
@ -149,7 +113,7 @@ class KeypairController(object):
""" """
context = req.environ['nova.context'] context = req.environ['nova.context']
authorize(context) authorize(context)
key_pairs = db.key_pair_get_all_by_user(context, context.user_id) key_pairs = self.api.get_key_pairs(context, context.user_id)
rval = [] rval = []
for key_pair in key_pairs: for key_pair in key_pairs:
rval.append({'keypair': { rval.append({'keypair': {

View File

@ -3,6 +3,7 @@
# Copyright 2010 United States Government as represented by the # Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration. # Administrator of the National Aeronautics and Space Administration.
# Copyright 2011 Piston Cloud Computing, Inc. # Copyright 2011 Piston Cloud Computing, Inc.
# Copyright 2012 Red Hat, Inc.
# All Rights Reserved. # All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); you may # Licensed under the Apache License, Version 2.0 (the "License"); you may
@ -23,6 +24,7 @@ networking and storage of vms, and compute hosts on which they run)."""
import functools import functools
import re import re
import time import time
import string
from nova import block_device from nova import block_device
from nova.compute import aggregate_states from nova.compute import aggregate_states
@ -30,6 +32,7 @@ from nova.compute import instance_types
from nova.compute import power_state from nova.compute import power_state
from nova.compute import task_states from nova.compute import task_states
from nova.compute import vm_states from nova.compute import vm_states
from nova import crypto
from nova.db import base from nova.db import base
from nova import exception from nova import exception
from nova import flags from nova import flags
@ -1908,3 +1911,83 @@ class AggregateAPI(base.Base):
result["metadata"] = metadata result["metadata"] = metadata
result["hosts"] = hosts result["hosts"] = hosts
return result return result
class KeypairAPI(base.Base):
"""Sub-set of the Compute Manager API for managing key pairs."""
def __init__(self, **kwargs):
super(KeypairAPI, self).__init__(**kwargs)
def _validate_keypair_name(self, context, user_id, key_name):
safechars = "_- " + string.digits + string.ascii_letters
clean_value = "".join(x for x in key_name if x in safechars)
if clean_value != key_name:
msg = _("Keypair name contains unsafe characters")
raise exception.InvalidKeypair(explanation=msg)
if not 0 < len(key_name) < 256:
msg = _('Keypair name must be between 1 and 255 characters long')
raise exception.InvalidKeypair(explanation=msg)
# NOTE: check for existing keypairs of same name
try:
self.db.key_pair_get(context, user_id, key_name)
msg = _("Key pair '%s' already exists.") % key_name
raise exception.KeyPairExists(explanation=msg)
except exception.NotFound:
pass
def import_key_pair(self, context, user_id, key_name, public_key):
"""Import a key pair using an existing public key."""
self._validate_keypair_name(context, user_id, key_name)
if quota.allowed_key_pairs(context, 1) < 1:
raise exception.KeypairLimitExceeded()
try:
fingerprint = crypto.generate_fingerprint(public_key)
except exception.InvalidKeypair:
msg = _("Keypair data is invalid")
raise exception.InvalidKeypair(explanation=msg)
keypair = {'user_id': user_id,
'name': key_name,
'fingerprint': fingerprint,
'public_key': public_key}
self.db.key_pair_create(context, keypair)
return keypair
def create_key_pair(self, context, user_id, key_name):
"""Create a new key pair."""
self._validate_keypair_name(context, user_id, key_name)
if quota.allowed_key_pairs(context, 1) < 1:
raise exception.KeypairLimitExceeded()
private_key, public_key, fingerprint = crypto.generate_key_pair()
keypair = {'user_id': user_id,
'name': key_name,
'fingerprint': fingerprint,
'public_key': public_key,
'private_key': private_key}
self.db.key_pair_create(context, keypair)
return keypair
def delete_key_pair(self, context, user_id, key_name):
"""Delete a keypair by name."""
self.db.key_pair_destroy(context, user_id, key_name)
def get_key_pairs(self, context, user_id):
"""List key pairs."""
key_pairs = self.db.key_pair_get_all_by_user(context, user_id)
rval = []
for key_pair in key_pairs:
rval.append({
'name': key_pair['name'],
'public_key': key_pair['public_key'],
'fingerprint': key_pair['fingerprint'],
})
return rval

View File

@ -993,6 +993,10 @@ class OnsetFileContentLimitExceeded(QuotaError):
message = _("Personality file content too long") message = _("Personality file content too long")
class KeypairLimitExceeded(QuotaError):
message = _("Maximum number of key pairs exceeded")
class AggregateError(NovaException): class AggregateError(NovaException):
message = _("Aggregate %(aggregate_id)s: action '%(action)s' " message = _("Aggregate %(aggregate_id)s: action '%(action)s' "
"caused an error: %(reason)s.") "caused an error: %(reason)s.")

View File

@ -28,6 +28,7 @@ import tempfile
from nova.api.ec2 import cloud from nova.api.ec2 import cloud
from nova.api.ec2 import ec2utils from nova.api.ec2 import ec2utils
from nova.api.ec2 import inst_state from nova.api.ec2 import inst_state
from nova.compute import api as compute_api
from nova.compute import power_state from nova.compute import power_state
from nova.compute import utils as compute_utils from nova.compute import utils as compute_utils
from nova.compute import vm_states from nova.compute import vm_states
@ -144,7 +145,9 @@ class CloudTestCase(test.TestCase):
def _create_key(self, name): def _create_key(self, name):
# NOTE(vish): create depends on pool, so just call helper directly # NOTE(vish): create depends on pool, so just call helper directly
return cloud._gen_key(self.context, self.context.user_id, name) keypair_api = compute_api.KeypairAPI()
return keypair_api.create_key_pair(self.context, self.context.user_id,
name)
def test_describe_regions(self): def test_describe_regions(self):
"""Makes sure describe regions runs without raising an exception""" """Makes sure describe regions runs without raising an exception"""

View File

@ -35,8 +35,8 @@ from nova import test
from nova.api import auth from nova.api import auth
from nova.api import ec2 from nova.api import ec2
from nova.api.ec2 import apirequest from nova.api.ec2 import apirequest
from nova.api.ec2 import cloud
from nova.api.ec2 import ec2utils from nova.api.ec2 import ec2utils
from nova.compute import api as compute_api
class FakeHttplibSocket(object): class FakeHttplibSocket(object):
@ -290,13 +290,11 @@ class ApiEc2TestCase(test.TestCase):
def test_get_all_key_pairs(self): def test_get_all_key_pairs(self):
"""Test that, after creating a user and project and generating """Test that, after creating a user and project and generating
a key pair, that the API call to list key pairs works properly""" a key pair, that the API call to list key pairs works properly"""
self.expect_http()
self.mox.ReplayAll()
keyname = "".join(random.choice("sdiuisudfsdcnpaqwertasd") keyname = "".join(random.choice("sdiuisudfsdcnpaqwertasd")
for x in range(random.randint(4, 8))) for x in range(random.randint(4, 8)))
# NOTE(vish): create depends on pool, so call helper directly self.expect_http()
cloud._gen_key(context.get_admin_context(), 'fake', keyname) self.mox.ReplayAll()
self.ec2.create_key_pair(keyname)
rv = self.ec2.get_all_key_pairs() rv = self.ec2.get_all_key_pairs()
results = [k for k in rv if k.name == keyname] results = [k for k in rv if k.name == keyname]
self.assertEquals(len(results), 1) self.assertEquals(len(results), 1)
@ -306,9 +304,6 @@ class ApiEc2TestCase(test.TestCase):
requesting a second keypair with the same name fails sanely""" requesting a second keypair with the same name fails sanely"""
self.expect_http() self.expect_http()
self.mox.ReplayAll() self.mox.ReplayAll()
keyname = "".join(random.choice("sdiuisudfsdcnpaqwertasd")
for x in range(random.randint(4, 8)))
# NOTE(vish): create depends on pool, so call helper directly
self.ec2.create_key_pair('test') self.ec2.create_key_pair('test')
try: try:

View File

@ -3872,3 +3872,105 @@ class ComputeHostAPITestCase(BaseTestCase):
self.assertEqual(call_info['msg'], self.assertEqual(call_info['msg'],
{'method': 'host_maintenance_mode', {'method': 'host_maintenance_mode',
'args': {'host': 'fake_host', 'mode': 'fake_mode'}}) 'args': {'host': 'fake_host', 'mode': 'fake_mode'}})
class KeypairAPITestCase(BaseTestCase):
def setUp(self):
super(KeypairAPITestCase, self).setUp()
self.keypair_api = compute_api.KeypairAPI()
self.ctxt = context.RequestContext('fake', 'fake')
self._keypair_db_call_stubs()
self.pub_key = 'ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDLnVkqJu9WVf' \
'/5StU3JCrBR2r1s1j8K1tux+5XeSvdqaM8lMFNorzbY5iyoBbRS56gy' \
'1jmm43QsMPJsrpfUZKcJpRENSe3OxIIwWXRoiapZe78u/a9xKwj0avF' \
'YMcws9Rk9iAB7W4K1nEJbyCPl5lRBoyqeHBqrnnuXWEgGxJCK0Ah6wc' \
'OzwlEiVjdf4kxzXrwPHyi7Ea1qvnNXTziF8yYmUlH4C8UXfpTQckwSw' \
'pDyxZUc63P8q+vPbs3Q2kw+/7vvkCKHJAXVI+oCiyMMfffoTq16M1xf' \
'V58JstgtTqAXG+ZFpicGajREUE/E3hO5MGgcHmyzIrWHKpe1n3oEGuz'
self.fingerprint = '4e:48:c6:a0:4a:f9:dd:b5:4c:85:54:5a:af:43:47:5a'
def _keypair_db_call_stubs(self):
def db_key_pair_get_all_by_user(self, user_id):
return []
def db_key_pair_create(self, keypair):
pass
def db_key_pair_destroy(context, user_id, name):
pass
self.stubs.Set(db, "key_pair_get_all_by_user",
db_key_pair_get_all_by_user)
self.stubs.Set(db, "key_pair_create",
db_key_pair_create)
self.stubs.Set(db, "key_pair_destroy",
db_key_pair_destroy)
def test_create_keypair(self):
keypair = self.keypair_api.create_key_pair(self.ctxt,
self.ctxt.user_id, 'foo')
self.assertEqual('foo', keypair['name'])
def test_create_keypair_name_too_long(self):
self.assertRaises(exception.InvalidKeypair,
self.keypair_api.create_key_pair,
self.ctxt, self.ctxt.user_id, 'x' * 256)
def test_create_keypair_invalid_chars(self):
self.assertRaises(exception.InvalidKeypair,
self.keypair_api.create_key_pair,
self.ctxt, self.ctxt.user_id, '* BAD CHARACTERS! *')
def test_create_keypair_already_exists(self):
def db_key_pair_get(context, user_id, name):
pass
self.stubs.Set(db, "key_pair_get",
db_key_pair_get)
self.assertRaises(exception.KeyPairExists,
self.keypair_api.create_key_pair,
self.ctxt, self.ctxt.user_id, 'foo')
def test_create_keypair_quota_limit(self):
def db_key_pair_count_by_user_max(self, user_id):
return FLAGS.quota_key_pairs
self.stubs.Set(db, "key_pair_count_by_user",
db_key_pair_count_by_user_max)
self.assertRaises(exception.KeypairLimitExceeded,
self.keypair_api.create_key_pair,
self.ctxt, self.ctxt.user_id, 'foo')
def test_import_keypair(self):
keypair = self.keypair_api.import_key_pair(self.ctxt,
self.ctxt.user_id,
'foo',
self.pub_key)
self.assertEqual('foo', keypair['name'])
self.assertEqual(self.fingerprint, keypair['fingerprint'])
self.assertEqual(self.pub_key, keypair['public_key'])
def test_import_keypair_bad_public_key(self):
self.assertRaises(exception.InvalidKeypair,
self.keypair_api.import_key_pair,
self.ctxt, self.ctxt.user_id, 'foo', 'bad key data')
def test_import_keypair_name_too_long(self):
self.assertRaises(exception.InvalidKeypair,
self.keypair_api.import_key_pair,
self.ctxt, self.ctxt.user_id, 'x' * 256,
self.pub_key)
def test_import_keypair_invalid_chars(self):
self.assertRaises(exception.InvalidKeypair,
self.keypair_api.import_key_pair,
self.ctxt, self.ctxt.user_id,
'* BAD CHARACTERS! *', self.pub_key)
def test_import_keypair_quota_limit(self):
def db_key_pair_count_by_user_max(self, user_id):
return FLAGS.quota_key_pairs
self.stubs.Set(db, "key_pair_count_by_user",
db_key_pair_count_by_user_max)
self.assertRaises(exception.KeypairLimitExceeded,
self.keypair_api.import_key_pair,
self.ctxt, self.ctxt.user_id, 'foo', self.pub_key)