Clean up duplicated change-building code in objects

Almost every object thus far has done this:

  changes = {}
  for key in self.obj_what_changed():
      changes[key] = self[key]

to get a dict of updates to apply to the database. This patch adds
that as part of the base object and makes every place that does
the above to just use that.

Change-Id: I847f5d35181b0305668b107f86faa164e71c3375
This commit is contained in:
Dan Smith
2013-09-05 16:37:22 -07:00
parent 45b28ecb4f
commit 2f981fe643
11 changed files with 34 additions and 47 deletions

View File

@@ -55,9 +55,7 @@ class Aggregate(base.NovaPersistentObject, base.NovaObject):
@base.remotable @base.remotable
def create(self, context): def create(self, context):
self._assert_no_hosts('create') self._assert_no_hosts('create')
updates = {} updates = self.obj_get_changes()
for key in self.obj_what_changed():
updates[key] = self[key]
payload = dict(updates) payload = dict(updates)
if 'metadata' in updates: if 'metadata' in updates:
# NOTE(danms): For some reason the notification format is weird # NOTE(danms): For some reason the notification format is weird
@@ -76,9 +74,7 @@ class Aggregate(base.NovaPersistentObject, base.NovaObject):
@base.remotable @base.remotable
def save(self, context): def save(self, context):
self._assert_no_hosts('save') self._assert_no_hosts('save')
updates = {} updates = self.obj_get_changes()
for key in self.obj_what_changed():
updates[key] = self[key]
payload = {'aggregate_id': self.id} payload = {'aggregate_id': self.id}
if 'metadata' in updates: if 'metadata' in updates:

View File

@@ -317,6 +317,13 @@ class NovaObject(object):
"""Returns a set of fields that have been modified.""" """Returns a set of fields that have been modified."""
return self._changed_fields return self._changed_fields
def obj_get_changes(self):
"""Returns a dict of changed fields and their new values."""
changes = {}
for key in self.obj_what_changed():
changes[key] = self[key]
return changes
def obj_reset_changes(self, fields=None): def obj_reset_changes(self, fields=None):
"""Reset the list of fields that have been changed. """Reset the list of fields that have been changed.

View File

@@ -62,17 +62,13 @@ class ComputeNode(base.NovaPersistentObject, base.NovaObject):
@base.remotable @base.remotable
def create(self, context): def create(self, context):
updates = {} updates = self.obj_get_changes()
for key in self.obj_what_changed():
updates[key] = self[key]
db_compute = db.compute_node_create(context, updates) db_compute = db.compute_node_create(context, updates)
self._from_db_object(context, self, db_compute) self._from_db_object(context, self, db_compute)
@base.remotable @base.remotable
def save(self, context, prune_stats=False): def save(self, context, prune_stats=False):
updates = {} updates = self.obj_get_changes()
for key in self.obj_what_changed():
updates[key] = self[key]
updates.pop('id', None) updates.pop('id', None)
db_compute = db.compute_node_update(context, self.id, updates, db_compute = db.compute_node_update(context, self.id, updates,
prune_stats=prune_stats) prune_stats=prune_stats)

View File

@@ -319,9 +319,8 @@ class Instance(base.NovaPersistentObject, base.NovaObject):
if self.obj_attr_is_set('id'): if self.obj_attr_is_set('id'):
raise exception.ObjectActionError(action='create', raise exception.ObjectActionError(action='create',
reason='already created') reason='already created')
updates = {} updates = self.obj_get_changes()
for attr in self.obj_what_changed() - set(['id']): updates.pop('id', None)
updates[attr] = self[attr]
expected_attrs = [attr for attr in INSTANCE_DEFAULT_FIELDS expected_attrs = [attr for attr in INSTANCE_DEFAULT_FIELDS
if attr in updates] if attr in updates]
if 'security_groups' in updates: if 'security_groups' in updates:

View File

@@ -64,10 +64,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject):
def save(self, context): def save(self, context):
"""Save updates to this instance group.""" """Save updates to this instance group."""
updates = {} updates = self.obj_get_changes()
for key in self.obj_what_changed():
updates[key] = self[key]
if not updates: if not updates:
return return
@@ -95,10 +92,7 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject):
if self.obj_attr_is_set('id'): if self.obj_attr_is_set('id'):
raise exception.ObjectActionError(action='create', raise exception.ObjectActionError(action='create',
reason='already created') reason='already created')
updates = {} updates = self.obj_get_changes()
for attr in self.obj_what_changed():
updates[attr] = self[attr]
updates.pop('id', None) updates.pop('id', None)
policies = updates.pop('policies', None) policies = updates.pop('policies', None)
members = updates.pop('members', None) members = updates.pop('members', None)

View File

@@ -45,9 +45,7 @@ class KeyPair(base.NovaPersistentObject, base.NovaObject):
@base.remotable @base.remotable
def create(self, context): def create(self, context):
updates = {} updates = self.obj_get_changes()
for key in self.obj_what_changed():
updates[key] = self[key]
db_keypair = db.key_pair_create(context, updates) db_keypair = db.key_pair_create(context, updates)
self._from_db_object(context, self, db_keypair) self._from_db_object(context, self, db_keypair)

View File

@@ -53,18 +53,14 @@ class Migration(base.NovaPersistentObject, base.NovaObject):
@base.remotable @base.remotable
def create(self, context): def create(self, context):
updates = {} updates = self.obj_get_changes()
for key in self.obj_what_changed():
updates[key] = self[key]
updates.pop('id', None) updates.pop('id', None)
db_migration = db.migration_create(context, updates) db_migration = db.migration_create(context, updates)
self._from_db_object(context, self, db_migration) self._from_db_object(context, self, db_migration)
@base.remotable @base.remotable
def save(self, context): def save(self, context):
updates = {} updates = self.obj_get_changes()
for key in self.obj_what_changed():
updates[key] = self[key]
updates.pop('id', None) updates.pop('id', None)
db_migration = db.migration_update(context, self.id, updates) db_migration = db.migration_update(context, self.id, updates)
self._from_db_object(context, self, db_migration) self._from_db_object(context, self, db_migration)

View File

@@ -230,12 +230,9 @@ class PciDevice(base.NovaPersistentObject, base.NovaObject):
self.status = 'deleted' self.status = 'deleted'
db.pci_device_destroy(context, self.compute_node_id, self.address) db.pci_device_destroy(context, self.compute_node_id, self.address)
elif self.status != 'deleted': elif self.status != 'deleted':
updates = {} updates = self.obj_get_changes()
for field in self.obj_what_changed(): if 'extra_info' in updates:
if field == 'extra_info': updates['extra_info'] = jsonutils.dumps(updates['extra_info'])
updates['extra_info'] = jsonutils.dumps(self.extra_info)
else:
updates[field] = self[field]
if updates: if updates:
db_pci = db.pci_device_update(context, self.compute_node_id, db_pci = db.pci_device_update(context, self.compute_node_id,
self.address, updates) self.address, updates)

View File

@@ -51,9 +51,7 @@ class SecurityGroup(base.NovaPersistentObject, base.NovaObject):
@base.remotable @base.remotable
def save(self, context): def save(self, context):
updates = {} updates = self.obj_get_changes()
for field in self.obj_what_changed():
updates[field] = self[field]
if updates: if updates:
db_secgroup = db.security_group_update(context, self.id, updates) db_secgroup = db.security_group_update(context, self.id, updates)
SecurityGroup._from_db_object(self, db_secgroup) SecurityGroup._from_db_object(self, db_secgroup)

View File

@@ -99,17 +99,13 @@ class Service(base.NovaPersistentObject, base.NovaObject):
@base.remotable @base.remotable
def create(self, context): def create(self, context):
updates = {} updates = self.obj_get_changes()
for key in self.obj_what_changed():
updates[key] = self[key]
db_service = db.service_create(context, updates) db_service = db.service_create(context, updates)
self._from_db_object(context, self, db_service) self._from_db_object(context, self, db_service)
@base.remotable @base.remotable
def save(self, context): def save(self, context):
updates = {} updates = self.obj_get_changes()
for key in self.obj_what_changed():
updates[key] = self[key]
updates.pop('id', None) updates.pop('id', None)
db_service = db.service_update(context, self.id, updates) db_service = db.service_update(context, self.id, updates)
self._from_db_object(context, self, db_service) self._from_db_object(context, self, db_service)

View File

@@ -557,6 +557,16 @@ class _TestObject(object):
self.assertEqual(set(myobj_fields) | set(myobj3_fields), self.assertEqual(set(myobj_fields) | set(myobj3_fields),
set(TestSubclassedObject.fields.keys())) set(TestSubclassedObject.fields.keys()))
def test_get_changes(self):
obj = MyObj()
self.assertEqual({}, obj.obj_get_changes())
obj.foo = 123
self.assertEqual({'foo': 123}, obj.obj_get_changes())
obj.bar = 'test'
self.assertEqual({'foo': 123, 'bar': 'test'}, obj.obj_get_changes())
obj.obj_reset_changes()
self.assertEqual({}, obj.obj_get_changes())
class TestObject(_LocalTest, _TestObject): class TestObject(_LocalTest, _TestObject):
pass pass