diff --git a/nova/objects/aggregate.py b/nova/objects/aggregate.py index 4db09b2b6e7b..210524970762 100644 --- a/nova/objects/aggregate.py +++ b/nova/objects/aggregate.py @@ -15,6 +15,7 @@ from oslo_log import log as logging from oslo_utils import uuidutils +from sqlalchemy.orm import contains_eager from sqlalchemy.orm import joinedload from nova.compute import utils as compute_utils @@ -212,6 +213,41 @@ class Aggregate(base.NovaPersistentObject, base.NovaObject): return self.metadata.get('availability_zone', None) +@db_api.api_context_manager.reader +def _get_all_from_db(context): + query = context.session.query(api_models.Aggregate).\ + options(joinedload('_hosts')).\ + options(joinedload('_metadata')) + + return query.all() + + +@db_api.api_context_manager.reader +def _get_by_host_from_db(context, host, key=None): + query = context.session.query(api_models.Aggregate).\ + options(joinedload('_hosts')).\ + options(joinedload('_metadata')) + query = query.join('_hosts') + query = query.filter(api_models.AggregateHost.host == host) + + if key: + query = query.join("_metadata").filter( + api_models.AggregateMetadata.key == key) + + return query.all() + + +@db_api.api_context_manager.reader +def _get_by_metadata_key_from_db(context, key): + query = context.session.query(api_models.Aggregate) + query = query.join("_metadata") + query = query.filter(api_models.AggregateMetadata.key == key) + query = query.options(contains_eager("_metadata")) + query = query.options(joinedload("_hosts")) + + return query.all() + + @base.NovaObjectRegistry.register class AggregateList(base.ObjectListBase, base.NovaObject): # Version 1.0: Initial version @@ -224,6 +260,14 @@ class AggregateList(base.ObjectListBase, base.NovaObject): 'objects': fields.ListOfObjectsField('Aggregate'), } + # NOTE(mdoff): Calls to this can be removed when we remove + # compatibility with the old aggregate model. + @staticmethod + def _fill_deprecated(db_aggregate): + db_aggregate['deleted_at'] = None + db_aggregate['deleted'] = False + return db_aggregate + @classmethod def _filter_db_aggregates(cls, db_aggregates, hosts): if not isinstance(hosts, set): @@ -238,20 +282,28 @@ class AggregateList(base.ObjectListBase, base.NovaObject): @base.remotable_classmethod def get_all(cls, context): + api_db_aggregates = [cls._fill_deprecated(agg) for agg in + _get_all_from_db(context)] db_aggregates = db.aggregate_get_all(context) return base.obj_make_list(context, cls(context), objects.Aggregate, - db_aggregates) + db_aggregates + api_db_aggregates) @base.remotable_classmethod def get_by_host(cls, context, host, key=None): + api_db_aggregates = [cls._fill_deprecated(agg) for agg in + _get_by_host_from_db(context, host, key=key)] db_aggregates = db.aggregate_get_by_host(context, host, key=key) return base.obj_make_list(context, cls(context), objects.Aggregate, - db_aggregates) + db_aggregates + api_db_aggregates) @base.remotable_classmethod def get_by_metadata_key(cls, context, key, hosts=None): + api_db_aggregates = [cls._fill_deprecated(agg) for agg in + _get_by_metadata_key_from_db(context, key=key)] db_aggregates = db.aggregate_get_by_metadata_key(context, key=key) + + all_aggregates = db_aggregates + api_db_aggregates if hosts is not None: - db_aggregates = cls._filter_db_aggregates(db_aggregates, hosts) + all_aggregates = cls._filter_db_aggregates(all_aggregates, hosts) return base.obj_make_list(context, cls(context), objects.Aggregate, - db_aggregates) + all_aggregates) diff --git a/nova/tests/unit/objects/test_aggregate.py b/nova/tests/unit/objects/test_aggregate.py index 7a39176bd2cf..ade7d05d7521 100644 --- a/nova/tests/unit/objects/test_aggregate.py +++ b/nova/tests/unit/objects/test_aggregate.py @@ -66,12 +66,13 @@ def _create_aggregate(context, values=fake_db_aggregate_values, aggregate.update(values) aggregate.save(context.session) - for key, value in metadata.items(): - aggregate_metadata = api_models.AggregateMetadata() - aggregate_metadata.update({'key': key, - 'value': value, - 'aggregate_id': aggregate['id']}) - aggregate_metadata.save(context.session) + if metadata: + for key, value in metadata.items(): + aggregate_metadata = api_models.AggregateMetadata() + aggregate_metadata.update({'key': key, + 'value': value, + 'aggregate_id': aggregate['id']}) + aggregate_metadata.save(context.session) return aggregate @@ -82,10 +83,10 @@ def _create_aggregate_with_hosts(context, values=fake_db_aggregate_values, hosts=fake_db_aggregate_hosts): aggregate = _create_aggregate(context, values, metadata) for host in hosts: - host = api_models.AggregateHost() - host.update({'host': 'foo.openstack.org', - 'aggregate_id': aggregate.id}) - host.save(context.session) + host_model = api_models.AggregateHost() + host_model.update({'host': host, + 'aggregate_id': aggregate.id}) + host_model.save(context.session) return aggregate @@ -94,7 +95,7 @@ class _TestAggregateObject(object): def test_aggregate_get_from_db(self): result = _create_aggregate_with_hosts(self.context) expected = aggregate._aggregate_get_from_db(self.context, result['id']) - self.assertEqual(fake_db_aggregate_hosts, expected['hosts']) + self.assertEqual(fake_db_aggregate_hosts, expected.hosts) self.assertEqual(fake_db_aggregate_metadata, expected['metadetails']) def test_aggregate_get_from_db_raise_not_found(self): @@ -103,6 +104,56 @@ class _TestAggregateObject(object): aggregate._aggregate_get_from_db, self.context, aggregate_id) + def test_aggregate_get_all_from_db(self): + for c in range(3): + _create_aggregate(self.context, + values={'name': 'fake_aggregate_%d' % c}) + results = aggregate._get_all_from_db(self.context) + self.assertEqual(len(results), 3) + + def test_aggregate_get_by_host_from_db(self): + _create_aggregate_with_hosts(self.context, + values={'name': 'fake_aggregate_1'}, + hosts=['host.1.openstack.org']) + _create_aggregate_with_hosts(self.context, + values={'name': 'fake_aggregate_2'}, + hosts=['host.1.openstack.org']) + _create_aggregate(self.context, + values={'name': 'no_host_aggregate'}) + rh1 = aggregate._get_all_from_db(self.context) + rh2 = aggregate._get_by_host_from_db(self.context, + 'host.1.openstack.org') + self.assertEqual(3, len(rh1)) + self.assertEqual(2, len(rh2)) + + def test_aggregate_get_by_host_with_key_from_db(self): + ah1 = _create_aggregate_with_hosts(self.context, + values={'name': 'fake_aggregate_1'}, + metadata={'goodkey': 'good'}, + hosts=['host.1.openstack.org']) + _create_aggregate_with_hosts(self.context, + values={'name': 'fake_aggregate_2'}, + hosts=['host.1.openstack.org']) + rh1 = aggregate._get_by_host_from_db(self.context, + 'host.1.openstack.org', + key='goodkey') + self.assertEqual(1, len(rh1)) + self.assertEqual(ah1['id'], rh1[0]['id']) + + def test_aggregate_get_by_metadata_key_from_db(self): + _create_aggregate(self.context, + values={'name': 'aggregate_1'}, + metadata={'goodkey': 'good'}) + _create_aggregate(self.context, + values={'name': 'aggregate_2'}, + metadata={'goodkey': 'bad'}) + _create_aggregate(self.context, + values={'name': 'aggregate_3'}, + metadata={'badkey': 'good'}) + rl1 = aggregate._get_by_metadata_key_from_db(self.context, + key='goodkey') + self.assertEqual(2, len(rl1)) + @mock.patch('nova.objects.aggregate._aggregate_get_from_db') @mock.patch('nova.db.aggregate_get') def test_get_by_id_from_api(self, mock_get, mock_get_api): @@ -247,29 +298,36 @@ class _TestAggregateObject(object): agg.metadata = {'availability_zone': 'foo'} self.assertEqual('foo', agg.availability_zone) - def test_get_all(self): - self.mox.StubOutWithMock(db, 'aggregate_get_all') - db.aggregate_get_all(self.context).AndReturn([fake_aggregate]) - self.mox.ReplayAll() + @mock.patch('nova.objects.aggregate._get_all_from_db') + @mock.patch('nova.db.aggregate_get_all') + def test_get_all(self, mock_get_all, mock_api_get_all): + mock_get_all.return_value = [fake_aggregate] + mock_api_get_all.return_value = [fake_api_aggregate] aggs = aggregate.AggregateList.get_all(self.context) - self.assertEqual(1, len(aggs)) + self.assertEqual(2, len(aggs)) self.compare_obj(aggs[0], fake_aggregate, subs=SUBS) + self.compare_obj(aggs[1], fake_api_aggregate, subs=SUBS) - def test_by_host(self): - self.mox.StubOutWithMock(db, 'aggregate_get_by_host') - db.aggregate_get_by_host(self.context, 'fake-host', key=None, - ).AndReturn([fake_aggregate]) - self.mox.ReplayAll() + @mock.patch('nova.objects.aggregate._get_by_host_from_db') + @mock.patch('nova.db.aggregate_get_by_host') + def test_by_host(self, mock_get_by_host, mock_api_get_by_host): + mock_get_by_host.return_value = [fake_aggregate] + mock_api_get_by_host.return_value = [fake_api_aggregate] aggs = aggregate.AggregateList.get_by_host(self.context, 'fake-host') - self.assertEqual(1, len(aggs)) + self.assertEqual(2, len(aggs)) self.compare_obj(aggs[0], fake_aggregate, subs=SUBS) + self.compare_obj(aggs[1], fake_api_aggregate, subs=SUBS) + @mock.patch('nova.objects.aggregate._get_by_metadata_key_from_db') @mock.patch('nova.db.aggregate_get_by_metadata_key') - def test_get_by_metadata_key(self, get_by_metadata_key): - get_by_metadata_key.return_value = [fake_aggregate] + def test_get_by_metadata_key(self, + mock_get_by_metadata_key, + mock_api_get_by_metadata_key): + mock_get_by_metadata_key.return_value = [fake_aggregate] + mock_api_get_by_metadata_key.return_value = [fake_api_aggregate] aggs = aggregate.AggregateList.get_by_metadata_key( self.context, 'this') - self.assertEqual(1, len(aggs)) + self.assertEqual(2, len(aggs)) self.compare_obj(aggs[0], fake_aggregate, subs=SUBS) @mock.patch('nova.db.aggregate_get_by_metadata_key')