Make AggregateList.get_ return API & cell db items

Make the AggregateList get_all, get_by_hosts and get_by_metadata
functions return items from both the API and cell databases.

blueprint cells-aggregate-api-db

Change-Id: I0e1e11283e6ab2f7d2b976d126e914f9d7b4d8fb
This commit is contained in:
Mark Doffman 2016-03-22 14:46:56 -05:00
parent 5dad61b1f3
commit 92bcd0ae42
2 changed files with 139 additions and 29 deletions

View File

@ -15,6 +15,7 @@
from oslo_log import log as logging from oslo_log import log as logging
from oslo_utils import uuidutils from oslo_utils import uuidutils
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from nova.compute import utils as compute_utils from nova.compute import utils as compute_utils
@ -212,6 +213,41 @@ class Aggregate(base.NovaPersistentObject, base.NovaObject):
return self.metadata.get('availability_zone', None) return self.metadata.get('availability_zone', None)
@db_api.api_context_manager.reader
def _get_all_from_db(context):
query = context.session.query(api_models.Aggregate).\
options(joinedload('_hosts')).\
options(joinedload('_metadata'))
return query.all()
@db_api.api_context_manager.reader
def _get_by_host_from_db(context, host, key=None):
query = context.session.query(api_models.Aggregate).\
options(joinedload('_hosts')).\
options(joinedload('_metadata'))
query = query.join('_hosts')
query = query.filter(api_models.AggregateHost.host == host)
if key:
query = query.join("_metadata").filter(
api_models.AggregateMetadata.key == key)
return query.all()
@db_api.api_context_manager.reader
def _get_by_metadata_key_from_db(context, key):
query = context.session.query(api_models.Aggregate)
query = query.join("_metadata")
query = query.filter(api_models.AggregateMetadata.key == key)
query = query.options(contains_eager("_metadata"))
query = query.options(joinedload("_hosts"))
return query.all()
@base.NovaObjectRegistry.register @base.NovaObjectRegistry.register
class AggregateList(base.ObjectListBase, base.NovaObject): class AggregateList(base.ObjectListBase, base.NovaObject):
# Version 1.0: Initial version # Version 1.0: Initial version
@ -224,6 +260,14 @@ class AggregateList(base.ObjectListBase, base.NovaObject):
'objects': fields.ListOfObjectsField('Aggregate'), 'objects': fields.ListOfObjectsField('Aggregate'),
} }
# NOTE(mdoff): Calls to this can be removed when we remove
# compatibility with the old aggregate model.
@staticmethod
def _fill_deprecated(db_aggregate):
db_aggregate['deleted_at'] = None
db_aggregate['deleted'] = False
return db_aggregate
@classmethod @classmethod
def _filter_db_aggregates(cls, db_aggregates, hosts): def _filter_db_aggregates(cls, db_aggregates, hosts):
if not isinstance(hosts, set): if not isinstance(hosts, set):
@ -238,20 +282,28 @@ class AggregateList(base.ObjectListBase, base.NovaObject):
@base.remotable_classmethod @base.remotable_classmethod
def get_all(cls, context): def get_all(cls, context):
api_db_aggregates = [cls._fill_deprecated(agg) for agg in
_get_all_from_db(context)]
db_aggregates = db.aggregate_get_all(context) db_aggregates = db.aggregate_get_all(context)
return base.obj_make_list(context, cls(context), objects.Aggregate, return base.obj_make_list(context, cls(context), objects.Aggregate,
db_aggregates) db_aggregates + api_db_aggregates)
@base.remotable_classmethod @base.remotable_classmethod
def get_by_host(cls, context, host, key=None): def get_by_host(cls, context, host, key=None):
api_db_aggregates = [cls._fill_deprecated(agg) for agg in
_get_by_host_from_db(context, host, key=key)]
db_aggregates = db.aggregate_get_by_host(context, host, key=key) db_aggregates = db.aggregate_get_by_host(context, host, key=key)
return base.obj_make_list(context, cls(context), objects.Aggregate, return base.obj_make_list(context, cls(context), objects.Aggregate,
db_aggregates) db_aggregates + api_db_aggregates)
@base.remotable_classmethod @base.remotable_classmethod
def get_by_metadata_key(cls, context, key, hosts=None): def get_by_metadata_key(cls, context, key, hosts=None):
api_db_aggregates = [cls._fill_deprecated(agg) for agg in
_get_by_metadata_key_from_db(context, key=key)]
db_aggregates = db.aggregate_get_by_metadata_key(context, key=key) db_aggregates = db.aggregate_get_by_metadata_key(context, key=key)
all_aggregates = db_aggregates + api_db_aggregates
if hosts is not None: if hosts is not None:
db_aggregates = cls._filter_db_aggregates(db_aggregates, hosts) all_aggregates = cls._filter_db_aggregates(all_aggregates, hosts)
return base.obj_make_list(context, cls(context), objects.Aggregate, return base.obj_make_list(context, cls(context), objects.Aggregate,
db_aggregates) all_aggregates)

View File

@ -66,12 +66,13 @@ def _create_aggregate(context, values=fake_db_aggregate_values,
aggregate.update(values) aggregate.update(values)
aggregate.save(context.session) aggregate.save(context.session)
for key, value in metadata.items(): if metadata:
aggregate_metadata = api_models.AggregateMetadata() for key, value in metadata.items():
aggregate_metadata.update({'key': key, aggregate_metadata = api_models.AggregateMetadata()
'value': value, aggregate_metadata.update({'key': key,
'aggregate_id': aggregate['id']}) 'value': value,
aggregate_metadata.save(context.session) 'aggregate_id': aggregate['id']})
aggregate_metadata.save(context.session)
return aggregate return aggregate
@ -82,10 +83,10 @@ def _create_aggregate_with_hosts(context, values=fake_db_aggregate_values,
hosts=fake_db_aggregate_hosts): hosts=fake_db_aggregate_hosts):
aggregate = _create_aggregate(context, values, metadata) aggregate = _create_aggregate(context, values, metadata)
for host in hosts: for host in hosts:
host = api_models.AggregateHost() host_model = api_models.AggregateHost()
host.update({'host': 'foo.openstack.org', host_model.update({'host': host,
'aggregate_id': aggregate.id}) 'aggregate_id': aggregate.id})
host.save(context.session) host_model.save(context.session)
return aggregate return aggregate
@ -94,7 +95,7 @@ class _TestAggregateObject(object):
def test_aggregate_get_from_db(self): def test_aggregate_get_from_db(self):
result = _create_aggregate_with_hosts(self.context) result = _create_aggregate_with_hosts(self.context)
expected = aggregate._aggregate_get_from_db(self.context, result['id']) expected = aggregate._aggregate_get_from_db(self.context, result['id'])
self.assertEqual(fake_db_aggregate_hosts, expected['hosts']) self.assertEqual(fake_db_aggregate_hosts, expected.hosts)
self.assertEqual(fake_db_aggregate_metadata, expected['metadetails']) self.assertEqual(fake_db_aggregate_metadata, expected['metadetails'])
def test_aggregate_get_from_db_raise_not_found(self): def test_aggregate_get_from_db_raise_not_found(self):
@ -103,6 +104,56 @@ class _TestAggregateObject(object):
aggregate._aggregate_get_from_db, aggregate._aggregate_get_from_db,
self.context, aggregate_id) self.context, aggregate_id)
def test_aggregate_get_all_from_db(self):
for c in range(3):
_create_aggregate(self.context,
values={'name': 'fake_aggregate_%d' % c})
results = aggregate._get_all_from_db(self.context)
self.assertEqual(len(results), 3)
def test_aggregate_get_by_host_from_db(self):
_create_aggregate_with_hosts(self.context,
values={'name': 'fake_aggregate_1'},
hosts=['host.1.openstack.org'])
_create_aggregate_with_hosts(self.context,
values={'name': 'fake_aggregate_2'},
hosts=['host.1.openstack.org'])
_create_aggregate(self.context,
values={'name': 'no_host_aggregate'})
rh1 = aggregate._get_all_from_db(self.context)
rh2 = aggregate._get_by_host_from_db(self.context,
'host.1.openstack.org')
self.assertEqual(3, len(rh1))
self.assertEqual(2, len(rh2))
def test_aggregate_get_by_host_with_key_from_db(self):
ah1 = _create_aggregate_with_hosts(self.context,
values={'name': 'fake_aggregate_1'},
metadata={'goodkey': 'good'},
hosts=['host.1.openstack.org'])
_create_aggregate_with_hosts(self.context,
values={'name': 'fake_aggregate_2'},
hosts=['host.1.openstack.org'])
rh1 = aggregate._get_by_host_from_db(self.context,
'host.1.openstack.org',
key='goodkey')
self.assertEqual(1, len(rh1))
self.assertEqual(ah1['id'], rh1[0]['id'])
def test_aggregate_get_by_metadata_key_from_db(self):
_create_aggregate(self.context,
values={'name': 'aggregate_1'},
metadata={'goodkey': 'good'})
_create_aggregate(self.context,
values={'name': 'aggregate_2'},
metadata={'goodkey': 'bad'})
_create_aggregate(self.context,
values={'name': 'aggregate_3'},
metadata={'badkey': 'good'})
rl1 = aggregate._get_by_metadata_key_from_db(self.context,
key='goodkey')
self.assertEqual(2, len(rl1))
@mock.patch('nova.objects.aggregate._aggregate_get_from_db') @mock.patch('nova.objects.aggregate._aggregate_get_from_db')
@mock.patch('nova.db.aggregate_get') @mock.patch('nova.db.aggregate_get')
def test_get_by_id_from_api(self, mock_get, mock_get_api): def test_get_by_id_from_api(self, mock_get, mock_get_api):
@ -247,29 +298,36 @@ class _TestAggregateObject(object):
agg.metadata = {'availability_zone': 'foo'} agg.metadata = {'availability_zone': 'foo'}
self.assertEqual('foo', agg.availability_zone) self.assertEqual('foo', agg.availability_zone)
def test_get_all(self): @mock.patch('nova.objects.aggregate._get_all_from_db')
self.mox.StubOutWithMock(db, 'aggregate_get_all') @mock.patch('nova.db.aggregate_get_all')
db.aggregate_get_all(self.context).AndReturn([fake_aggregate]) def test_get_all(self, mock_get_all, mock_api_get_all):
self.mox.ReplayAll() mock_get_all.return_value = [fake_aggregate]
mock_api_get_all.return_value = [fake_api_aggregate]
aggs = aggregate.AggregateList.get_all(self.context) aggs = aggregate.AggregateList.get_all(self.context)
self.assertEqual(1, len(aggs)) self.assertEqual(2, len(aggs))
self.compare_obj(aggs[0], fake_aggregate, subs=SUBS) self.compare_obj(aggs[0], fake_aggregate, subs=SUBS)
self.compare_obj(aggs[1], fake_api_aggregate, subs=SUBS)
def test_by_host(self): @mock.patch('nova.objects.aggregate._get_by_host_from_db')
self.mox.StubOutWithMock(db, 'aggregate_get_by_host') @mock.patch('nova.db.aggregate_get_by_host')
db.aggregate_get_by_host(self.context, 'fake-host', key=None, def test_by_host(self, mock_get_by_host, mock_api_get_by_host):
).AndReturn([fake_aggregate]) mock_get_by_host.return_value = [fake_aggregate]
self.mox.ReplayAll() mock_api_get_by_host.return_value = [fake_api_aggregate]
aggs = aggregate.AggregateList.get_by_host(self.context, 'fake-host') aggs = aggregate.AggregateList.get_by_host(self.context, 'fake-host')
self.assertEqual(1, len(aggs)) self.assertEqual(2, len(aggs))
self.compare_obj(aggs[0], fake_aggregate, subs=SUBS) self.compare_obj(aggs[0], fake_aggregate, subs=SUBS)
self.compare_obj(aggs[1], fake_api_aggregate, subs=SUBS)
@mock.patch('nova.objects.aggregate._get_by_metadata_key_from_db')
@mock.patch('nova.db.aggregate_get_by_metadata_key') @mock.patch('nova.db.aggregate_get_by_metadata_key')
def test_get_by_metadata_key(self, get_by_metadata_key): def test_get_by_metadata_key(self,
get_by_metadata_key.return_value = [fake_aggregate] mock_get_by_metadata_key,
mock_api_get_by_metadata_key):
mock_get_by_metadata_key.return_value = [fake_aggregate]
mock_api_get_by_metadata_key.return_value = [fake_api_aggregate]
aggs = aggregate.AggregateList.get_by_metadata_key( aggs = aggregate.AggregateList.get_by_metadata_key(
self.context, 'this') self.context, 'this')
self.assertEqual(1, len(aggs)) self.assertEqual(2, len(aggs))
self.compare_obj(aggs[0], fake_aggregate, subs=SUBS) self.compare_obj(aggs[0], fake_aggregate, subs=SUBS)
@mock.patch('nova.db.aggregate_get_by_metadata_key') @mock.patch('nova.db.aggregate_get_by_metadata_key')