diff --git a/nova/exception.py b/nova/exception.py index 6a22d57591ba..8447fb94c31d 100644 --- a/nova/exception.py +++ b/nova/exception.py @@ -1376,3 +1376,15 @@ class PciDeviceNotFoundById(NotFound): class PciDeviceNotFound(NovaException): msg_fmt = _("PCI Device %(node_id)s:%(address)s not found.") + + +class PciDeviceInvalidStatus(NovaException): + msg_fmt = _( + "PCI Device %(compute_node_id)s:%(address)s is %(status)s " + "instead of %(hopestatus)s") + + +class PciDeviceInvalidOwner(NovaException): + msg_fmt = _( + "PCI Device %(compute_node_id)s:%(address)s is owned by %(owner)s " + "instead of %(hopeowner)s") diff --git a/nova/objects/instance.py b/nova/objects/instance.py index 379dd0f460e1..78f9d1112b81 100644 --- a/nova/objects/instance.py +++ b/nova/objects/instance.py @@ -22,6 +22,7 @@ from nova import notifications from nova.objects import base from nova.objects import instance_fault from nova.objects import instance_info_cache +from nova.objects import pci_device from nova.objects import security_group from nova.objects import utils as obj_utils from nova import utils @@ -33,7 +34,8 @@ CONF = cfg.CONF # These are fields that can be specified as expected_attrs -INSTANCE_OPTIONAL_FIELDS = ['metadata', 'system_metadata', 'fault'] +INSTANCE_OPTIONAL_FIELDS = ['metadata', 'system_metadata', 'fault', + 'pci_devices'] # These are fields that are always joined by the db right now INSTANCE_IMPLIED_FIELDS = ['info_cache', 'security_groups'] # These are fields that are optional but don't translate to db columns @@ -50,7 +52,8 @@ class Instance(base.NovaObject): # save() # Version 1.4: Added locked_by and deprecated locked # Version 1.5: Added cleaned - VERSION = '1.5' + # Version 1.6: Added pci_devices + VERSION = '1.6' fields = { 'id': int, @@ -136,6 +139,8 @@ class Instance(base.NovaObject): 'cleaned': bool, + 'pci_devices': obj_utils.nested_object_or_none( + pci_device.PciDeviceList), } obj_extra_fields = ['name'] @@ -199,6 +204,8 @@ class Instance(base.NovaObject): _attr_info_cache_to_primitive = obj_utils.obj_serializer('info_cache') _attr_security_groups_to_primitive = obj_utils.obj_serializer( 'security_groups') + _attr_pci_devices_to_primitive = obj_utils.obj_serializer( + 'pci_devices') _attr_scheduled_at_from_primitive = obj_utils.dt_deserializer _attr_launched_at_from_primitive = obj_utils.dt_deserializer @@ -210,6 +217,9 @@ class Instance(base.NovaObject): def _attr_security_groups_from_primitive(self, val): return base.NovaObject.obj_from_primitive(val) + def _attr_pci_devices_from_primitive(self, val): + return base.NovaObject.obj_from_primitive(val) + @staticmethod def _from_db_object(context, instance, db_inst, expected_attrs=None): """Method to help with migration to objects. @@ -237,6 +247,13 @@ class Instance(base.NovaObject): instance['fault'] = ( instance_fault.InstanceFault.get_latest_for_instance( context, instance.uuid)) + + if 'pci_devices' in expected_attrs: + instance['pci_devices'] =\ + pci_device._make_pci_list(context, + pci_device.PciDeviceList(), + db_inst['pci_devices']) + # NOTE(danms): info_cache and security_groups are almost # always joined in the DB layer right now, so check to see if # they are asked for and are present in the resulting object @@ -263,6 +280,8 @@ class Instance(base.NovaObject): columns_to_join.append('metadata') if 'system_metadata' in attrs: columns_to_join.append('system_metadata') + if 'pci_devices' in attrs: + columns_to_join.append('pci_devices') # NOTE(danms): The DB API currently always joins info_cache and # security_groups for get operations, so don't add them to the # list of columns @@ -340,6 +359,12 @@ class Instance(base.NovaObject): # NOTE(danms): I don't think we need to worry about this, do we? pass + def _save_pci_devices(self, context): + # NOTE(yjiang5): All devices held by PCI tracker, only PCI tracker + # permitted to update the DB. all change to devices from here will + # be dropped. + pass + @base.remotable def save(self, context, expected_vm_state=None, expected_task_state=None, admin_state_reset=False): @@ -448,6 +473,8 @@ class Instance(base.NovaObject): extra.append('info_cache') elif attrname == 'security_groups': extra.append('security_groups') + elif attrname == 'pci_devices': + extra.append('pci_devices') elif attrname == 'fault': extra.append('fault') diff --git a/nova/objects/pci_device.py b/nova/objects/pci_device.py new file mode 100644 index 000000000000..e964622080b8 --- /dev/null +++ b/nova/objects/pci_device.py @@ -0,0 +1,269 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 Intel Corporation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# @author: Yongli He, Intel Corporation. + +import copy +import functools + +from nova import db +from nova import exception +from nova.objects import base +from nova.objects import utils as obj_utils +from nova.openstack.common import jsonutils +from nova.openstack.common import log as logging + + +LOG = logging.getLogger(__name__) + + +def check_device_status(dev_status=None): + """Decorator to check device status before changing it.""" + + if dev_status is not None and not isinstance(dev_status, set): + dev_status = set(dev_status) + + def outer(f): + @functools.wraps(f) + def inner(self, instance=None): + if self['status'] not in dev_status: + raise exception.PciDeviceInvalidStatus( + compute_node_id=self.compute_node_id, + address=self.address, status=self.status, + hopestatus=dev_status) + if instance: + return f(self, instance) + else: + return f(self) + return inner + return outer + + +class PciDevice(base.NovaObject): + + """Object to represent a PCI device on a compute node. + + PCI devices are managed by the compute resource tracker, which discovers + the devices from the hardware platform, claims, allocates and frees + devices for instances. + + The PCI device information is permanently maintained in a database. + This makes it convenient to get PCI device information, like physical + function for a VF device, adjacent switch IP address for a NIC, + hypervisor identification for a PCI device, etc. It also provides a + convenient way to check device allocation information for administrator + purposes. + + A device can be in available/claimed/allocated/deleted/removed state. + + A device is available when it is discovered.. + + A device is claimed prior to being allocated to an instance. Normally the + transition from claimed to allocated is quick. However, during a resize + operation the transition can take longer, because devices are claimed in + prep_resize and allocated in finish_resize. + + A device becomes removed when hot removed from a node (i.e. not found in + the next auto-discover) but not yet synced with the DB. A removed device + should not be allocated to any instance, and once deleted from the DB, + the device object is changed to deleted state and no longer synced with + the DB. + + Filed notes: + 'dev_id': + Hypervisor's identification for the device, the string format + is hypervisor specific + 'extra_info': + Device-specific properties like PF address, switch ip address etc. + """ + + # Version 1.0: Initial version + VERSION = '1.0' + + fields = { + 'id': int, + # Note(yjiang5): the compute_node_id may be None because the pci + # device objects are created before the compute node is created in DB + 'compute_node_id': obj_utils.int_or_none, + 'address': str, + 'vendor_id': str, + 'product_id': str, + 'dev_type': str, + 'status': str, + 'dev_id': obj_utils.str_or_none, + 'label': obj_utils.str_or_none, + 'instance_uuid': obj_utils.str_or_none, + 'extra_info': dict, + } + + def update_device(self, dev_dict): + """Sync the content from device dictionary to device object. + + The resource tracker updates the available devices periodically. + To avoid meaningless syncs with the database, we update the device + object only if a value changed. + """ + + # Note(yjiang5): status/instance_uuid should only be updated by + # functions like claim/allocate etc. The id is allocated by + # database. The extra_info is created by the object. + no_changes = ('status', 'instance_uuid', 'id', 'extra_info') + map(lambda x: dev_dict.pop(x, None), + [key for key in no_changes]) + + for k, v in dev_dict.items(): + if k in self.fields.keys(): + self[k] = v + else: + # Note (yjiang5) extra_info.update does not update + # obj_what_changed, set it explicitely + extra_info = self.extra_info + extra_info.update({k: v}) + self.extra_info = extra_info + + def __init__(self): + super(PciDevice, self).__init__() + self.extra_info = {} + self.obj_reset_changes() + + @staticmethod + def _from_db_object(context, pci_device, db_dev): + for key in pci_device.fields: + if key != 'extra_info': + pci_device[key] = db_dev.get(key) + else: + extra_info = db_dev.get("extra_info") + pci_device.extra_info = jsonutils.loads(extra_info) + pci_device._context = context + pci_device.obj_reset_changes() + return pci_device + + @base.remotable_classmethod + def get_by_dev_addr(cls, context, compute_node_id, dev_addr): + db_dev = db.pci_device_get_by_addr( + context, compute_node_id, dev_addr) + return cls._from_db_object(context, cls(), db_dev) + + @base.remotable_classmethod + def get_by_dev_id(cls, context, id): + db_dev = db.pci_device_get_by_id(context, id) + return cls._from_db_object(context, cls(), db_dev) + + @classmethod + def create(cls, dev_dict): + """Create a PCI device based on hypervisor information. + + As the device object is just created and is not synced with db yet + thus we should not reset changes here for fields from dict. + """ + pci_device = cls() + pci_device.update_device(dev_dict) + pci_device.status = 'available' + return pci_device + + @check_device_status(dev_status=['available']) + def claim(self, instance): + self.status = 'claimed' + self.instance_uuid = instance['uuid'] + + @check_device_status(dev_status=['available', 'claimed']) + def allocate(self, instance): + if self.status == 'claimed' and self.instance_uuid != instance['uuid']: + raise exception.PciDeviceInvalidOwner( + compute_node_id=self.compute_node_id, + address=self.address, owner=self.instance_uuid, + hopeowner=instance['uuid']) + + self.status = 'allocated' + self.instance_uuid = instance['uuid'] + + # Notes(yjiang5): remove this check when instance object for + # compute manager is finished + if isinstance(instance, dict): + if 'pci_devices' not in instance: + instance['pci_devices'] = [] + instance['pci_devices'].append(copy.copy(self)) + else: + instance.pci_devices.objects.append(copy.copy(self)) + + @check_device_status(dev_status=['available']) + def remove(self): + self.status = 'removed' + self.instance_uuid = None + + @check_device_status(dev_status=['claimed', 'allocated']) + def free(self, instance=None): + if instance and self.instance_uuid != instance['uuid']: + raise exception.PciDeviceInvalidOwner( + compute_node_id=self.compute_node_id, + address=self.address, owner=self.instance_uuid, + hopeowner=instance['uuid']) + old_status = self.status + self.status = 'available' + self.instance_uuid = None + if old_status == 'allocated' and instance: + # Notes(yjiang5): remove this check when instance object for + # compute manager is finished + existed = next((dev for dev in instance['pci_devices'] + if dev.id == self.id)) + if isinstance(instance, dict): + instance['pci_devices'].remove(existed) + else: + instance.pci_devices.objects.remove(existed) + + @base.remotable + def save(self, context): + if self.status == 'removed': + self.status = 'deleted' + db.pci_device_destroy(context, self.compute_node_id, self.address) + elif self.status != 'deleted': + updates = {} + for field in self.obj_what_changed(): + if field == 'extra_info': + updates['extra_info'] = jsonutils.dumps(self.extra_info) + else: + updates[field] = self[field] + if updates: + db_pci = db.pci_device_update(context, self.compute_node_id, + self.address, updates) + self._from_db_object(context, self, db_pci) + + +def _make_pci_list(context, pci_list, db_list): + pci_list.objects = [] + for pci in db_list: + pci_obj = PciDevice._from_db_object(context, PciDevice(), pci) + pci_list.objects.append(pci_obj) + + pci_list.obj_reset_changes() + return pci_list + + +class PciDeviceList(base.ObjectListBase, base.NovaObject): + def __init__(self): + super(PciDeviceList, self).__init__() + self.objects = [] + self.obj_reset_changes() + + @base.remotable_classmethod + def get_by_compute_node(cls, context, node_id): + db_dev_list = db.pci_device_get_all_by_node(context, node_id) + return _make_pci_list(context, cls(), db_dev_list) + + @base.remotable_classmethod + def get_by_instance_uuid(cls, context, uuid): + db_dev_list = db.pci_device_get_all_by_instance_uuid(context, uuid) + return _make_pci_list(context, cls(), db_dev_list) diff --git a/nova/tests/api/openstack/fakes.py b/nova/tests/api/openstack/fakes.py index b6c23e6c7f3d..8a1e440a7ecc 100644 --- a/nova/tests/api/openstack/fakes.py +++ b/nova/tests/api/openstack/fakes.py @@ -561,6 +561,7 @@ def stub_instance(id, user_id=None, project_id=None, host=None, "security_groups": security_groups, "root_device_name": root_device_name, "system_metadata": utils.dict_to_metadata(sys_meta), + "pci_devices": [], "vm_mode": "", "default_swap_device": "", "default_ephemeral_device": "", diff --git a/nova/tests/compute/test_compute.py b/nova/tests/compute/test_compute.py index 4e41fcf9d113..5b9535a21f04 100644 --- a/nova/tests/compute/test_compute.py +++ b/nova/tests/compute/test_compute.py @@ -123,6 +123,10 @@ def unify_instance(instance): elif k == 'fault': # NOTE(danms): DB models don't have 'fault' continue + elif k == 'pci_devices': + # NOTE(yonlig.he) pci devices need lazy loading + # fake db does not support it yet. + continue newdict[k] = v return newdict @@ -282,6 +286,7 @@ class BaseTestCase(test.TestCase): inst['updated_at'] = timeutils.utcnow() inst['launched_at'] = timeutils.utcnow() inst['security_groups'] = [] + inst['pci_devices'] = [] inst.update(params) if services: _create_service_entries(self.context.elevated(), diff --git a/nova/tests/objects/test_instance.py b/nova/tests/objects/test_instance.py index bcd7a292b3a6..f83102441ef2 100644 --- a/nova/tests/objects/test_instance.py +++ b/nova/tests/objects/test_instance.py @@ -52,6 +52,7 @@ class _TestInstanceObject(object): fake_instance['deleted'] = False fake_instance['info_cache']['instance_uuid'] = fake_instance['uuid'] fake_instance['security_groups'] = None + fake_instance['pci_devices'] = [] return fake_instance def test_datetime_deserialization(self): @@ -113,7 +114,7 @@ class _TestInstanceObject(object): self.mox.StubOutWithMock(db, 'instance_get_by_uuid') db.instance_get_by_uuid( ctxt, 'uuid', - columns_to_join=['metadata', 'system_metadata'] + columns_to_join=['metadata', 'system_metadata', 'pci_devices'] ).AndReturn(self.fake_instance) self.mox.ReplayAll() inst = instance.Instance.get_by_uuid( @@ -387,6 +388,61 @@ class _TestInstanceObject(object): inst = instance.Instance.get_by_uuid(ctxt, fake_uuid) self.assertEqual(0, len(inst.security_groups)) + def test_with_empty_pci_devices(self): + ctxt = context.get_admin_context() + fake_inst = dict(self.fake_instance, pci_devices=[]) + fake_uuid = fake_inst['uuid'] + self.mox.StubOutWithMock(db, 'instance_get_by_uuid') + db.instance_get_by_uuid(ctxt, fake_uuid, + columns_to_join=['pci_devices'] + ).AndReturn(fake_inst) + self.mox.ReplayAll() + inst = instance.Instance.get_by_uuid(ctxt, fake_uuid, + ['pci_devices']) + self.assertEqual(len(inst.pci_devices), 0) + + def test_with_pci_devices(self): + ctxt = context.get_admin_context() + fake_inst = dict(self.fake_instance) + fake_uuid = fake_inst['uuid'] + fake_inst['pci_devices'] = [ + {'created_at': None, + 'updated_at': None, + 'deleted_at': None, + 'deleted': None, + 'id': 2, + 'compute_node_id': 1, + 'address': 'a1', + 'product_id': 'p1', + 'vendor_id': 'v1', + 'status': 'allocated', + 'instance_uuid': fake_uuid, + 'extra_info': '{}'}, + { + 'created_at': None, + 'updated_at': None, + 'deleted_at': None, + 'deleted': None, + 'id': 1, + 'compute_node_id': 1, + 'address': 'a', + 'product_id': 'p', + 'vendor_id': 'v', + 'status': 'allocated', + 'instance_uuid': fake_uuid, + 'extra_info': '{}'}, + ] + self.mox.StubOutWithMock(db, 'instance_get_by_uuid') + db.instance_get_by_uuid(ctxt, fake_uuid, + columns_to_join=['pci_devices'] + ).AndReturn(fake_inst) + self.mox.ReplayAll() + inst = instance.Instance.get_by_uuid(ctxt, fake_uuid, + ['pci_devices']) + self.assertEqual(len(inst.pci_devices), 2) + self.assertEqual(inst.pci_devices[0].instance_uuid, fake_uuid) + self.assertEqual(inst.pci_devices[1].instance_uuid, fake_uuid) + def test_with_fault(self): ctxt = context.get_admin_context() fake_inst = dict(self.fake_instance) diff --git a/nova/tests/objects/test_pci_device.py b/nova/tests/objects/test_pci_device.py new file mode 100644 index 000000000000..24d2b78cfbe6 --- /dev/null +++ b/nova/tests/objects/test_pci_device.py @@ -0,0 +1,315 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright (c) 2012 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import copy + +from nova import context +from nova import db +from nova import exception +from nova.objects import instance +from nova.objects import pci_device +from nova.tests.objects import test_objects + +dev_dict = { + 'compute_node_id': 1, + 'address': 'a', + 'product_id': 'p', + 'vendor_id': 'v', + 'status': 'available'} + + +fake_db_dev = { + 'created_at': None, + 'updated_at': None, + 'deleted_at': None, + 'deleted': None, + 'id': 1, + 'compute_node_id': 1, + 'address': 'a', + 'product_id': 'p', + 'vendor_id': 'v', + 'status': 'available', + 'extra_info': '{}', + } + + +fake_db_dev_1 = { + 'created_at': None, + 'updated_at': None, + 'deleted_at': None, + 'deleted': None, + 'id': 2, + 'compute_node_id': 1, + 'address': 'a1', + 'product_id': 'p1', + 'vendor_id': 'v1', + 'status': 'available', + 'extra_info': '{}', + } + + +class _TestPciDeviceObject(object): + def _create_fake_instance(self): + self.inst = instance.Instance() + self.inst.uuid = 'fake-inst-uuid' + self.inst.pci_devices = pci_device.PciDeviceList() + + def _create_fake_pci_device(self): + ctxt = context.get_admin_context() + self.mox.StubOutWithMock(db, 'pci_device_get_by_addr') + db.pci_device_get_by_addr(ctxt, 1, 'a').AndReturn(fake_db_dev) + self.mox.ReplayAll() + self.pci_device = pci_device.PciDevice.get_by_dev_addr(ctxt, 1, 'a') + + def test_create_pci_device(self): + self.pci_device = pci_device.PciDevice.create(dev_dict) + self.assertEqual(self.pci_device.product_id, 'p') + self.assertEqual(self.pci_device.obj_what_changed(), + set(['compute_node_id', 'product_id', 'vendor_id', + 'status', 'address'])) + + def test_pci_device_extra_info(self): + self.dev_dict = copy.copy(dev_dict) + self.dev_dict['k1'] = 'v1' + self.dev_dict['k2'] = 'v2' + self.pci_device = pci_device.PciDevice.create(self.dev_dict) + extra_value = self.pci_device.extra_info + self.assertEqual(extra_value.get('k1'), 'v1') + self.assertEqual(set(extra_value.keys()), set(('k1', 'k2'))) + self.assertEqual(self.pci_device.obj_what_changed(), + set(['compute_node_id', 'address', 'product_id', + 'vendor_id', 'status', 'extra_info'])) + + def test_update_device(self): + self.pci_device = pci_device.PciDevice.create(dev_dict) + self.pci_device.obj_reset_changes() + changes = {'product_id': 'p2', 'vendor_id': 'v2'} + self.pci_device.update_device(changes) + self.assertEqual(self.pci_device.vendor_id, 'v2') + self.assertEqual(self.pci_device.obj_what_changed(), + set(['vendor_id', 'product_id'])) + + def test_update_device_same_value(self): + self.pci_device = pci_device.PciDevice.create(dev_dict) + self.pci_device.obj_reset_changes() + changes = {'product_id': 'p', 'vendor_id': 'v2'} + self.pci_device.update_device(changes) + self.assertEqual(self.pci_device.product_id, 'p') + self.assertEqual(self.pci_device.vendor_id, 'v2') + self.assertEqual(self.pci_device.obj_what_changed(), + set(['vendor_id', 'product_id'])) + + def test_get_by_dev_addr(self): + ctxt = context.get_admin_context() + self.mox.StubOutWithMock(db, 'pci_device_get_by_addr') + db.pci_device_get_by_addr(ctxt, 1, 'a').AndReturn(fake_db_dev) + self.mox.ReplayAll() + self.pci_device = pci_device.PciDevice.get_by_dev_addr(ctxt, 1, 'a') + self.assertEqual(self.pci_device.product_id, 'p') + self.assertEqual(self.pci_device.obj_what_changed(), set()) + self.assertRemotes() + + def test_get_by_dev_id(self): + ctxt = context.get_admin_context() + self.mox.StubOutWithMock(db, 'pci_device_get_by_id') + db.pci_device_get_by_id(ctxt, 1).AndReturn(fake_db_dev) + self.mox.ReplayAll() + self.pci_device = pci_device.PciDevice.get_by_dev_id(ctxt, 1) + self.assertEqual(self.pci_device.product_id, 'p') + self.assertEqual(self.pci_device.obj_what_changed(), set()) + self.assertRemotes() + + def test_claim_device(self): + self._create_fake_instance() + self.pci_device = pci_device.PciDevice.create(dev_dict) + self.pci_device.claim(self.inst) + self.assertEqual(self.pci_device.status, 'claimed') + self.assertEqual(self.pci_device.instance_uuid, + 'fake-inst-uuid') + self.assertEqual(len(self.inst.pci_devices), 0) + + def test_claim_device_fail(self): + self._create_fake_instance() + self._create_fake_pci_device() + self.pci_device.status = 'allocated' + self.assertRaises(exception.PciDeviceInvalidStatus, + self.pci_device.claim, self.inst) + + def test_allocate_device(self): + self._create_fake_instance() + self._create_fake_pci_device() + self.pci_device.claim(self.inst) + self.pci_device.allocate(self.inst) + self.assertEqual(self.pci_device.status, 'allocated') + self.assertEqual(self.pci_device.instance_uuid, 'fake-inst-uuid') + self.assertEqual(len(self.inst.pci_devices), 1) + self.assertEqual(self.inst.pci_devices[0]['vendor_id'], 'v') + self.assertEqual(self.inst.pci_devices[0]['status'], 'allocated') + + def test_allocacte_device_fail_status(self): + self._create_fake_instance() + self._create_fake_pci_device() + self.pci_device.status = 'removed' + self.assertRaises(exception.PciDeviceInvalidStatus, + self.pci_device.allocate, + self.inst) + + def test_allocacte_device_fail_owner(self): + self._create_fake_instance() + self._create_fake_pci_device() + inst_2 = instance.Instance() + inst_2.uuid = 'fake-inst-uuid-2' + self.pci_device.claim(self.inst) + self.assertRaises(exception.PciDeviceInvalidOwner, + self.pci_device.allocate, inst_2) + + def test_free_claimed_device(self): + self._create_fake_instance() + self._create_fake_pci_device() + self.pci_device.claim(self.inst) + self.pci_device.free(self.inst) + self.assertEqual(self.pci_device.status, 'available') + self.assertEqual(self.pci_device.instance_uuid, None) + + def test_free_allocated_device(self): + self._create_fake_instance() + self._create_fake_pci_device() + self.pci_device.claim(self.inst) + self.pci_device.allocate(self.inst) + self.assertEqual(len(self.inst.pci_devices), 1) + self.pci_device.free(self.inst) + self.assertEqual(len(self.inst.pci_devices), 0) + self.assertEqual(self.pci_device.status, 'available') + self.assertEqual(self.pci_device.instance_uuid, None) + + def test_free_device_fail(self): + self._create_fake_pci_device() + self.pci_device.status = 'removed' + self.assertRaises(exception.PciDeviceInvalidStatus, + self.pci_device.free) + + def test_remove_device(self): + self._create_fake_pci_device() + self.pci_device.remove() + self.assertEqual(self.pci_device.status, 'removed') + self.assertEqual(self.pci_device.instance_uuid, None) + + def test_remove_device_fail(self): + self._create_fake_instance() + self._create_fake_pci_device() + self.pci_device.claim(self.inst) + self.assertRaises(exception.PciDeviceInvalidStatus, + self.pci_device.remove) + + def test_save(self): + ctxt = context.get_admin_context() + self._create_fake_pci_device() + return_dev = dict(fake_db_dev, status='available', + instance_uuid='fake-uuid-3') + self.pci_device.status = 'allocated' + self.pci_device.instance_uuid = 'fake-uuid-2' + expected_updates = dict(status='allocated', + instance_uuid='fake-uuid-2') + self.mox.StubOutWithMock(db, 'pci_device_update') + db.pci_device_update(ctxt, 1, 'a', + expected_updates).AndReturn(return_dev) + self.mox.ReplayAll() + self.pci_device.save(ctxt) + self.assertEqual(self.pci_device.status, 'available') + self.assertEqual(self.pci_device.instance_uuid, + 'fake-uuid-3') + self.assertRemotes() + + def test_save_removed(self): + ctxt = context.get_admin_context() + self._create_fake_pci_device() + self.pci_device.status = 'removed' + self.mox.StubOutWithMock(db, 'pci_device_destroy') + db.pci_device_destroy(ctxt, 1, 'a') + self.mox.ReplayAll() + self.pci_device.save(ctxt) + self.assertEqual(self.pci_device.status, 'deleted') + self.assertRemotes() + + def test_save_deleted(self): + def _fake_destroy(ctxt, node_id, addr): + self.called = True + + def _fake_update(ctxt, node_id, addr, updates): + self.called = True + ctxt = context.get_admin_context() + self.stubs.Set(db, 'pci_device_destroy', _fake_destroy) + self.stubs.Set(db, 'pci_device_update', _fake_update) + self._create_fake_pci_device() + self.pci_device.status = 'deleted' + self.called = False + self.pci_device.save(ctxt) + self.assertEqual(self.called, False) + + +class TestPciDeviceObject(test_objects._LocalTest, + _TestPciDeviceObject): + pass + + +class TestPciDeviceObjectRemote(test_objects._RemoteTest, + _TestPciDeviceObject): + pass + + +fake_pci_devs = [fake_db_dev, fake_db_dev_1] + + +class _TestPciDeviceListObject(object): + def test_get_by_compute_node(self): + ctxt = context.get_admin_context() + self.mox.StubOutWithMock(db, 'pci_device_get_all_by_node') + db.pci_device_get_all_by_node(ctxt, 1).AndReturn(fake_pci_devs) + self.mox.ReplayAll() + devs = pci_device.PciDeviceList.get_by_compute_node(ctxt, 1) + for i in range(len(fake_pci_devs)): + self.assertTrue(isinstance(devs[i], pci_device.PciDevice)) + self.assertEqual(fake_pci_devs[i]['vendor_id'], devs[i].vendor_id) + self.assertRemotes() + + def test_get_by_instance_uuid(self): + ctxt = context.get_admin_context() + fake_db_1 = dict(fake_db_dev, address='a1', + status='allocated', instance_uuid='1') + fake_db_2 = dict(fake_db_dev, address='a2', + status='allocated', instance_uuid='1') + self.mox.StubOutWithMock(db, 'pci_device_get_all_by_instance_uuid') + db.pci_device_get_all_by_instance_uuid(ctxt, '1').AndReturn( + [fake_db_1, fake_db_2]) + self.mox.ReplayAll() + devs = pci_device.PciDeviceList.get_by_instance_uuid(ctxt, '1') + self.assertEqual(len(devs), 2) + for i in range(len(fake_pci_devs)): + self.assertTrue(isinstance(devs[i], pci_device.PciDevice)) + self.assertEqual(devs[0].vendor_id, 'v') + self.assertEqual(devs[1].vendor_id, 'v') + self.assertRemotes() + + +class TestPciDeviceListObject(test_objects._LocalTest, + _TestPciDeviceListObject): + pass + + +class TestPciDeviceListObjectRemote(test_objects._RemoteTest, + _TestPciDeviceListObject): + pass