diff --git a/nova/objects/aggregate.py b/nova/objects/aggregate.py index a9ea5d41e..58347c962 100644 --- a/nova/objects/aggregate.py +++ b/nova/objects/aggregate.py @@ -47,6 +47,21 @@ def _aggregate_get_from_db(context, aggregate_id): return aggregate +@db_api.api_context_manager.reader +def _aggregate_get_from_db_by_uuid(context, aggregate_uuid): + query = context.session.query(api_models.Aggregate).\ + options(joinedload('_hosts')).\ + options(joinedload('_metadata')) + query = query.filter(api_models.Aggregate.uuid == aggregate_uuid) + + aggregate = query.first() + + if not aggregate: + raise exception.AggregateNotFound(aggregate_id=aggregate_uuid) + + return aggregate + + @base.NovaObjectRegistry.register class Aggregate(base.NovaPersistentObject, base.NovaObject): # Version 1.0: Initial version @@ -121,7 +136,11 @@ class Aggregate(base.NovaPersistentObject, base.NovaObject): @base.remotable_classmethod def get_by_uuid(cls, context, aggregate_uuid): - db_aggregate = db.aggregate_get_by_uuid(context, aggregate_uuid) + try: + db_aggregate = _aggregate_get_from_db_by_uuid(context, + aggregate_uuid) + except exception.AggregateNotFound: + db_aggregate = db.aggregate_get_by_uuid(context, aggregate_uuid) return cls._from_db_object(context, cls(), db_aggregate) @base.remotable diff --git a/nova/tests/unit/objects/test_aggregate.py b/nova/tests/unit/objects/test_aggregate.py index 2d63064f1..9a95c4779 100644 --- a/nova/tests/unit/objects/test_aggregate.py +++ b/nova/tests/unit/objects/test_aggregate.py @@ -98,6 +98,14 @@ class _TestAggregateObject(object): self.assertEqual(fake_db_aggregate_hosts, expected.hosts) self.assertEqual(fake_db_aggregate_metadata, expected['metadetails']) + def test_aggregate_get_from_db_by_uuid(self): + result = _create_aggregate_with_hosts(self.context) + expected = aggregate._aggregate_get_from_db_by_uuid( + self.context, result['uuid']) + self.assertEqual(result.uuid, expected.uuid) + self.assertEqual(fake_db_aggregate_hosts, expected.hosts) + self.assertEqual(fake_db_aggregate_metadata, expected['metadetails']) + def test_aggregate_get_from_db_raise_not_found(self): aggregate_id = 5 self.assertRaises(exception.AggregateNotFound, @@ -192,14 +200,27 @@ class _TestAggregateObject(object): self.assertEqual(uuid, obj.uuid) mock_save.assert_called_once_with() + @mock.patch('nova.objects.aggregate._aggregate_get_from_db_by_uuid') @mock.patch('nova.db.aggregate_get_by_uuid') - def test_get_by_uuid(self, get_by_uuid): + def test_get_by_uuid(self, get_by_uuid, get_by_uuid_api): + get_by_uuid_api.side_effect = exception.AggregateNotFound( + aggregate_id=123) get_by_uuid.return_value = fake_aggregate agg = aggregate.Aggregate.get_by_uuid(self.context, uuidsentinel.fake_aggregate) self.assertEqual(uuidsentinel.fake_aggregate, agg.uuid) self.assertEqual(fake_aggregate['id'], agg.id) + @mock.patch('nova.objects.aggregate._aggregate_get_from_db_by_uuid') + @mock.patch('nova.db.aggregate_get_by_uuid') + def test_get_by_uuid_from_api(self, get_by_uuid, get_by_uuid_api): + get_by_uuid_api.return_value = fake_aggregate + agg = aggregate.Aggregate.get_by_uuid(self.context, + uuidsentinel.fake_aggregate) + self.assertEqual(uuidsentinel.fake_aggregate, agg.uuid) + self.assertEqual(fake_aggregate['id'], agg.id) + self.assertFalse(get_by_uuid.called) + def test_create(self): self.mox.StubOutWithMock(db, 'aggregate_create') db.aggregate_create(self.context, {'name': 'foo',