objects: Stop querying the main DB for keypairs

This was migrated to the API DB during the 14.0.0 (Newton) release [1]
and the 345 migration introduced during the 15.0.0 (Ocata) release [2]
ensures we should no longer have any entries left in the main database.

The actual model isn't removed yet. That will be done separately.

[1] I5f6d88fee47dd87de2867d3947d65b04f0b21e8f
[2] Iab714d9e752c334cc1cc14a0d524cc9cf5d115dc

Change-Id: I15efc38258685375284d8a97004777385023c6e8
Signed-off-by: Stephen Finucane <stephenfin@redhat.com>
This commit is contained in:
Stephen Finucane 2021-09-27 14:12:55 +01:00
parent 1f648b4f77
commit 944033061c
8 changed files with 149 additions and 573 deletions

View File

@ -2520,78 +2520,6 @@ def instance_extra_get_by_instance_uuid(
###################
@require_context
@pick_context_manager_writer
def key_pair_create(context, values):
"""Create a key_pair from the values dictionary."""
try:
key_pair_ref = models.KeyPair()
key_pair_ref.update(values)
key_pair_ref.save(context.session)
return key_pair_ref
except db_exc.DBDuplicateEntry:
raise exception.KeyPairExists(key_name=values['name'])
@require_context
@pick_context_manager_writer
def key_pair_destroy(context, user_id, name):
"""Destroy the key_pair or raise if it does not exist."""
result = model_query(context, models.KeyPair).\
filter_by(user_id=user_id).\
filter_by(name=name).\
soft_delete()
if not result:
raise exception.KeypairNotFound(user_id=user_id, name=name)
@require_context
@pick_context_manager_reader
def key_pair_get(context, user_id, name):
"""Get a key_pair or raise if it does not exist."""
result = model_query(context, models.KeyPair).\
filter_by(user_id=user_id).\
filter_by(name=name).\
first()
if not result:
raise exception.KeypairNotFound(user_id=user_id, name=name)
return result
@require_context
@pick_context_manager_reader
def key_pair_get_all_by_user(context, user_id, limit=None, marker=None):
"""Get all key_pairs by user."""
marker_row = None
if marker is not None:
marker_row = model_query(context, models.KeyPair, read_deleted="no").\
filter_by(name=marker).filter_by(user_id=user_id).first()
if not marker_row:
raise exception.MarkerNotFound(marker=marker)
query = model_query(context, models.KeyPair, read_deleted="no").\
filter_by(user_id=user_id)
query = sqlalchemyutils.paginate_query(
query, models.KeyPair, limit, ['name'], marker=marker_row)
return query.all()
@require_context
@pick_context_manager_reader
def key_pair_count_by_user(context, user_id):
"""Count number of key pairs for the given user ID."""
return model_query(context, models.KeyPair, read_deleted="no").\
filter_by(user_id=user_id).\
count()
###################
@require_context
@pick_context_manager_reader
def quota_get(context, project_id, resource, user_id=None):

View File

@ -19,7 +19,6 @@ from oslo_utils import versionutils
from nova.db.api import api as api_db_api
from nova.db.api import models as api_models
from nova.db.main import api as main_db_api
from nova import exception
from nova import objects
from nova.objects import base
@ -134,40 +133,27 @@ class KeyPair(base.NovaPersistentObject, base.NovaObject,
def _create_in_db(context, values):
return _create_in_db(context, values)
# TODO(stephenfin): Remove the 'localonly' parameter in v2.0
@base.remotable_classmethod
def get_by_name(cls, context, user_id, name,
localonly=False):
db_keypair = None
if not localonly:
try:
db_keypair = cls._get_from_db(context, user_id, name)
except exception.KeypairNotFound:
pass
if db_keypair is None:
db_keypair = main_db_api.key_pair_get(context, user_id, name)
def get_by_name(cls, context, user_id, name, localonly=False):
if localonly:
# There is no longer a "local" (main) table for keypairs, so this
# will always return nothing now
raise exception.KeypairNotFound(user_id=user_id, name=name)
db_keypair = cls._get_from_db(context, user_id, name)
return cls._from_db_object(context, cls(), db_keypair)
@base.remotable_classmethod
def destroy_by_name(cls, context, user_id, name):
try:
cls._destroy_in_db(context, user_id, name)
except exception.KeypairNotFound:
main_db_api.key_pair_destroy(context, user_id, name)
cls._destroy_in_db(context, user_id, name)
@base.remotable
def create(self):
if self.obj_attr_is_set('id'):
raise exception.ObjectActionError(action='create',
reason='already created')
# NOTE(danms): Check to see if it exists in the old DB before
# letting them create in the API DB, since we won't get protection
# from the UC.
try:
main_db_api.key_pair_get(self._context, self.user_id, self.name)
raise exception.KeyPairExists(key_name=self.name)
except exception.KeypairNotFound:
pass
raise exception.ObjectActionError(
action='create', reason='already created',
)
self._create()
@ -178,11 +164,7 @@ class KeyPair(base.NovaPersistentObject, base.NovaObject,
@base.remotable
def destroy(self):
try:
self._destroy_in_db(self._context, self.user_id, self.name)
except exception.KeypairNotFound:
main_db_api.key_pair_destroy(
self._context, self.user_id, self.name)
self._destroy_in_db(self._context, self.user_id, self.name)
@base.NovaObjectRegistry.register
@ -208,31 +190,13 @@ class KeyPairList(base.ObjectListBase, base.NovaObject):
@base.remotable_classmethod
def get_by_user(cls, context, user_id, limit=None, marker=None):
try:
api_db_keypairs = cls._get_from_db(
context, user_id, limit=limit, marker=marker)
# NOTE(pkholkin): If we were asked for a marker and found it in
# results from the API DB, we must continue our pagination with
# just the limit (if any) to the main DB.
marker = None
except exception.MarkerNotFound:
api_db_keypairs = []
api_db_keypairs = cls._get_from_db(
context, user_id, limit=limit, marker=marker)
if limit is not None:
limit_more = limit - len(api_db_keypairs)
else:
limit_more = None
if limit_more is None or limit_more > 0:
main_db_keypairs = main_db_api.key_pair_get_all_by_user(
context, user_id, limit=limit_more, marker=marker)
else:
main_db_keypairs = []
return base.obj_make_list(context, cls(context), objects.KeyPair,
api_db_keypairs + main_db_keypairs)
return base.obj_make_list(
context, cls(context), objects.KeyPair, api_db_keypairs,
)
@base.remotable_classmethod
def get_count_by_user(cls, context, user_id):
return (cls._get_count_from_db(context, user_id) +
main_db_api.key_pair_count_by_user(context, user_id))
return cls._get_count_from_db(context, user_id)

View File

@ -11,7 +11,6 @@
# under the License.
from nova import context
from nova.db.main import api as db_api
from nova import exception
from nova import objects
from nova.objects import keypair
@ -21,10 +20,10 @@ from nova import test
class KeyPairObjectTestCase(test.TestCase):
def setUp(self):
super(KeyPairObjectTestCase, self).setUp()
super().setUp()
self.context = context.RequestContext('fake-user', 'fake-project')
def _api_kp(self, **values):
def _create_keypair(self, **values):
kp = objects.KeyPair(context=self.context,
user_id=self.context.user_id,
name='fookey',
@ -35,152 +34,107 @@ class KeyPairObjectTestCase(test.TestCase):
kp.create()
return kp
def _main_kp(self, **values):
vals = {
'user_id': self.context.user_id,
'name': 'fookey',
'fingerprint': 'fp',
'public_key': 'keydata',
'type': 'ssh',
}
vals.update(values)
return db_api.key_pair_create(self.context, vals)
def test_create_in_api(self):
kp = self._api_kp()
def test_create(self):
kp = self._create_keypair()
keypair.KeyPair._get_from_db(self.context, kp.user_id, kp.name)
self.assertRaises(exception.KeypairNotFound,
db_api.key_pair_get,
self.context, kp.user_id, kp.name)
def test_create_in_api_duplicate(self):
self._api_kp()
self.assertRaises(exception.KeyPairExists, self._api_kp)
def test_create_duplicate(self):
self._create_keypair()
self.assertRaises(exception.KeyPairExists, self._create_keypair)
def test_create_in_api_duplicate_in_main(self):
self._main_kp()
self.assertRaises(exception.KeyPairExists, self._api_kp)
def test_get_from_api(self):
self._api_kp(name='apikey')
self._main_kp(name='mainkey')
def test_get(self):
self._create_keypair(name='key')
kp = objects.KeyPair.get_by_name(self.context, self.context.user_id,
'apikey')
self.assertEqual('apikey', kp.name)
def test_get_from_main(self):
self._api_kp(name='apikey')
self._main_kp(name='mainkey')
kp = objects.KeyPair.get_by_name(self.context, self.context.user_id,
'mainkey')
self.assertEqual('mainkey', kp.name)
'key')
self.assertEqual('key', kp.name)
def test_get_not_found(self):
self._api_kp(name='apikey')
self._main_kp(name='mainkey')
self._create_keypair(name='key')
self.assertRaises(exception.KeypairNotFound,
objects.KeyPair.get_by_name,
self.context, self.context.user_id, 'nokey')
def test_destroy_in_api(self):
kp = self._api_kp(name='apikey')
self._main_kp(name='mainkey')
def test_destroy(self):
kp = self._create_keypair(name='key')
kp.destroy()
self.assertRaises(exception.KeypairNotFound,
objects.KeyPair.get_by_name,
self.context, self.context.user_id, 'apikey')
self.context, self.context.user_id, 'key')
def test_destroy_by_name_in_api(self):
self._api_kp(name='apikey')
self._main_kp(name='mainkey')
def test_destroy_by_name(self):
self._create_keypair(name='key')
objects.KeyPair.destroy_by_name(self.context, self.context.user_id,
'apikey')
'key')
self.assertRaises(exception.KeypairNotFound,
objects.KeyPair.get_by_name,
self.context, self.context.user_id, 'apikey')
def test_destroy_in_main(self):
self._api_kp(name='apikey')
self._main_kp(name='mainkey')
kp = objects.KeyPair.get_by_name(self.context, self.context.user_id,
'mainkey')
kp.destroy()
self.assertRaises(exception.KeypairNotFound,
objects.KeyPair.get_by_name,
self.context, self.context.user_id, 'mainkey')
def test_destroy_by_name_in_main(self):
self._api_kp(name='apikey')
self._main_kp(name='mainkey')
objects.KeyPair.destroy_by_name(self.context, self.context.user_id,
'mainkey')
self.context, self.context.user_id, 'key')
def test_get_by_user(self):
self._api_kp(name='apikey')
self._main_kp(name='mainkey')
self._create_keypair(name='key1')
self._create_keypair(name='key2')
kpl = objects.KeyPairList.get_by_user(self.context,
self.context.user_id)
self.assertEqual(2, len(kpl))
self.assertEqual(set(['apikey', 'mainkey']),
self.assertEqual(set(['key1', 'key2']),
set([x.name for x in kpl]))
def test_get_count_by_user(self):
self._api_kp(name='apikey')
self._main_kp(name='mainkey')
self._create_keypair(name='key1')
self._create_keypair(name='key2')
count = objects.KeyPairList.get_count_by_user(self.context,
self.context.user_id)
self.assertEqual(2, count)
def test_get_by_user_limit_and_marker(self):
self._api_kp(name='apikey1')
self._api_kp(name='apikey2')
self._main_kp(name='mainkey1')
self._main_kp(name='mainkey2')
self._create_keypair(name='key1')
self._create_keypair(name='key2')
self._create_keypair(name='key3')
self._create_keypair(name='key4')
# check all 4 keypairs (2 api and 2 main)
# check all 4 keypairs
kpl = objects.KeyPairList.get_by_user(self.context,
self.context.user_id)
self.assertEqual(4, len(kpl))
self.assertEqual(set(['apikey1', 'apikey2', 'mainkey1', 'mainkey2']),
self.assertEqual(set(['key1', 'key2', 'key3', 'key4']),
set([x.name for x in kpl]))
# check only 1 keypair (1 api)
# check only 1 keypair
kpl = objects.KeyPairList.get_by_user(self.context,
self.context.user_id,
limit=1)
self.assertEqual(1, len(kpl))
self.assertEqual(set(['apikey1']),
self.assertEqual(set(['key1']),
set([x.name for x in kpl]))
# check only 3 keypairs (2 api and 1 main)
# check only 3 keypairs
kpl = objects.KeyPairList.get_by_user(self.context,
self.context.user_id,
limit=3)
self.assertEqual(3, len(kpl))
self.assertEqual(set(['apikey1', 'apikey2', 'mainkey1']),
self.assertEqual(set(['key1', 'key2', 'key3']),
set([x.name for x in kpl]))
# check keypairs after 'apikey1' (1 api and 2 main)
# check keypairs after 'key1' (3 keypairs)
kpl = objects.KeyPairList.get_by_user(self.context,
self.context.user_id,
marker='apikey1')
marker='key1')
self.assertEqual(3, len(kpl))
self.assertEqual(set(['apikey2', 'mainkey1', 'mainkey2']),
self.assertEqual(set(['key2', 'key3', 'key4']),
set([x.name for x in kpl]))
# check keypairs after 'mainkey2' (no keypairs)
# check keypairs after 'key4' (no keypairs)
kpl = objects.KeyPairList.get_by_user(self.context,
self.context.user_id,
marker='mainkey2')
marker='key4')
self.assertEqual(0, len(kpl))
# check only 2 keypairs after 'apikey1' (1 api and 1 main)
# check only 2 keypairs after 'key1' (2 keypairs)
kpl = objects.KeyPairList.get_by_user(self.context,
self.context.user_id,
limit=2,
marker='apikey1')
marker='key1')
self.assertEqual(2, len(kpl))
self.assertEqual(set(['apikey2', 'mainkey1']),
self.assertEqual(set(['key2', 'key3']),
set([x.name for x in kpl]))
# check non-existing keypair
@ -191,43 +145,43 @@ class KeyPairObjectTestCase(test.TestCase):
def test_get_by_user_different_users(self):
# create keypairs for two users
self._api_kp(name='apikey', user_id='user1')
self._api_kp(name='apikey', user_id='user2')
self._main_kp(name='mainkey', user_id='user1')
self._main_kp(name='mainkey', user_id='user2')
self._create_keypair(name='key1', user_id='user1')
self._create_keypair(name='key2', user_id='user1')
self._create_keypair(name='key1', user_id='user2')
self._create_keypair(name='key2', user_id='user2')
# check all 2 keypairs for user1 (1 api and 1 main)
# check all 2 keypairs for user1
kpl = objects.KeyPairList.get_by_user(self.context, 'user1')
self.assertEqual(2, len(kpl))
self.assertEqual(set(['apikey', 'mainkey']),
self.assertEqual(set(['key1', 'key2']),
set([x.name for x in kpl]))
# check all 2 keypairs for user2 (1 api and 1 main)
# check all 2 keypairs for user2
kpl = objects.KeyPairList.get_by_user(self.context, 'user2')
self.assertEqual(2, len(kpl))
self.assertEqual(set(['apikey', 'mainkey']),
self.assertEqual(set(['key1', 'key2']),
set([x.name for x in kpl]))
# check only 1 keypair for user1 (1 api)
# check only 1 keypair for user1
kpl = objects.KeyPairList.get_by_user(self.context, 'user1', limit=1)
self.assertEqual(1, len(kpl))
self.assertEqual(set(['apikey']),
self.assertEqual(set(['key1']),
set([x.name for x in kpl]))
# check keypairs after 'apikey' for user2 (1 main)
# check keypairs after 'key1' for user2 (1 keypair)
kpl = objects.KeyPairList.get_by_user(self.context, 'user2',
marker='apikey')
marker='key1')
self.assertEqual(1, len(kpl))
self.assertEqual(set(['mainkey']),
self.assertEqual(set(['key2']),
set([x.name for x in kpl]))
# check only 2 keypairs after 'apikey' for user1 (1 main)
# check only 2 keypairs after 'key1' for user1 (1 keypair)
kpl = objects.KeyPairList.get_by_user(self.context,
'user1',
limit=2,
marker='apikey')
marker='key1')
self.assertEqual(1, len(kpl))
self.assertEqual(set(['mainkey']),
self.assertEqual(set(['key2']),
set([x.name for x in kpl]))
# check non-existing keypair for user2

View File

@ -43,18 +43,29 @@ def fake_keypair(name):
name=name, **keypair_data)
def db_key_pair_get_all_by_user(self, user_id, limit, marker):
def _fake_get_from_db(context, user_id, name=None, limit=None, marker=None):
if name:
if name != 'FAKE':
raise exception.KeypairNotFound(user_id=user_id, name=name)
return fake_keypair('FAKE')
return [fake_keypair('FAKE')]
def db_key_pair_create(self, keypair):
return fake_keypair(name=keypair['name'])
def _fake_get_count_from_db(context, user_id):
return 1
def db_key_pair_destroy(context, user_id, name):
def _fake_create_in_db(context, values):
return fake_keypair(name=values['name'])
def _fake_destroy_in_db(context, user_id, name):
if not (user_id and name):
raise Exception()
if name != 'FAKE':
raise exception.KeypairNotFound(user_id=user_id, name=name)
def db_key_pair_create_duplicate(context):
raise exception.KeyPairExists(key_name='create_duplicate')
@ -74,12 +85,15 @@ class KeypairsTestV21(test.TestCase):
fakes.stub_out_networking(self)
fakes.stub_out_secgroup_api(self)
self.stub_out("nova.db.main.api.key_pair_get_all_by_user",
db_key_pair_get_all_by_user)
self.stub_out("nova.db.main.api.key_pair_create",
db_key_pair_create)
self.stub_out("nova.db.main.api.key_pair_destroy",
db_key_pair_destroy)
self.stub_out(
'nova.objects.keypair._create_in_db', _fake_create_in_db)
self.stub_out(
'nova.objects.keypair._destroy_in_db', _fake_destroy_in_db)
self.stub_out(
'nova.objects.keypair._get_from_db', _fake_get_from_db)
self.stub_out(
'nova.objects.keypair._get_count_from_db', _fake_get_count_from_db)
self._setup_app_and_controller()
self.req = fakes.HTTPRequest.blank('', version=self.wsgi_api_version)
@ -222,7 +236,7 @@ class KeypairsTestV21(test.TestCase):
mock_check.side_effect = [None, exc]
body = {
'keypair': {
'name': 'create_test',
'name': 'FAKE',
},
}
@ -278,39 +292,20 @@ class KeypairsTestV21(test.TestCase):
self.controller.show, self.req, 'DOESNOTEXIST')
def test_keypair_delete_not_found(self):
def db_key_pair_get_not_found(context, user_id, name):
raise exception.KeypairNotFound(user_id=user_id, name=name)
self.stub_out("nova.db.main.api.key_pair_destroy",
db_key_pair_get_not_found)
self.assertRaises(webob.exc.HTTPNotFound,
self.controller.delete, self.req, 'FAKE')
self.controller.delete, self.req, 'DOESNOTEXIST')
def test_keypair_show(self):
def _db_key_pair_get(context, user_id, name):
return dict(test_keypair.fake_keypair,
name='foo', public_key='XXX', fingerprint='YYY',
type='ssh')
self.stub_out("nova.db.main.api.key_pair_get", _db_key_pair_get)
res_dict = self.controller.show(self.req, 'FAKE')
self.assertEqual('foo', res_dict['keypair']['name'])
self.assertEqual('XXX', res_dict['keypair']['public_key'])
self.assertEqual('YYY', res_dict['keypair']['fingerprint'])
self.assertEqual('FAKE', res_dict['keypair']['name'])
self.assertEqual('FAKE_KEY', res_dict['keypair']['public_key'])
self.assertEqual(
'FAKE_FINGERPRINT', res_dict['keypair']['fingerprint'])
self._assert_keypair_type(res_dict)
def test_keypair_show_not_found(self):
def _db_key_pair_get(context, user_id, name):
raise exception.KeypairNotFound(user_id=user_id, name=name)
self.stub_out("nova.db.main.api.key_pair_get", _db_key_pair_get)
self.assertRaises(webob.exc.HTTPNotFound,
self.controller.show, self.req, 'FAKE')
self.controller.show, self.req, 'DOESNOTEXIST')
def _assert_keypair_type(self, res_dict):
self.assertNotIn('type', res_dict['keypair'])
@ -432,9 +427,9 @@ class KeypairsTestV235(test.TestCase):
super(KeypairsTestV235, self).setUp()
self._setup_app_and_controller()
@mock.patch("nova.db.main.api.key_pair_get_all_by_user")
@mock.patch('nova.objects.keypair._get_from_db')
def test_keypair_list_limit_and_marker(self, mock_kp_get):
mock_kp_get.side_effect = db_key_pair_get_all_by_user
mock_kp_get.side_effect = _fake_get_from_db
req = fakes.HTTPRequest.blank(
self.base_url + '/os-keypairs?limit=3&marker=fake_marker',
@ -467,10 +462,11 @@ class KeypairsTestV235(test.TestCase):
self.assertRaises(exception.ValidationError, self.controller.index,
req)
@mock.patch("nova.db.main.api.key_pair_get_all_by_user")
@mock.patch('nova.objects.keypair._get_from_db')
def test_keypair_list_limit_and_marker_invalid_in_old_microversion(
self, mock_kp_get):
mock_kp_get.side_effect = db_key_pair_get_all_by_user
self, mock_kp_get,
):
mock_kp_get.side_effect = _fake_get_from_db
req = fakes.HTTPRequest.blank(
self.base_url + '/os-keypairs?limit=3&marker=fake_marker',
@ -488,17 +484,14 @@ class KeypairsTestV275(test.TestCase):
super(KeypairsTestV275, self).setUp()
self.controller = keypairs_v21.KeypairController()
@mock.patch("nova.db.main.api.key_pair_get_all_by_user")
@mock.patch('nova.objects.KeyPair.get_by_name')
def test_keypair_list_additional_param_old_version(self, mock_get_by_name,
mock_kp_get):
def test_keypair_list_additional_param_old_version(self, mock_get_by_name):
req = fakes.HTTPRequest.blank(
'/os-keypairs?unknown=3',
version='2.74', use_admin_context=True)
self.controller.index(req)
self.controller.show(req, 1)
with mock.patch.object(self.controller.api,
'delete_key_pair'):
with mock.patch.object(self.controller.api, 'delete_key_pair'):
self.controller.delete(req, 1)
def test_keypair_list_additional_param(self):

View File

@ -105,12 +105,10 @@ def stub_out_key_pair_funcs(testcase, have_key_pair=True, **kwargs):
return []
if have_key_pair:
testcase.stub_out(
'nova.db.main.api.key_pair_get_all_by_user', key_pair)
testcase.stub_out('nova.db.main.api.key_pair_get', one_key_pair)
testcase.stub_out('nova.objects.KeyPairList._get_from_db', key_pair)
testcase.stub_out('nova.objects.KeyPair._get_from_db', one_key_pair)
else:
testcase.stub_out(
'nova.db.main.api.key_pair_get_all_by_user', no_key_pair)
testcase.stub_out('nova.objects.KeyPairList._get_from_db', key_pair)
def stub_out_instance_quota(test, allowed, quota, resource='instances'):

View File

@ -74,11 +74,13 @@ class KeypairAPITestCase(test_compute.BaseTestCase):
else:
raise exception.KeypairNotFound(user_id=user_id, name=name)
self.stub_out("nova.db.main.api.key_pair_get_all_by_user",
db_key_pair_get_all_by_user)
self.stub_out("nova.db.main.api.key_pair_create", db_key_pair_create)
self.stub_out("nova.db.main.api.key_pair_destroy", db_key_pair_destroy)
self.stub_out("nova.db.main.api.key_pair_get", db_key_pair_get)
self.stub_out(
'nova.objects.KeyPairList._get_from_db',
db_key_pair_get_all_by_user)
self.stub_out('nova.objects.KeyPair._create_in_db', db_key_pair_create)
self.stub_out(
'nova.objects.KeyPair._destroy_in_db', db_key_pair_destroy)
self.stub_out('nova.objects.KeyPair._get_from_db', db_key_pair_get)
def _check_notifications(self, action='create', key_name='foo'):
self.assertEqual(2, len(self.notifier.notifications))
@ -141,8 +143,8 @@ class CreateImportSharedTestMixIn(object):
def db_key_pair_create_duplicate(context, keypair):
raise exception.KeyPairExists(key_name=keypair.get('name', ''))
self.stub_out("nova.db.main.api.key_pair_create",
db_key_pair_create_duplicate)
self.stub_out(
'nova.objects.KeyPair._create_in_db', db_key_pair_create_duplicate)
msg = ("Key pair '%(key_name)s' already exists." %
{'key_name': self.existing_key_name})

View File

@ -4752,202 +4752,6 @@ class VirtualInterfaceTestCase(test.TestCase, ModelsObjectComparatorMixin):
self._assertEqualObjects(updated, updated_vif, ignored_keys)
class KeyPairTestCase(test.TestCase, ModelsObjectComparatorMixin):
def setUp(self):
super(KeyPairTestCase, self).setUp()
self.ctxt = context.get_admin_context()
def _create_key_pair(self, values):
return db.key_pair_create(self.ctxt, values)
def test_key_pair_create(self):
param = {
'name': 'test_1',
'type': 'ssh',
'user_id': 'test_user_id_1',
'public_key': 'test_public_key_1',
'fingerprint': 'test_fingerprint_1'
}
key_pair = self._create_key_pair(param)
self.assertIsNotNone(key_pair['id'])
ignored_keys = ['deleted', 'created_at', 'updated_at',
'deleted_at', 'id']
self._assertEqualObjects(key_pair, param, ignored_keys)
def test_key_pair_create_with_duplicate_name(self):
params = {'name': 'test_name', 'user_id': 'test_user_id',
'type': 'ssh'}
self._create_key_pair(params)
self.assertRaises(exception.KeyPairExists, self._create_key_pair,
params)
def test_key_pair_get(self):
params = [
{'name': 'test_1', 'user_id': 'test_user_id_1', 'type': 'ssh'},
{'name': 'test_2', 'user_id': 'test_user_id_2', 'type': 'ssh'},
{'name': 'test_3', 'user_id': 'test_user_id_3', 'type': 'ssh'}
]
key_pairs = [self._create_key_pair(p) for p in params]
for key in key_pairs:
real_key = db.key_pair_get(self.ctxt, key['user_id'], key['name'])
self._assertEqualObjects(key, real_key)
def test_key_pair_get_no_results(self):
param = {'name': 'test_1', 'user_id': 'test_user_id_1'}
self.assertRaises(exception.KeypairNotFound, db.key_pair_get,
self.ctxt, param['user_id'], param['name'])
def test_key_pair_get_deleted(self):
param = {'name': 'test_1', 'user_id': 'test_user_id_1', 'type': 'ssh'}
key_pair_created = self._create_key_pair(param)
db.key_pair_destroy(self.ctxt, param['user_id'], param['name'])
self.assertRaises(exception.KeypairNotFound, db.key_pair_get,
self.ctxt, param['user_id'], param['name'])
ctxt = self.ctxt.elevated(read_deleted='yes')
key_pair_deleted = db.key_pair_get(ctxt, param['user_id'],
param['name'])
ignored_keys = ['deleted', 'created_at', 'updated_at', 'deleted_at']
self._assertEqualObjects(key_pair_deleted, key_pair_created,
ignored_keys)
self.assertEqual(key_pair_deleted['deleted'], key_pair_deleted['id'])
def test_key_pair_get_all_by_user(self):
params = [
{'name': 'test_1', 'user_id': 'test_user_id_1', 'type': 'ssh'},
{'name': 'test_2', 'user_id': 'test_user_id_1', 'type': 'ssh'},
{'name': 'test_3', 'user_id': 'test_user_id_2', 'type': 'ssh'}
]
key_pairs_user_1 = [self._create_key_pair(p) for p in params
if p['user_id'] == 'test_user_id_1']
key_pairs_user_2 = [self._create_key_pair(p) for p in params
if p['user_id'] == 'test_user_id_2']
real_keys_1 = db.key_pair_get_all_by_user(self.ctxt, 'test_user_id_1')
real_keys_2 = db.key_pair_get_all_by_user(self.ctxt, 'test_user_id_2')
self._assertEqualListsOfObjects(key_pairs_user_1, real_keys_1)
self._assertEqualListsOfObjects(key_pairs_user_2, real_keys_2)
def test_key_pair_get_all_by_user_limit_and_marker(self):
params = [
{'name': 'test_1', 'user_id': 'test_user_id', 'type': 'ssh'},
{'name': 'test_2', 'user_id': 'test_user_id', 'type': 'ssh'},
{'name': 'test_3', 'user_id': 'test_user_id', 'type': 'ssh'}
]
# check all 3 keypairs
keys = [self._create_key_pair(p) for p in params]
db_keys = db.key_pair_get_all_by_user(self.ctxt, 'test_user_id')
self._assertEqualListsOfObjects(keys, db_keys)
# check only 1 keypair
expected_keys = [keys[0]]
db_keys = db.key_pair_get_all_by_user(self.ctxt, 'test_user_id',
limit=1)
self._assertEqualListsOfObjects(expected_keys, db_keys)
# check keypairs after 'test_1'
expected_keys = [keys[1], keys[2]]
db_keys = db.key_pair_get_all_by_user(self.ctxt, 'test_user_id',
marker='test_1')
self._assertEqualListsOfObjects(expected_keys, db_keys)
# check only 1 keypairs after 'test_1'
expected_keys = [keys[1]]
db_keys = db.key_pair_get_all_by_user(self.ctxt, 'test_user_id',
limit=1,
marker='test_1')
self._assertEqualListsOfObjects(expected_keys, db_keys)
# check non-existing keypair
self.assertRaises(exception.MarkerNotFound,
db.key_pair_get_all_by_user,
self.ctxt, 'test_user_id',
limit=1, marker='unknown_kp')
def test_key_pair_get_all_by_user_different_users(self):
params1 = [
{'name': 'test_1', 'user_id': 'test_user_1', 'type': 'ssh'},
{'name': 'test_2', 'user_id': 'test_user_1', 'type': 'ssh'},
{'name': 'test_3', 'user_id': 'test_user_1', 'type': 'ssh'}
]
params2 = [
{'name': 'test_1', 'user_id': 'test_user_2', 'type': 'ssh'},
{'name': 'test_2', 'user_id': 'test_user_2', 'type': 'ssh'},
{'name': 'test_3', 'user_id': 'test_user_2', 'type': 'ssh'}
]
# create keypairs for two users
keys1 = [self._create_key_pair(p) for p in params1]
keys2 = [self._create_key_pair(p) for p in params2]
# check all 2 keypairs for test_user_1
db_keys = db.key_pair_get_all_by_user(self.ctxt, 'test_user_1')
self._assertEqualListsOfObjects(keys1, db_keys)
# check all 2 keypairs for test_user_2
db_keys = db.key_pair_get_all_by_user(self.ctxt, 'test_user_2')
self._assertEqualListsOfObjects(keys2, db_keys)
# check only 1 keypair for test_user_1
expected_keys = [keys1[0]]
db_keys = db.key_pair_get_all_by_user(self.ctxt, 'test_user_1',
limit=1)
self._assertEqualListsOfObjects(expected_keys, db_keys)
# check keypairs after 'test_1' for test_user_2
expected_keys = [keys2[1], keys2[2]]
db_keys = db.key_pair_get_all_by_user(self.ctxt, 'test_user_2',
marker='test_1')
self._assertEqualListsOfObjects(expected_keys, db_keys)
# check only 1 keypairs after 'test_1' for test_user_1
expected_keys = [keys1[1]]
db_keys = db.key_pair_get_all_by_user(self.ctxt, 'test_user_1',
limit=1,
marker='test_1')
self._assertEqualListsOfObjects(expected_keys, db_keys)
# check non-existing keypair for test_user_2
self.assertRaises(exception.MarkerNotFound,
db.key_pair_get_all_by_user,
self.ctxt, 'test_user_2',
limit=1, marker='unknown_kp')
def test_key_pair_count_by_user(self):
params = [
{'name': 'test_1', 'user_id': 'test_user_id_1', 'type': 'ssh'},
{'name': 'test_2', 'user_id': 'test_user_id_1', 'type': 'ssh'},
{'name': 'test_3', 'user_id': 'test_user_id_2', 'type': 'ssh'}
]
for p in params:
self._create_key_pair(p)
count_1 = db.key_pair_count_by_user(self.ctxt, 'test_user_id_1')
self.assertEqual(count_1, 2)
count_2 = db.key_pair_count_by_user(self.ctxt, 'test_user_id_2')
self.assertEqual(count_2, 1)
def test_key_pair_destroy(self):
param = {'name': 'test_1', 'user_id': 'test_user_id_1', 'type': 'ssh'}
self._create_key_pair(param)
db.key_pair_destroy(self.ctxt, param['user_id'], param['name'])
self.assertRaises(exception.KeypairNotFound, db.key_pair_get,
self.ctxt, param['user_id'], param['name'])
def test_key_pair_destroy_no_such_key(self):
param = {'name': 'test_1', 'user_id': 'test_user_id_1'}
self.assertRaises(exception.KeypairNotFound,
db.key_pair_destroy, self.ctxt,
param['user_id'], param['name'])
class QuotaTestCase(test.TestCase, ModelsObjectComparatorMixin):
"""Tests for db.api.quota_* methods."""

View File

@ -12,8 +12,6 @@
# License for the specific language governing permissions and limitations
# under the License.
import copy
import mock
from oslo_utils import timeutils
@ -33,27 +31,11 @@ fake_keypair = {
'user_id': 'fake-user',
'fingerprint': 'fake-fingerprint',
'public_key': 'fake\npublic\nkey',
}
}
class _TestKeyPairObject(object):
@mock.patch('nova.db.main.api.key_pair_get')
@mock.patch('nova.objects.KeyPair._get_from_db')
def test_get_by_name_main(self, mock_api_get, mock_kp_get):
mock_api_get.side_effect = exception.KeypairNotFound(user_id='foo',
name='foo')
mock_kp_get.return_value = fake_keypair
keypair_obj = keypair.KeyPair.get_by_name(self.context, 'fake-user',
'foo-keypair')
self.compare_obj(keypair_obj, fake_keypair)
mock_kp_get.assert_called_once_with(self.context, 'fake-user',
'foo-keypair')
mock_api_get.assert_called_once_with(self.context, 'fake-user',
'foo-keypair')
@mock.patch('nova.objects.KeyPair._create_in_db')
def test_create(self, mock_kp_create):
mock_kp_create.return_value = fake_keypair
@ -103,26 +85,20 @@ class _TestKeyPairObject(object):
mock_kp_destroy.assert_called_once_with(
self.context, 'fake-user', 'foo-keypair')
@mock.patch('nova.db.main.api.key_pair_get_all_by_user')
@mock.patch('nova.db.main.api.key_pair_count_by_user')
@mock.patch('nova.objects.KeyPairList._get_from_db')
@mock.patch('nova.objects.KeyPairList._get_count_from_db')
def test_get_by_user(self, mock_api_count, mock_api_get, mock_kp_count,
mock_kp_get):
mock_kp_get.return_value = [fake_keypair]
mock_kp_count.return_value = 1
def test_get_by_user(self, mock_api_count, mock_api_get):
mock_api_get.return_value = [fake_keypair]
mock_api_count.return_value = 1
keypairs = keypair.KeyPairList.get_by_user(self.context, 'fake-user')
self.assertEqual(2, len(keypairs))
self.assertEqual(1, len(keypairs))
self.compare_obj(keypairs[0], fake_keypair)
self.compare_obj(keypairs[1], fake_keypair)
self.assertEqual(2, keypair.KeyPairList.get_count_by_user(self.context,
'fake-user'))
mock_kp_get.assert_called_once_with(self.context, 'fake-user',
limit=None, marker=None)
mock_kp_count.assert_called_once_with(self.context, 'fake-user')
keypair_count = keypair.KeyPairList.get_count_by_user(
self.context, 'fake-user')
self.assertEqual(1, keypair_count)
mock_api_get.assert_called_once_with(self.context, 'fake-user',
limit=None, marker=None)
mock_api_count.assert_called_once_with(self.context, 'fake-user')
@ -134,29 +110,21 @@ class _TestKeyPairObject(object):
keypair_obj.obj_make_compatible(fake_keypair_copy, '1.1')
self.assertNotIn('type', fake_keypair_copy)
@mock.patch('nova.db.main.api.key_pair_get_all_by_user')
@mock.patch('nova.objects.KeyPairList._get_from_db')
def test_get_by_user_limit(self, mock_api_get, mock_kp_get):
api_keypair = copy.deepcopy(fake_keypair)
api_keypair['name'] = 'api_kp'
def test_get_by_user_limit(self, mock_api_get):
mock_api_get.return_value = [fake_keypair]
mock_api_get.return_value = [api_keypair]
mock_kp_get.return_value = [fake_keypair]
keypairs = keypair.KeyPairList.get_by_user(self.context, 'fake-user',
limit=1)
keypairs = keypair.KeyPairList.get_by_user(
self.context, 'fake-user', limit=1)
self.assertEqual(1, len(keypairs))
self.compare_obj(keypairs[0], api_keypair)
self.compare_obj(keypairs[0], fake_keypair)
mock_api_get.assert_called_once_with(self.context, 'fake-user',
limit=1, marker=None)
self.assertFalse(mock_kp_get.called)
@mock.patch('nova.db.main.api.key_pair_get_all_by_user')
@mock.patch('nova.objects.KeyPairList._get_from_db')
def test_get_by_user_marker(self, mock_api_get, mock_kp_get):
def test_get_by_user_marker(self, mock_api_get):
api_kp_name = 'api_kp'
mock_api_get.side_effect = exception.MarkerNotFound(marker=api_kp_name)
mock_kp_get.return_value = [fake_keypair]
mock_api_get.return_value = [fake_keypair]
keypairs = keypair.KeyPairList.get_by_user(self.context, 'fake-user',
marker=api_kp_name)
@ -165,59 +133,26 @@ class _TestKeyPairObject(object):
mock_api_get.assert_called_once_with(self.context, 'fake-user',
limit=None,
marker=api_kp_name)
mock_kp_get.assert_called_once_with(self.context, 'fake-user',
limit=None,
marker=api_kp_name)
@mock.patch('nova.db.main.api.key_pair_get_all_by_user')
@mock.patch('nova.objects.KeyPairList._get_from_db')
def test_get_by_user_limit_and_marker_api(self, mock_api_get, mock_kp_get):
def test_get_by_user_limit_and_marker_api(self, mock_api_get):
first_api_kp_name = 'first_api_kp'
api_keypair = copy.deepcopy(fake_keypair)
api_keypair['name'] = 'api_kp'
mock_api_get.return_value = [api_keypair]
mock_kp_get.return_value = [fake_keypair]
mock_api_get.return_value = [fake_keypair]
keypairs = keypair.KeyPairList.get_by_user(self.context, 'fake-user',
limit=5,
marker=first_api_kp_name)
self.assertEqual(2, len(keypairs))
self.compare_obj(keypairs[0], api_keypair)
self.compare_obj(keypairs[1], fake_keypair)
mock_api_get.assert_called_once_with(self.context, 'fake-user',
limit=5,
marker=first_api_kp_name)
mock_kp_get.assert_called_once_with(self.context, 'fake-user',
limit=4, marker=None)
@mock.patch('nova.db.main.api.key_pair_get_all_by_user')
@mock.patch('nova.objects.KeyPairList._get_from_db')
def test_get_by_user_limit_and_marker_main(self, mock_api_get,
mock_kp_get):
first_main_kp_name = 'first_main_kp'
mock_api_get.side_effect = exception.MarkerNotFound(
marker=first_main_kp_name)
mock_kp_get.return_value = [fake_keypair]
keypairs = keypair.KeyPairList.get_by_user(self.context, 'fake-user',
limit=5,
marker=first_main_kp_name)
self.assertEqual(1, len(keypairs))
self.compare_obj(keypairs[0], fake_keypair)
mock_api_get.assert_called_once_with(self.context, 'fake-user',
limit=5,
marker=first_main_kp_name)
mock_kp_get.assert_called_once_with(self.context, 'fake-user',
limit=5, marker=first_main_kp_name)
marker=first_api_kp_name)
@mock.patch('nova.db.main.api.key_pair_get_all_by_user')
@mock.patch('nova.objects.KeyPairList._get_from_db')
def test_get_by_user_limit_and_marker_invalid_marker(
self, mock_api_get, mock_kp_get):
def test_get_by_user_limit_and_marker_invalid_marker(self, mock_api_get):
kp_name = 'unknown_kp'
mock_api_get.side_effect = exception.MarkerNotFound(marker=kp_name)
mock_kp_get.side_effect = exception.MarkerNotFound(marker=kp_name)
self.assertRaises(exception.MarkerNotFound,
keypair.KeyPairList.get_by_user,
@ -225,11 +160,9 @@ class _TestKeyPairObject(object):
limit=5, marker=kp_name)
class TestMigrationObject(test_objects._LocalTest,
_TestKeyPairObject):
class TestMigrationObject(test_objects._LocalTest, _TestKeyPairObject):
pass
class TestRemoteMigrationObject(test_objects._RemoteTest,
_TestKeyPairObject):
class TestRemoteMigrationObject(test_objects._RemoteTest, _TestKeyPairObject):
pass