Generalize the _make_list() function for objects

Each object with a list duplicated the _make_list() method in its
own module. This removes that duplication and adds a generalized
helper in objects/base.py. The instance object still uses its own
because it has to do a bunch of other stuff in the loop for
efficiency.

Change-Id: Ic910f39087ebc167b2b930979f7951116caf8598
This commit is contained in:
Dan Smith
2013-08-26 11:39:44 -07:00
parent ff80ac8d03
commit ee600da2b6
8 changed files with 63 additions and 56 deletions

View File

@@ -22,6 +22,7 @@ from nova.cells import utils as cells_utils
from nova.compute import api as compute_api from nova.compute import api as compute_api
from nova.compute import rpcapi as compute_rpcapi from nova.compute import rpcapi as compute_rpcapi
from nova import exception from nova import exception
from nova.objects import base as obj_base
from nova.objects import service as service_obj from nova.objects import service as service_obj
from nova.openstack.common import excutils from nova.openstack.common import excutils
from nova import rpcclient from nova import rpcclient
@@ -480,7 +481,7 @@ class HostAPI(compute_api.HostAPI):
# NOTE(danms): Currently cells does not support objects as # NOTE(danms): Currently cells does not support objects as
# return values, so just convert the db-formatted service objects # return values, so just convert the db-formatted service objects
# to new-world objects here # to new-world objects here
return service_obj._make_list(context, return obj_base.obj_make_list(context,
service_obj.ServiceList(), service_obj.ServiceList(),
service_obj.Service, service_obj.Service,
services) services)

View File

@@ -144,22 +144,15 @@ class Aggregate(base.NovaObject):
return self.metadata.get('availability_zone', None) return self.metadata.get('availability_zone', None)
def _make_list(context, list_obj, item_cls, db_list):
list_obj.objects = []
for db_item in db_list:
item = item_cls._from_db_object(context, item_cls(), db_item)
list_obj.objects.append(item)
list_obj.obj_reset_changes()
return list_obj
class AggregateList(base.ObjectListBase, base.NovaObject): class AggregateList(base.ObjectListBase, base.NovaObject):
@base.remotable_classmethod @base.remotable_classmethod
def get_all(cls, context): def get_all(cls, context):
db_aggregates = db.aggregate_get_all(context) db_aggregates = db.aggregate_get_all(context)
return _make_list(context, AggregateList(), Aggregate, db_aggregates) return base.obj_make_list(context, AggregateList(), Aggregate,
db_aggregates)
@base.remotable_classmethod @base.remotable_classmethod
def get_by_host(cls, context, host): def get_by_host(cls, context, host):
db_aggregates = db.aggregate_get_by_host(context, host) db_aggregates = db.aggregate_get_by_host(context, host)
return _make_list(context, AggregateList(), Aggregate, db_aggregates) return base.obj_make_list(context, AggregateList(), Aggregate,
db_aggregates)

View File

@@ -528,3 +528,26 @@ def obj_to_primitive(obj):
return result return result
else: else:
return obj return obj
def obj_make_list(context, list_obj, item_cls, db_list, **extra_args):
"""Construct an object list from a list of primitives.
This calls item_cls._from_db_object() on each item of db_list, and
adds the resulting object to list_obj.
:param:context: Request contextr
:param:list_obj: An ObjectListBase object
:param:item_cls: The NovaObject class of the objects within the list
:param:db_list: The list of primitives to convert to objects
:param:extra_args: Extra arguments to pass to _from_db_object()
:returns: list_obj
"""
list_obj.objects = []
for db_item in db_list:
item = item_cls._from_db_object(context, item_cls(), db_item,
**extra_args)
list_obj.objects.append(item)
list_obj._context = context
list_obj.obj_reset_changes()
return list_obj

View File

@@ -92,23 +92,16 @@ class ComputeNode(base.NovaObject):
return self._cached_service return self._cached_service
def _make_list(context, list_obj, item_cls, db_list):
list_obj.objects = []
for db_item in db_list:
item = item_cls._from_db_object(context, item_cls(), db_item)
list_obj.objects.append(item)
list_obj.obj_reset_changes()
return list_obj
class ComputeNodeList(base.ObjectListBase, base.NovaObject): class ComputeNodeList(base.ObjectListBase, base.NovaObject):
@base.remotable_classmethod @base.remotable_classmethod
def get_all(cls, context): def get_all(cls, context):
db_computes = db.compute_node_get_all(context) db_computes = db.compute_node_get_all(context)
return _make_list(context, ComputeNodeList(), ComputeNode, db_computes) return base.obj_make_list(context, ComputeNodeList(), ComputeNode,
db_computes)
@base.remotable_classmethod @base.remotable_classmethod
def get_by_hypervisor(cls, context, hypervisor_match): def get_by_hypervisor(cls, context, hypervisor_match):
db_computes = db.compute_node_search_by_hypervisor(context, db_computes = db.compute_node_search_by_hypervisor(context,
hypervisor_match) hypervisor_match)
return _make_list(context, ComputeNodeList(), ComputeNode, db_computes) return base.obj_make_list(context, ComputeNodeList(), ComputeNode,
db_computes)

View File

@@ -77,20 +77,11 @@ class InstanceAction(base.NovaObject):
self._from_db_object(context, self, db_action) self._from_db_object(context, self, db_action)
def _make_list(context, list_obj, item_cls, db_list):
list_obj.objects = []
for db_item in db_list:
item = item_cls._from_db_object(context, item_cls(), db_item)
list_obj.objects.append(item)
list_obj.obj_reset_changes()
return list_obj
class InstanceActionList(base.ObjectListBase, base.NovaObject): class InstanceActionList(base.ObjectListBase, base.NovaObject):
@base.remotable_classmethod @base.remotable_classmethod
def get_by_instance_uuid(cls, context, instance_uuid): def get_by_instance_uuid(cls, context, instance_uuid):
db_actions = db.actions_get(context, instance_uuid) db_actions = db.actions_get(context, instance_uuid)
return _make_list(context, cls(), InstanceAction, db_actions) return base.obj_make_list(context, cls(), InstanceAction, db_actions)
class InstanceActionEvent(base.NovaObject): class InstanceActionEvent(base.NovaObject):
@@ -168,4 +159,5 @@ class InstanceActionEventList(base.ObjectListBase, base.NovaObject):
@base.remotable_classmethod @base.remotable_classmethod
def get_by_action(cls, context, action_id): def get_by_action(cls, context, action_id):
db_events = db.action_events_get(context, action_id) db_events = db.action_events_get(context, action_id)
return _make_list(context, cls(), InstanceActionEvent, db_events) return base.obj_make_list(context, cls(), InstanceActionEvent,
db_events)

View File

@@ -56,20 +56,11 @@ class KeyPair(base.NovaObject):
db.key_pair_destroy(context, self.user_id, self.name) db.key_pair_destroy(context, self.user_id, self.name)
def _make_list(context, list_obj, item_cls, db_list):
list_obj.objects = []
for db_item in db_list:
item = item_cls._from_db_object(context, item_cls(), db_item)
list_obj.objects.append(item)
list_obj.obj_reset_changes()
return list_obj
class KeyPairList(base.ObjectListBase, base.NovaObject): class KeyPairList(base.ObjectListBase, base.NovaObject):
@base.remotable_classmethod @base.remotable_classmethod
def get_by_user(cls, context, user_id): def get_by_user(cls, context, user_id):
db_keypairs = db.key_pair_get_all_by_user(context, user_id) db_keypairs = db.key_pair_get_all_by_user(context, user_id)
return _make_list(context, KeyPairList(), KeyPair, db_keypairs) return base.obj_make_list(context, KeyPairList(), KeyPair, db_keypairs)
@base.remotable_classmethod @base.remotable_classmethod
def get_count_by_user(cls, context, user_id): def get_count_by_user(cls, context, user_id):

View File

@@ -119,25 +119,16 @@ class Service(base.NovaObject):
db.service_destroy(context, self.id) db.service_destroy(context, self.id)
def _make_list(context, list_obj, item_cls, db_list):
list_obj.objects = []
for db_item in db_list:
item = item_cls._from_db_object(context, item_cls(), db_item)
list_obj.objects.append(item)
list_obj.obj_reset_changes()
return list_obj
class ServiceList(base.ObjectListBase, base.NovaObject): class ServiceList(base.ObjectListBase, base.NovaObject):
@base.remotable_classmethod @base.remotable_classmethod
def get_by_topic(cls, context, topic): def get_by_topic(cls, context, topic):
db_services = db.service_get_all_by_topic(context, topic) db_services = db.service_get_all_by_topic(context, topic)
return _make_list(context, ServiceList(), Service, db_services) return base.obj_make_list(context, ServiceList(), Service, db_services)
@base.remotable_classmethod @base.remotable_classmethod
def get_by_host(cls, context, host): def get_by_host(cls, context, host):
db_services = db.service_get_all_by_host(context, host) db_services = db.service_get_all_by_host(context, host)
return _make_list(context, ServiceList(), Service, db_services) return base.obj_make_list(context, ServiceList(), Service, db_services)
@base.remotable_classmethod @base.remotable_classmethod
def get_all(cls, context, disabled=None, set_zones=False): def get_all(cls, context, disabled=None, set_zones=False):
@@ -145,4 +136,4 @@ class ServiceList(base.ObjectListBase, base.NovaObject):
if set_zones: if set_zones:
db_services = availability_zones.set_availability_zones( db_services = availability_zones.set_availability_zones(
context, db_services) context, db_services)
return _make_list(context, ServiceList(), Service, db_services) return base.obj_make_list(context, ServiceList(), Service, db_services)

View File

@@ -33,6 +33,14 @@ class MyObj(base.NovaObject):
'missing': str, 'missing': str,
} }
@staticmethod
def _from_db_object(context, obj, db_obj):
self = MyObj()
self.foo = db_obj['foo']
self.bar = db_obj['bar']
self.missing = db_obj['missing']
return self
def obj_load_attr(self, attrname): def obj_load_attr(self, attrname):
setattr(self, attrname, 'loaded!') setattr(self, attrname, 'loaded!')
@@ -234,6 +242,21 @@ class TestUtils(test.TestCase):
self.assertEqual([{'foo': 0}, {'foo': 1}], self.assertEqual([{'foo': 0}, {'foo': 1}],
base.obj_to_primitive(mylist)) base.obj_to_primitive(mylist))
def test_obj_make_list(self):
class MyList(base.ObjectListBase, base.NovaObject):
pass
db_objs = [{'foo': 1, 'bar': 'baz', 'missing': 'banana'},
{'foo': 2, 'bar': 'bat', 'missing': 'apple'},
]
mylist = base.obj_make_list('ctxt', MyList(), MyObj, db_objs)
self.assertEqual(2, len(mylist))
self.assertEqual('ctxt', mylist._context)
for index, item in enumerate(mylist):
self.assertEqual(db_objs[index]['foo'], item.foo)
self.assertEqual(db_objs[index]['bar'], item.bar)
self.assertEqual(db_objs[index]['missing'], item.missing)
class _BaseTestCase(test.TestCase): class _BaseTestCase(test.TestCase):
def setUp(self): def setUp(self):