diff --git a/nova/objects/pci_device.py b/nova/objects/pci_device.py index 96ff27b602f2..01050fbe36f7 100644 --- a/nova/objects/pci_device.py +++ b/nova/objects/pci_device.py @@ -26,6 +26,21 @@ from nova import utils LOG = logging.getLogger(__name__) +def compare_pci_device_attributes(obj_a, obj_b): + pci_ignore_fields = base.NovaPersistentObject.fields.keys() + for name in obj_a.obj_fields: + if name in pci_ignore_fields: + continue + is_set_a = obj_a.obj_attr_is_set(name) + is_set_b = obj_b.obj_attr_is_set(name) + if is_set_a != is_set_b: + return False + if is_set_a: + if getattr(obj_a, name) != getattr(obj_b, name): + return False + return True + + # TODO(berrange): Remove NovaObjectDictCompat class PciDevice(base.NovaPersistentObject, base.NovaObject, base.NovaObjectDictCompat): @@ -127,6 +142,12 @@ class PciDevice(base.NovaPersistentObject, base.NovaObject, self.obj_reset_changes() self.extra_info = {} + def __eq__(self, other): + return compare_pci_device_attributes(self, other) + + def __ne__(self, other): + return not (self == other) + @staticmethod def _from_db_object(context, pci_device, db_dev): for key in pci_device.fields: diff --git a/nova/pci/stats.py b/nova/pci/stats.py index 4a065ba5d727..cedd11561a43 100644 --- a/nova/pci/stats.py +++ b/nova/pci/stats.py @@ -248,3 +248,9 @@ class PciDeviceStats(object): def clear(self): """Clear all the stats maintained.""" self.pools = [] + + def __eq__(self, other): + return cmp(self.pools, other.pools) == 0 + + def __ne__(self, other): + return not (self == other) diff --git a/nova/tests/unit/objects/test_pci_device.py b/nova/tests/unit/objects/test_pci_device.py index d6386bfc90ad..7c6224b70939 100644 --- a/nova/tests/unit/objects/test_pci_device.py +++ b/nova/tests/unit/objects/test_pci_device.py @@ -15,6 +15,8 @@ import copy +from oslo_utils import timeutils + from nova import context from nova import db from nova.objects import instance @@ -213,6 +215,37 @@ class _TestPciDeviceObject(object): self.pci_device = pci_device.PciDevice.create(self.dev_dict) self.assertEqual(1, self.pci_device.numa_node) + def test_pci_device_equivalent(self): + pci_device1 = pci_device.PciDevice.create(dev_dict) + pci_device2 = pci_device.PciDevice.create(dev_dict) + self.assertEqual(pci_device1, pci_device2) + + def test_pci_device_equivalent_with_ignore_field(self): + pci_device1 = pci_device.PciDevice.create(dev_dict) + pci_device2 = pci_device.PciDevice.create(dev_dict) + pci_device2.updated_at = timeutils.utcnow() + self.assertEqual(pci_device1, pci_device2) + + def test_pci_device_not_equivalent1(self): + pci_device1 = pci_device.PciDevice.create(dev_dict) + dev_dict2 = copy.copy(dev_dict) + dev_dict2['address'] = 'b' + pci_device2 = pci_device.PciDevice.create(dev_dict2) + self.assertNotEqual(pci_device1, pci_device2) + + def test_pci_device_not_equivalent2(self): + pci_device1 = pci_device.PciDevice.create(dev_dict) + pci_device2 = pci_device.PciDevice.create(dev_dict) + delattr(pci_device2, 'address') + self.assertNotEqual(pci_device1, pci_device2) + + def test_pci_device_not_equivalent_with_none(self): + pci_device1 = pci_device.PciDevice.create(dev_dict) + pci_device2 = pci_device.PciDevice.create(dev_dict) + pci_device1.instance_uuid = 'aaa' + pci_device2.instance_uuid = None + self.assertNotEqual(pci_device1, pci_device2) + class TestPciDeviceObject(test_objects._LocalTest, _TestPciDeviceObject): diff --git a/nova/tests/unit/pci/test_stats.py b/nova/tests/unit/pci/test_stats.py index cfc9c3779cfa..8685415eb0fa 100644 --- a/nova/tests/unit/pci/test_stats.py +++ b/nova/tests/unit/pci/test_stats.py @@ -96,6 +96,21 @@ class PciDeviceStatsTestCase(test.NoDBTestCase): self.pci_stats.remove_device, self.fake_dev_2) + def test_pci_stats_equivalent(self): + pci_stats2 = stats.PciDeviceStats() + map(pci_stats2.add_device, [self.fake_dev_1, + self.fake_dev_2, + self.fake_dev_3, + self.fake_dev_4]) + self.assertEqual(self.pci_stats, pci_stats2) + + def test_pci_stats_not_equivalent(self): + pci_stats2 = stats.PciDeviceStats() + map(pci_stats2.add_device, [self.fake_dev_1, + self.fake_dev_2, + self.fake_dev_3]) + self.assertNotEqual(self.pci_stats, pci_stats2) + def test_object_create(self): m = objects.pci_device_pool.from_pci_stats(self.pci_stats.pools) new_stats = stats.PciDeviceStats(m)