Clean up _make_*_list in object models to use base.obj_make_list

There are a lot of places where _make_foo_list()s are implemented.
Actually base.obj_make_list() can be used to make the list.

The patch is to clean up the redundancies and make the code clean.

Also, in order to make _from_db_objects()s the same format, we add
'context' into the parameter list in instance fault object.

Related to blueprint icehouse-objects.

Change-Id: I1a68b6e9dc402d4e30a06f94b068233e2afa713f
This commit is contained in:
Shane Wang 2013-12-26 16:28:53 +08:00
parent 54b829ad55
commit 07c1aeb560
9 changed files with 58 additions and 82 deletions

View File

@ -279,9 +279,9 @@ class Instance(base.NovaPersistentObject, base.NovaObject):
context, instance.uuid))
if 'pci_devices' in expected_attrs:
pci_devices = pci_device._make_pci_list(
pci_devices = base.obj_make_list(
context, pci_device.PciDeviceList(),
db_inst['pci_devices'])
pci_device.PciDevice, db_inst['pci_devices'])
instance['pci_devices'] = pci_devices
if 'info_cache' in expected_attrs:
if db_inst['info_cache'] is None:
@ -294,9 +294,9 @@ class Instance(base.NovaPersistentObject, base.NovaObject):
instance_info_cache.InstanceInfoCache._from_db_object(
context, instance.info_cache, db_inst['info_cache'])
if 'security_groups' in expected_attrs:
sec_groups = security_group._make_secgroup_list(
sec_groups = base.obj_make_list(
context, security_group.SecurityGroupList(),
db_inst['security_groups'])
security_group.SecurityGroup, db_inst['security_groups'])
instance['security_groups'] = sec_groups
instance._context = context

View File

@ -12,6 +12,8 @@
# License for the specific language governing permissions and limitations
# under the License.
import itertools
from nova import db
from nova.objects import base
from nova.objects import fields
@ -32,10 +34,11 @@ class InstanceFault(base.NovaPersistentObject, base.NovaObject):
}
@staticmethod
def _from_db_object(fault, db_fault):
def _from_db_object(context, fault, db_fault):
# NOTE(danms): These are identical right now
for key in fault.fields:
fault[key] = db_fault[key]
fault._context = context
fault.obj_reset_changes()
return fault
@ -44,17 +47,8 @@ class InstanceFault(base.NovaPersistentObject, base.NovaObject):
db_faults = db.instance_fault_get_by_instance_uuids(context,
[instance_uuid])
if instance_uuid in db_faults and db_faults[instance_uuid]:
return cls._from_db_object(cls(), db_faults[instance_uuid][0])
def _make_fault_list(faultlist, db_faultlist):
faultlist.objects = []
for instance_uuid in db_faultlist:
for db_fault in db_faultlist[instance_uuid]:
faultlist.objects.append(InstanceFault._from_db_object(
InstanceFault(), db_fault))
faultlist.obj_reset_changes()
return faultlist
return cls._from_db_object(context, cls(),
db_faults[instance_uuid][0])
class InstanceFaultList(base.ObjectListBase, base.NovaObject):
@ -72,6 +66,8 @@ class InstanceFaultList(base.ObjectListBase, base.NovaObject):
@base.remotable_classmethod
def get_by_instance_uuids(cls, context, instance_uuids):
db_faults = db.instance_fault_get_by_instance_uuids(context,
instance_uuids)
return _make_fault_list(cls(), db_faults)
db_faultdict = db.instance_fault_get_by_instance_uuids(context,
instance_uuids)
db_faultlist = itertools.chain(*db_faultdict.values())
return base.obj_make_list(context, InstanceFaultList(), InstanceFault,
db_faultlist)

View File

@ -110,16 +110,6 @@ class InstanceGroup(base.NovaPersistentObject, base.NovaObject):
self.obj_reset_changes()
def _make_instance_group_list(context, inst_list, db_list):
inst_list.objects = []
for group in db_list:
inst_obj = InstanceGroup._from_db_object(context, InstanceGroup(),
group)
inst_list.objects.append(inst_obj)
inst_list.obj_reset_changes()
return inst_list
class InstanceGroupList(base.ObjectListBase, base.NovaObject):
# Version 1.0: Initial version
# InstanceGroup <= version 1.3
@ -136,9 +126,11 @@ class InstanceGroupList(base.ObjectListBase, base.NovaObject):
@base.remotable_classmethod
def get_by_project_id(cls, context, project_id):
groups = db.instance_group_get_all_by_project_id(context, project_id)
return _make_instance_group_list(context, cls(), groups)
return base.obj_make_list(context, InstanceGroupList(), InstanceGroup,
groups)
@base.remotable_classmethod
def get_all(cls, context):
groups = db.instance_group_get_all(context)
return _make_instance_group_list(context, cls(), groups)
return base.obj_make_list(context, InstanceGroupList(), InstanceGroup,
groups)

View File

@ -80,15 +80,6 @@ class Migration(base.NovaPersistentObject, base.NovaObject):
self.instance_uuid)
def _make_list(context, list_obj, item_cls, db_list):
list_obj.objects = []
for db_item in db_list:
item = item_cls._from_db_object(context, item_cls(), db_item)
list_obj.objects.append(item)
list_obj.obj_reset_changes()
return list_obj
class MigrationList(base.ObjectListBase, base.NovaObject):
# Version 1.0: Initial version
# Migration <= 1.1
@ -109,15 +100,18 @@ class MigrationList(base.ObjectListBase, base.NovaObject):
dest_compute, use_slave=False):
db_migrations = db.migration_get_unconfirmed_by_dest_compute(
context, confirm_window, dest_compute, use_slave=use_slave)
return _make_list(context, MigrationList(), Migration, db_migrations)
return base.obj_make_list(context, MigrationList(), Migration,
db_migrations)
@base.remotable_classmethod
def get_in_progress_by_host_and_node(cls, context, host, node):
db_migrations = db.migration_get_in_progress_by_host_and_node(
context, host, node)
return _make_list(context, MigrationList(), Migration, db_migrations)
return base.obj_make_list(context, MigrationList(), Migration,
db_migrations)
@base.remotable_classmethod
def get_by_filters(cls, context, filters):
db_migrations = db.migration_get_all_by_filters(context, filters)
return _make_list(context, MigrationList(), Migration, db_migrations)
return base.obj_make_list(context, MigrationList(), Migration,
db_migrations)

View File

@ -237,16 +237,6 @@ class PciDevice(base.NovaPersistentObject, base.NovaObject):
self._from_db_object(context, self, db_pci)
def _make_pci_list(context, pci_list, db_list):
pci_list.objects = []
for pci in db_list:
pci_obj = PciDevice._from_db_object(context, PciDevice(), pci)
pci_list.objects.append(pci_obj)
pci_list.obj_reset_changes()
return pci_list
class PciDeviceList(base.ObjectListBase, base.NovaObject):
# Version 1.0: Initial version
# PciDevice <= 1.1
@ -268,9 +258,11 @@ class PciDeviceList(base.ObjectListBase, base.NovaObject):
@base.remotable_classmethod
def get_by_compute_node(cls, context, node_id):
db_dev_list = db.pci_device_get_all_by_node(context, node_id)
return _make_pci_list(context, cls(), db_dev_list)
return base.obj_make_list(context, PciDeviceList(), PciDevice,
db_dev_list)
@base.remotable_classmethod
def get_by_instance_uuid(cls, context, uuid):
db_dev_list = db.pci_device_get_all_by_instance_uuid(context, uuid)
return _make_pci_list(context, cls(), db_dev_list)
return base.obj_make_list(context, PciDeviceList(), PciDevice,
db_dev_list)

View File

@ -70,17 +70,6 @@ class SecurityGroup(base.NovaPersistentObject, base.NovaObject):
self.id))
def _make_secgroup_list(context, secgroup_list, db_secgroup_list):
secgroup_list.objects = []
for db_secgroup in db_secgroup_list:
secgroup = SecurityGroup._from_db_object(context, SecurityGroup(),
db_secgroup)
secgroup._context = context
secgroup_list.objects.append(secgroup)
secgroup_list.obj_reset_changes()
return secgroup_list
class SecurityGroupList(base.ObjectListBase, base.NovaObject):
# Version 1.0: Initial version
# SecurityGroup <= version 1.1
@ -101,20 +90,21 @@ class SecurityGroupList(base.ObjectListBase, base.NovaObject):
@base.remotable_classmethod
def get_all(cls, context):
return _make_secgroup_list(context, cls(),
db.security_group_get_all(context))
groups = db.security_group_get_all(context)
return base.obj_make_list(context, SecurityGroupList(), SecurityGroup,
groups)
@base.remotable_classmethod
def get_by_project(cls, context, project_id):
return _make_secgroup_list(context, cls(),
db.security_group_get_by_project(
context, project_id))
groups = db.security_group_get_by_project(context, project_id)
return base.obj_make_list(context, SecurityGroupList(), SecurityGroup,
groups)
@base.remotable_classmethod
def get_by_instance(cls, context, instance):
return _make_secgroup_list(context, cls(),
db.security_group_get_by_instance(
context, instance.uuid))
groups = db.security_group_get_by_instance(context, instance.uuid)
return base.obj_make_list(context, SecurityGroupList(), SecurityGroup,
groups)
def make_secgroup_list(security_groups):

View File

@ -2619,7 +2619,8 @@ class ServersViewBuilderTest(test.TestCase):
def test_build_server_detail_with_fault(self):
self.instance['vm_state'] = vm_states.ERROR
self.instance['fault'] = fake_instance.fake_fault_obj(self.uuid)
self.instance['fault'] = fake_instance.fake_fault_obj(
self.request.context, self.uuid)
image_bookmark = "http://localhost:9292/images/5"
flavor_bookmark = "http://localhost/flavors/1"
@ -2688,7 +2689,8 @@ class ServersViewBuilderTest(test.TestCase):
def test_build_server_detail_with_fault_that_has_been_deleted(self):
self.instance['deleted'] = 1
self.instance['vm_state'] = vm_states.ERROR
fault = fake_instance.fake_fault_obj(self.uuid, code=500,
fault = fake_instance.fake_fault_obj(self.request.context,
self.uuid, code=500,
message="No valid host was found")
self.instance['fault'] = fault
@ -2706,6 +2708,7 @@ class ServersViewBuilderTest(test.TestCase):
def test_build_server_detail_with_fault_no_details_not_admin(self):
self.instance['vm_state'] = vm_states.ERROR
self.instance['fault'] = fake_instance.fake_fault_obj(
self.request.context,
self.uuid,
code=500,
message='Error')
@ -2722,6 +2725,7 @@ class ServersViewBuilderTest(test.TestCase):
def test_build_server_detail_with_fault_admin(self):
self.instance['vm_state'] = vm_states.ERROR
self.instance['fault'] = fake_instance.fake_fault_obj(
self.request.context,
self.uuid,
code=500,
message='Error')
@ -2739,6 +2743,7 @@ class ServersViewBuilderTest(test.TestCase):
def test_build_server_detail_with_fault_no_details_admin(self):
self.instance['vm_state'] = vm_states.ERROR
self.instance['fault'] = fake_instance.fake_fault_obj(
self.request.context,
self.uuid,
code=500,
message='Error',
@ -2756,7 +2761,8 @@ class ServersViewBuilderTest(test.TestCase):
def test_build_server_detail_with_fault_but_active(self):
self.instance['vm_state'] = vm_states.ACTIVE
self.instance['progress'] = 100
self.instance['fault'] = fake_instance.fake_fault_obj(self.uuid)
self.instance['fault'] = fake_instance.fake_fault_obj(
self.request.context, self.uuid)
output = self.view_builder.show(self.request, self.instance)
self.assertNotIn('fault', output['server'])

View File

@ -4008,7 +4008,8 @@ class ServersViewBuilderTest(test.TestCase):
def test_build_server_detail_with_fault(self):
self.instance['vm_state'] = vm_states.ERROR
self.instance['fault'] = fake_instance.fake_fault_obj(self.uuid)
self.instance['fault'] = fake_instance.fake_fault_obj(
self.request.context, self.uuid)
self.expected_detailed_server["server"]["status"] = "ERROR"
self.expected_detailed_server["server"]["fault"] = {
@ -4027,7 +4028,8 @@ class ServersViewBuilderTest(test.TestCase):
def test_build_server_detail_with_fault_that_has_been_deleted(self):
self.instance['deleted'] = 1
self.instance['vm_state'] = vm_states.ERROR
fault = fake_instance.fake_fault_obj(self.uuid, code=500,
fault = fake_instance.fake_fault_obj(self.request.context,
self.uuid, code=500,
message="No valid host was found")
self.instance['fault'] = fault
@ -4048,6 +4050,7 @@ class ServersViewBuilderTest(test.TestCase):
def test_build_server_detail_with_fault_no_details_not_admin(self):
self.instance['vm_state'] = vm_states.ERROR
self.instance['fault'] = fake_instance.fake_fault_obj(
self.request.context,
self.uuid,
code=500,
message='Error')
@ -4064,6 +4067,7 @@ class ServersViewBuilderTest(test.TestCase):
def test_build_server_detail_with_fault_admin(self):
self.instance['vm_state'] = vm_states.ERROR
self.instance['fault'] = fake_instance.fake_fault_obj(
self.request.context,
self.uuid,
code=500,
message='Error')
@ -4081,6 +4085,7 @@ class ServersViewBuilderTest(test.TestCase):
def test_build_server_detail_with_fault_no_details_admin(self):
self.instance['vm_state'] = vm_states.ERROR
self.instance['fault'] = fake_instance.fake_fault_obj(
self.request.context,
self.uuid,
code=500,
message='Error',
@ -4098,7 +4103,8 @@ class ServersViewBuilderTest(test.TestCase):
def test_build_server_detail_with_fault_but_active(self):
self.instance['vm_state'] = vm_states.ACTIVE
self.instance['progress'] = 100
self.instance['fault'] = fake_instance.fake_fault_obj(self.uuid)
self.instance['fault'] = fake_instance.fake_fault_obj(
self.request.context, self.uuid)
output = self.view_builder.show(self.request, self.instance)
self.assertNotIn('fault', output['server'])

View File

@ -85,7 +85,7 @@ def fake_instance_obj(context, **updates):
expected_attrs=expected_attrs)
def fake_fault_obj(instance_uuid, code=404,
def fake_fault_obj(context, instance_uuid, code=404,
message='HTTPNotFound',
details='Stock details for test',
**updates):
@ -103,6 +103,6 @@ def fake_fault_obj(instance_uuid, code=404,
}
if updates:
fault.update(updates)
return inst_fault_obj.InstanceFault._from_db_object(
return inst_fault_obj.InstanceFault._from_db_object(context,
inst_fault_obj.InstanceFault(),
fault)