From 8a7d0db7302f225a841f55e1c79a2a099f68e7c4 Mon Sep 17 00:00:00 2001 From: Dan Smith Date: Fri, 20 Sep 2013 11:16:38 -0700 Subject: [PATCH] Make Instance.refresh() extra careful about recursive loads This simply makes sure that the instance copy that we fetch to refresh an instance can not trigger lazy loads itself due to a bug in get_by_uuid() or below. This hasn't actually been the cause of any recent bugs, but this makes it much more careful. It also makes sure that obj_load_attr() refuses to do its work if self._context is None, which was previously only being enforced for remoted calls. Change-Id: I76f433af34620045b079e0a2c30c60b28a4ce525 --- nova/objects/instance.py | 8 ++++++++ nova/objects/service.py | 3 +++ nova/tests/objects/test_instance.py | 15 +++++++++++++++ nova/tests/objects/test_service.py | 7 +++++++ 4 files changed, 33 insertions(+) diff --git a/nova/objects/instance.py b/nova/objects/instance.py index 7a7528e411cc..8b7aaa5d3750 100644 --- a/nova/objects/instance.py +++ b/nova/objects/instance.py @@ -462,6 +462,11 @@ class Instance(base.NovaPersistentObject, base.NovaObject): if self.obj_attr_is_set(field)] current = self.__class__.get_by_uuid(context, uuid=self.uuid, expected_attrs=extra) + # NOTE(danms): We orphan the instance copy so we do not unexpectedly + # trigger a lazy-load (which would mean we failed to calculate the + # expected_attrs properly) + current._context = None + for field in self.fields: if self.obj_attr_is_set(field) and self[field] != current[field]: self[field] = current[field] @@ -472,6 +477,9 @@ class Instance(base.NovaPersistentObject, base.NovaObject): raise exception.ObjectActionError( action='obj_load_attr', reason='attribute %s not lazy-loadable' % attrname) + if not self._context: + raise exception.OrphanedObjectError(method='obj_load_attr', + objtype=self.obj_name()) LOG.debug(_("Lazy-loading `%(attr)s' on %(name) uuid %(uuid)s"), {'attr': attrname, diff --git a/nova/objects/service.py b/nova/objects/service.py index 6d4ec0173925..4130af61ea65 100644 --- a/nova/objects/service.py +++ b/nova/objects/service.py @@ -70,6 +70,9 @@ class Service(base.NovaPersistentObject, base.NovaObject): return service def obj_load_attr(self, attrname): + if not self._context: + raise exception.OrphanedObjectError(method='obj_load_attr', + objtype=self.obj_name()) LOG.debug(_("Lazy-loading `%(attr)s' on %(name)s id %(id)s"), {'attr': attrname, 'name': self.obj_name(), diff --git a/nova/tests/objects/test_instance.py b/nova/tests/objects/test_instance.py index 58bec436505a..d4d6295f054f 100644 --- a/nova/tests/objects/test_instance.py +++ b/nova/tests/objects/test_instance.py @@ -171,6 +171,7 @@ class _TestInstanceObject(object): def test_load_invalid(self): inst = instance.Instance() + inst._context = self.context inst.uuid = 'fake-uuid' self.assertRaises(exception.ObjectActionError, inst.obj_load_attr, 'foo') @@ -214,6 +215,20 @@ class _TestInstanceObject(object): self.assertRemotes() self.assertEqual(set([]), inst.obj_what_changed()) + def test_refresh_does_not_recurse(self): + inst = instance.Instance() + inst._context = self.context + inst.uuid = 'fake-uuid' + inst.metadata = {} + inst_copy = instance.Instance() + inst_copy.uuid = inst.uuid + self.mox.StubOutWithMock(instance.Instance, 'get_by_uuid') + instance.Instance.get_by_uuid(self.context, uuid=inst.uuid, + expected_attrs=['metadata'] + ).AndReturn(inst_copy) + self.mox.ReplayAll() + self.assertRaises(exception.OrphanedObjectError, inst.refresh) + def _save_test_helper(self, cell_type, save_kwargs): """Common code for testing save() for cells/non-cells.""" if cell_type: diff --git a/nova/tests/objects/test_service.py b/nova/tests/objects/test_service.py index 66888848e9fe..c3a9b84a37d6 100644 --- a/nova/tests/objects/test_service.py +++ b/nova/tests/objects/test_service.py @@ -13,6 +13,7 @@ # under the License. from nova import db +from nova import exception from nova.objects import service from nova.openstack.common import timeutils from nova.tests.objects import test_compute_node @@ -165,6 +166,12 @@ class _TestServiceObject(object): # Make sure it doesn't re-fetch this service_obj.compute_node + def test_load_when_orphaned(self): + service_obj = service.Service() + service_obj.id = 123 + self.assertRaises(exception.OrphanedObjectError, + getattr, service_obj, 'compute_node') + class TestServiceObject(test_objects._LocalTest, _TestServiceObject):