diff --git a/ironic/db/api.py b/ironic/db/api.py index a9ab47a701..c52cc7bd4c 100644 --- a/ironic/db/api.py +++ b/ironic/db/api.py @@ -924,13 +924,14 @@ class Connection(object): # TODO(rloo) Delete this in Rocky cycle. @abc.abstractmethod - def set_node_traits(self, node_id, traits): + def set_node_traits(self, node_id, traits, version): """Replace all of the node traits with specified list of traits. This ignores duplicate traits in the specified list. :param node_id: The id of a node. :param traits: List of traits. + :param version: the version of the object.Trait. :returns: A list of NodeTrait objects. :raises: InvalidParameterValue if setting the traits would exceed the per-node traits limit. @@ -955,7 +956,7 @@ class Connection(object): """ @abc.abstractmethod - def add_node_trait(self, node_id, trait): + def add_node_trait(self, node_id, trait, version): """Add trait to the node. If the node_id and trait pair already exists, this should still @@ -963,6 +964,7 @@ class Connection(object): :param node_id: The id of a node. :param trait: A trait string. + :param version: the version of the object.Trait. :returns: the NodeTrait object. :raises: InvalidParameterValue if adding the trait would exceed the per-node traits limit. diff --git a/ironic/db/sqlalchemy/api.py b/ironic/db/sqlalchemy/api.py index 5e46f6889b..ade554b4e3 100644 --- a/ironic/db/sqlalchemy/api.py +++ b/ironic/db/sqlalchemy/api.py @@ -1306,7 +1306,7 @@ class Connection(api.Connection): max_traits=MAX_TRAITS_PER_NODE) @oslo_db_api.retry_on_deadlock - def set_node_traits(self, node_id, traits): + def set_node_traits(self, node_id, traits, version): # Remove duplicate traits traits = set(traits) @@ -1317,7 +1317,8 @@ class Connection(api.Connection): self.unset_node_traits(node_id) node_traits = [] for trait in traits: - node_trait = models.NodeTrait(trait=trait, node_id=node_id) + node_trait = models.NodeTrait(trait=trait, node_id=node_id, + version=version) session.add(node_trait) node_traits.append(node_trait) @@ -1337,8 +1338,9 @@ class Connection(api.Connection): return result @oslo_db_api.retry_on_deadlock - def add_node_trait(self, node_id, trait): - node_trait = models.NodeTrait(trait=trait, node_id=node_id) + def add_node_trait(self, node_id, trait, version): + node_trait = models.NodeTrait(trait=trait, node_id=node_id, + version=version) self._check_node_exists(node_id) try: diff --git a/ironic/tests/unit/db/test_node_traits.py b/ironic/tests/unit/db/test_node_traits.py index 7f78f60012..bbd6874115 100644 --- a/ironic/tests/unit/db/test_node_traits.py +++ b/ironic/tests/unit/db/test_node_traits.py @@ -24,31 +24,34 @@ class DbNodeTraitTestCase(base.DbTestCase): self.node = db_utils.create_test_node() def test_set_node_traits(self): - result = self.dbapi.set_node_traits(self.node.id, ['trait1', 'trait2']) + result = self.dbapi.set_node_traits(self.node.id, ['trait1', 'trait2'], + '1.0') self.assertEqual(self.node.id, result[0].node_id) self.assertItemsEqual(['trait1', 'trait2'], [trait.trait for trait in result]) - result = self.dbapi.set_node_traits(self.node.id, []) + result = self.dbapi.set_node_traits(self.node.id, [], '1.0') self.assertEqual([], result) def test_set_node_traits_duplicate(self): result = self.dbapi.set_node_traits(self.node.id, - ['trait1', 'trait2', 'trait2']) + ['trait1', 'trait2', 'trait2'], + '1.0') self.assertEqual(self.node.id, result[0].node_id) self.assertItemsEqual(['trait1', 'trait2'], [trait.trait for trait in result]) def test_set_node_traits_at_limit(self): traits = ['trait%d' % n for n in range(50)] - result = self.dbapi.set_node_traits(self.node.id, traits) + result = self.dbapi.set_node_traits(self.node.id, traits, '1.0') self.assertEqual(self.node.id, result[0].node_id) self.assertItemsEqual(traits, [trait.trait for trait in result]) def test_set_node_traits_over_limit(self): traits = ['trait%d' % n for n in range(51)] self.assertRaises(exception.InvalidParameterValue, - self.dbapi.set_node_traits, self.node.id, traits) + self.dbapi.set_node_traits, self.node.id, traits, + '1.0') # Ensure the traits were not set. result = self.dbapi.get_node_traits_by_node_id(self.node.id) self.assertEqual([], result) @@ -56,10 +59,11 @@ class DbNodeTraitTestCase(base.DbTestCase): def test_set_node_traits_node_not_exist(self): self.assertRaises(exception.NodeNotFound, self.dbapi.set_node_traits, '1234', - ['trait1', 'trait2']) + ['trait1', 'trait2'], '1.0') def test_get_node_traits_by_node_id(self): - self.dbapi.set_node_traits(self.node.id, ['trait1', 'trait2']) + db_utils.create_test_node_traits(node_id=self.node.id, + traits=['trait1', 'trait2']) result = self.dbapi.get_node_traits_by_node_id(self.node.id) self.assertEqual(self.node.id, result[0].node_id) self.assertItemsEqual(['trait1', 'trait2'], @@ -74,7 +78,8 @@ class DbNodeTraitTestCase(base.DbTestCase): self.dbapi.get_node_traits_by_node_id, '123') def test_unset_node_traits(self): - self.dbapi.set_node_traits(self.node.id, ['trait1', 'trait2']) + db_utils.create_test_node_traits(node_id=self.node.id, + traits=['trait1', 'trait2']) self.dbapi.unset_node_traits(self.node.id) result = self.dbapi.get_node_traits_by_node_id(self.node.id) self.assertEqual([], result) @@ -89,13 +94,13 @@ class DbNodeTraitTestCase(base.DbTestCase): self.dbapi.unset_node_traits, '123') def test_add_node_trait(self): - result = self.dbapi.add_node_trait(self.node.id, 'trait1') + result = self.dbapi.add_node_trait(self.node.id, 'trait1', '1.0') self.assertEqual(self.node.id, result.node_id) self.assertEqual('trait1', result.trait) def test_add_node_trait_duplicate(self): - self.dbapi.add_node_trait(self.node.id, 'trait1') - result = self.dbapi.add_node_trait(self.node.id, 'trait1') + self.dbapi.add_node_trait(self.node.id, 'trait1', '1.0') + result = self.dbapi.add_node_trait(self.node.id, 'trait1', '1.0') self.assertEqual(self.node.id, result.node_id) self.assertEqual('trait1', result.trait) result = self.dbapi.get_node_traits_by_node_id(self.node.id) @@ -103,36 +108,38 @@ class DbNodeTraitTestCase(base.DbTestCase): def test_add_node_trait_at_limit(self): traits = ['trait%d' % n for n in range(49)] - self.dbapi.set_node_traits(self.node.id, traits) + db_utils.create_test_node_traits(node_id=self.node.id, traits=traits) - result = self.dbapi.add_node_trait(self.node.id, 'trait49') + result = self.dbapi.add_node_trait(self.node.id, 'trait49', '1.0') self.assertEqual(self.node.id, result.node_id) self.assertEqual('trait49', result.trait) def test_add_node_trait_duplicate_at_limit(self): traits = ['trait%d' % n for n in range(50)] - self.dbapi.set_node_traits(self.node.id, traits) + db_utils.create_test_node_traits(node_id=self.node.id, traits=traits) - result = self.dbapi.add_node_trait(self.node.id, 'trait49') + result = self.dbapi.add_node_trait(self.node.id, 'trait49', '1.0') self.assertEqual(self.node.id, result.node_id) self.assertEqual('trait49', result.trait) def test_add_node_trait_over_limit(self): traits = ['trait%d' % n for n in range(50)] - self.dbapi.set_node_traits(self.node.id, traits) + db_utils.create_test_node_traits(node_id=self.node.id, traits=traits) self.assertRaises(exception.InvalidParameterValue, - self.dbapi.add_node_trait, self.node.id, 'trait50') + self.dbapi.add_node_trait, self.node.id, 'trait50', + '1.0') # Ensure the trait was not added. result = self.dbapi.get_node_traits_by_node_id(self.node.id) self.assertNotIn('trait50', [trait.trait for trait in result]) def test_add_node_trait_node_not_exist(self): self.assertRaises(exception.NodeNotFound, - self.dbapi.add_node_trait, '123', 'trait1') + self.dbapi.add_node_trait, '123', 'trait1', '1.0') def test_delete_node_trait(self): - self.dbapi.set_node_traits(self.node.id, ['trait1', 'trait2']) + db_utils.create_test_node_traits(node_id=self.node.id, + traits=['trait1', 'trait2']) self.dbapi.delete_node_trait(self.node.id, 'trait1') result = self.dbapi.get_node_traits_by_node_id(self.node.id) self.assertEqual(1, len(result)) @@ -147,7 +154,8 @@ class DbNodeTraitTestCase(base.DbTestCase): self.dbapi.delete_node_trait, '123', 'trait1') def test_node_trait_exists(self): - self.dbapi.set_node_traits(self.node.id, ['trait1', 'trait2']) + db_utils.create_test_node_traits(node_id=self.node.id, + traits=['trait1', 'trait2']) result = self.dbapi.node_trait_exists(self.node.id, 'trait1') self.assertTrue(result) diff --git a/ironic/tests/unit/db/test_nodes.py b/ironic/tests/unit/db/test_nodes.py index 01dea51aaa..0a010dc17e 100644 --- a/ironic/tests/unit/db/test_nodes.py +++ b/ironic/tests/unit/db/test_nodes.py @@ -66,7 +66,8 @@ class DbNodeTestCase(base.DbTestCase): def test_get_node_by_id(self): node = utils.create_test_node() self.dbapi.set_node_tags(node.id, ['tag1', 'tag2']) - self.dbapi.set_node_traits(node.id, ['trait1', 'trait2']) + utils.create_test_node_traits(node_id=node.id, + traits=['trait1', 'trait2']) res = self.dbapi.get_node_by_id(node.id) self.assertEqual(node.id, res.id) self.assertEqual(node.uuid, res.uuid) @@ -77,7 +78,8 @@ class DbNodeTestCase(base.DbTestCase): def test_get_node_by_uuid(self): node = utils.create_test_node() self.dbapi.set_node_tags(node.id, ['tag1', 'tag2']) - self.dbapi.set_node_traits(node.id, ['trait1', 'trait2']) + utils.create_test_node_traits(node_id=node.id, + traits=['trait1', 'trait2']) res = self.dbapi.get_node_by_uuid(node.uuid) self.assertEqual(node.id, res.id) self.assertEqual(node.uuid, res.uuid) @@ -88,7 +90,8 @@ class DbNodeTestCase(base.DbTestCase): def test_get_node_by_name(self): node = utils.create_test_node() self.dbapi.set_node_tags(node.id, ['tag1', 'tag2']) - self.dbapi.set_node_traits(node.id, ['trait1', 'trait2']) + utils.create_test_node_traits(node_id=node.id, + traits=['trait1', 'trait2']) res = self.dbapi.get_node_by_name(node.name) self.assertEqual(node.id, res.id) self.assertEqual(node.uuid, res.uuid) @@ -299,7 +302,8 @@ class DbNodeTestCase(base.DbTestCase): node = utils.create_test_node( instance_uuid='12345678-9999-0000-aaaa-123456789012') self.dbapi.set_node_tags(node.id, ['tag1', 'tag2']) - self.dbapi.set_node_traits(node.id, ['trait1', 'trait2']) + utils.create_test_node_traits(node_id=node.id, + traits=['trait1', 'trait2']) res = self.dbapi.get_node_by_instance(node.instance_uuid) self.assertEqual(node.uuid, res.uuid) @@ -528,7 +532,8 @@ class DbNodeTestCase(base.DbTestCase): def test_reserve_node(self): node = utils.create_test_node() self.dbapi.set_node_tags(node.id, ['tag1', 'tag2']) - self.dbapi.set_node_traits(node.id, ['trait1', 'trait2']) + utils.create_test_node_traits(node_id=node.id, + traits=['trait1', 'trait2']) uuid = node.uuid r1 = 'fake-reservation' diff --git a/ironic/tests/unit/db/utils.py b/ironic/tests/unit/db/utils.py index 946a4e3d65..71aa3d0987 100644 --- a/ironic/tests/unit/db/utils.py +++ b/ironic/tests/unit/db/utils.py @@ -513,4 +513,17 @@ def create_test_node_trait(**kw): """ trait = get_test_node_trait(**kw) dbapi = db_api.get_instance() - return dbapi.add_node_trait(trait['node_id'], trait['trait']) + return dbapi.add_node_trait(trait['node_id'], trait['trait'], + trait['version']) + + +def create_test_node_traits(traits, **kw): + """Create test node trait entries in DB and return NodeTrait DB objects. + + Function to be used to create test NodeTrait objects in the database. + + :param traits: a list of Strings; traits to create. + :param kw: kwargs with overriding values for trait's attributes. + :returns: a list of test NodeTrait DB objects. + """ + return [create_test_node_trait(trait=trait, **kw) for trait in traits]