From 47181ae3ebcd1533c22378ee31a4b1f0848926d6 Mon Sep 17 00:00:00 2001 From: Ludovic Beliveau Date: Wed, 18 Nov 2015 11:52:39 -0500 Subject: [PATCH] Allow saving empty pci_device_pools in ComputeNode object Prior to this patch, saving a ComputeNode with a pci_device_pools attribute that has no objects specified in it (empty PciDevicePool list) would result in the change not being saved. Object of type PciDevicePoolList are evaluated like a list, thefore a conditional statement like 'if pools' will always evaluate to False even if 'pools' is not None. Without this fix, if 'pci_passthrough_whitelist' is cleared in the configuration, nova scheduler still think a compute node has PCI devices available and can still trigger scheduling an instance with PCI devices on the node. Change-Id: Ib3c19d569b9b3b23a293ad55dd9023291435d5a6 Closes-Bug: #1487451 --- nova/objects/compute_node.py | 12 ++++-- nova/tests/unit/objects/test_compute_node.py | 45 +++++++++++++++++++- 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/nova/objects/compute_node.py b/nova/objects/compute_node.py index c3b5cf9fe20a..d94b5d37998e 100644 --- a/nova/objects/compute_node.py +++ b/nova/objects/compute_node.py @@ -208,7 +208,9 @@ class ComputeNode(base.NovaPersistentObject, base.NovaObject, compute['supported_hv_specs'] = hv_specs pci_stats = db_compute.get('pci_stats') - compute.pci_device_pools = pci_device_pool.from_pci_stats(pci_stats) + if pci_stats is not None: + pci_stats = pci_device_pool.from_pci_stats(pci_stats) + compute.pci_device_pools = pci_stats compute._context = context # Make sure that we correctly set the host field depending on either @@ -288,9 +290,11 @@ class ComputeNode(base.NovaPersistentObject, base.NovaObject, @staticmethod def _convert_pci_stats_to_db_format(updates): - pools = updates.pop('pci_device_pools', None) - if pools: - updates['pci_stats'] = jsonutils.dumps(pools.obj_to_primitive()) + if 'pci_device_pools' in updates: + pools = updates.pop('pci_device_pools') + if pools is not None: + pools = jsonutils.dumps(pools.obj_to_primitive()) + updates['pci_stats'] = pools @base.remotable def create(self): diff --git a/nova/tests/unit/objects/test_compute_node.py b/nova/tests/unit/objects/test_compute_node.py index 70c9471ec3a6..a26c19b7005a 100644 --- a/nova/tests/unit/objects/test_compute_node.py +++ b/nova/tests/unit/objects/test_compute_node.py @@ -134,8 +134,11 @@ class _TestComputeNodeObject(object): self.assertJsonEqual(expected, obj_val) def pci_device_pools_comparator(self, expected, obj_val): - obj_val = obj_val.obj_to_primitive() - self.assertJsonEqual(expected, obj_val) + if obj_val is not None: + obj_val = obj_val.obj_to_primitive() + self.assertJsonEqual(expected, obj_val) + else: + self.assertEqual(expected, obj_val) def comparators(self): return {'stats': self.assertJsonEqual, @@ -313,6 +316,44 @@ class _TestComputeNodeObject(object): self.assertEqual(uuidsentinel.fake_compute_node, obj.uuid) self.assertFalse(mock_gu.called) + def test_save_pci_device_pools_empty(self): + fake_pci = jsonutils.dumps( + objects.PciDevicePoolList(objects=[]).obj_to_primitive()) + compute_dict = fake_compute_node.copy() + compute_dict['pci_stats'] = fake_pci + + with mock.patch.object( + db, 'compute_node_update', + return_value=compute_dict) as mock_compute_node_update: + compute = compute_node.ComputeNode(context=self.context) + compute.id = 123 + compute.pci_device_pools = objects.PciDevicePoolList(objects=[]) + compute.save() + self.compare_obj(compute, compute_dict, + subs=self.subs(), + comparators=self.comparators()) + + mock_compute_node_update.assert_called_once_with( + self.context, 123, {'pci_stats': fake_pci}) + + def test_save_pci_device_pools_null(self): + compute_dict = fake_compute_node.copy() + compute_dict['pci_stats'] = None + + with mock.patch.object( + db, 'compute_node_update', + return_value=compute_dict) as mock_compute_node_update: + compute = compute_node.ComputeNode(context=self.context) + compute.id = 123 + compute.pci_device_pools = None + compute.save() + self.compare_obj(compute, compute_dict, + subs=self.subs(), + comparators=self.comparators()) + + mock_compute_node_update.assert_called_once_with( + self.context, 123, {'pci_stats': None}) + @mock.patch.object(db, 'compute_node_create', return_value=fake_compute_node) def test_set_id_failure(self, db_mock):