Clean up duplicated change-building code in objects

Almost every object thus far has done this:

  changes = {}
  for key in self.obj_what_changed():
      changes[key] = self[key]

to get a dict of updates to apply to the database. This patch adds
that as part of the base object and makes every place that does
the above to just use that.

Change-Id: I847f5d35181b0305668b107f86faa164e71c3375
This commit is contained in:
Dan Smith
2013-09-05 16:37:22 -07:00
parent 45b28ecb4f
commit 2f981fe643
11 changed files with 34 additions and 47 deletions

View File

@@ -55,9 +55,7 @@ class Aggregate(base.NovaPersistentObject, base.NovaObject):
@base.remotable
def create(self, context):
self._assert_no_hosts('create')
updates = {}
for key in self.obj_what_changed():
updates[key] = self[key]
updates = self.obj_get_changes()
payload = dict(updates)
if 'metadata' in updates:
# NOTE(danms): For some reason the notification format is weird
@@ -76,9 +74,7 @@ class Aggregate(base.NovaPersistentObject, base.NovaObject):
@base.remotable
def save(self, context):
self._assert_no_hosts('save')
updates = {}
for key in self.obj_what_changed():
updates[key] = self[key]
updates = self.obj_get_changes()
payload = {'aggregate_id': self.id}
if 'metadata' in updates:

View File

@@ -317,6 +317,13 @@ class NovaObject(object):
"""Returns a set of fields that have been modified."""
return self._changed_fields
def obj_get_changes(self):
"""Returns a dict of changed fields and their new values."""
changes = {}
for key in self.obj_what_changed():
changes[key] = self[key]
return changes
def obj_reset_changes(self, fields=None):
"""Reset the list of fields that have been changed.

View File

@@ -62,17 +62,13 @@ class ComputeNode(base.NovaPersistentObject, base.NovaObject):
@base.remotable
def create(self, context):
updates = {}
for key in self.obj_what_changed():
updates[key] = self[key]
updates = self.obj_get_changes()
db_compute = db.compute_node_create(context, updates)
self._from_db_object(context, self, db_compute)
@base.remotable
def save(self, context, prune_stats=False):
updates = {}
for key in self.obj_what_changed():
updates[key] = self[key]
updates = self.obj_get_changes()
updates.pop('id', None)
db_compute = db.compute_node_update(context, self.id, updates,
prune_stats=prune_stats)

View File

@@ -319,9 +319,8 @@ class Instance(base.NovaPersistentObject, base.NovaObject):
if self.obj_attr_is_set('id'):
raise exception.ObjectActionError(action='create',
reason='already created')
updates = {}
for attr in self.obj_what_changed() - set(['id']):
updates[attr] = self[attr]
updates = self.obj_get_changes()
updates.pop('id', None)
expected_attrs = [attr for attr in INSTANCE_DEFAULT_FIELDS
if attr in updates]
if 'security_groups' in updates:

View File

@@ -64,10 +64,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject):
def save(self, context):
"""Save updates to this instance group."""
updates = {}
for key in self.obj_what_changed():
updates[key] = self[key]
updates = self.obj_get_changes()
if not updates:
return
@@ -95,10 +92,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject):
if self.obj_attr_is_set('id'):
raise exception.ObjectActionError(action='create',
reason='already created')
updates = {}
for attr in self.obj_what_changed():
updates[attr] = self[attr]
updates = self.obj_get_changes()
updates.pop('id', None)
policies = updates.pop('policies', None)
members = updates.pop('members', None)

View File

@@ -45,9 +45,7 @@ class KeyPair(base.NovaPersistentObject, base.NovaObject):
@base.remotable
def create(self, context):
updates = {}
for key in self.obj_what_changed():
updates[key] = self[key]
updates = self.obj_get_changes()
db_keypair = db.key_pair_create(context, updates)
self._from_db_object(context, self, db_keypair)

View File

@@ -53,18 +53,14 @@ class Migration(base.NovaPersistentObject, base.NovaObject):
@base.remotable
def create(self, context):
updates = {}
for key in self.obj_what_changed():
updates[key] = self[key]
updates = self.obj_get_changes()
updates.pop('id', None)
db_migration = db.migration_create(context, updates)
self._from_db_object(context, self, db_migration)
@base.remotable
def save(self, context):
updates = {}
for key in self.obj_what_changed():
updates[key] = self[key]
updates = self.obj_get_changes()
updates.pop('id', None)
db_migration = db.migration_update(context, self.id, updates)
self._from_db_object(context, self, db_migration)

View File

@@ -230,12 +230,9 @@ class PciDevice(base.NovaPersistentObject, base.NovaObject):
self.status = 'deleted'
db.pci_device_destroy(context, self.compute_node_id, self.address)
elif self.status != 'deleted':
updates = {}
for field in self.obj_what_changed():
if field == 'extra_info':
updates['extra_info'] = jsonutils.dumps(self.extra_info)
else:
updates[field] = self[field]
updates = self.obj_get_changes()
if 'extra_info' in updates:
updates['extra_info'] = jsonutils.dumps(updates['extra_info'])
if updates:
db_pci = db.pci_device_update(context, self.compute_node_id,
self.address, updates)

View File

@@ -51,9 +51,7 @@ class SecurityGroup(base.NovaPersistentObject, base.NovaObject):
@base.remotable
def save(self, context):
updates = {}
for field in self.obj_what_changed():
updates[field] = self[field]
updates = self.obj_get_changes()
if updates:
db_secgroup = db.security_group_update(context, self.id, updates)
SecurityGroup._from_db_object(self, db_secgroup)

View File

@@ -99,17 +99,13 @@ class Service(base.NovaPersistentObject, base.NovaObject):
@base.remotable
def create(self, context):
updates = {}
for key in self.obj_what_changed():
updates[key] = self[key]
updates = self.obj_get_changes()
db_service = db.service_create(context, updates)
self._from_db_object(context, self, db_service)
@base.remotable
def save(self, context):
updates = {}
for key in self.obj_what_changed():
updates[key] = self[key]
updates = self.obj_get_changes()
updates.pop('id', None)
db_service = db.service_update(context, self.id, updates)
self._from_db_object(context, self, db_service)

View File

@@ -557,6 +557,16 @@ class _TestObject(object):
self.assertEqual(set(myobj_fields) | set(myobj3_fields),
set(TestSubclassedObject.fields.keys()))
def test_get_changes(self):
obj = MyObj()
self.assertEqual({}, obj.obj_get_changes())
obj.foo = 123
self.assertEqual({'foo': 123}, obj.obj_get_changes())
obj.bar = 'test'
self.assertEqual({'foo': 123, 'bar': 'test'}, obj.obj_get_changes())
obj.obj_reset_changes()
self.assertEqual({}, obj.obj_get_changes())
class TestObject(_LocalTest, _TestObject):
pass