Add a version argument to traits DB API

We need to save the Trait object version to the DB when creating traits.

Change-Id: I2c43c27455de6e7017477b0f12b18873c66455ad
Partial-Bug: #1722194
This commit is contained in:
Mark Goddard 2018-01-18 17:02:50 +00:00
parent 58ebae91ba
commit c3ed7dfb9e
5 changed files with 62 additions and 32 deletions

@ -924,13 +924,14 @@ class Connection(object):
# TODO(rloo) Delete this in Rocky cycle. # TODO(rloo) Delete this in Rocky cycle.
@abc.abstractmethod @abc.abstractmethod
def set_node_traits(self, node_id, traits): def set_node_traits(self, node_id, traits, version):
"""Replace all of the node traits with specified list of traits. """Replace all of the node traits with specified list of traits.
This ignores duplicate traits in the specified list. This ignores duplicate traits in the specified list.
:param node_id: The id of a node. :param node_id: The id of a node.
:param traits: List of traits. :param traits: List of traits.
:param version: the version of the object.Trait.
:returns: A list of NodeTrait objects. :returns: A list of NodeTrait objects.
:raises: InvalidParameterValue if setting the traits would exceed the :raises: InvalidParameterValue if setting the traits would exceed the
per-node traits limit. per-node traits limit.
@ -955,7 +956,7 @@ class Connection(object):
""" """
@abc.abstractmethod @abc.abstractmethod
def add_node_trait(self, node_id, trait): def add_node_trait(self, node_id, trait, version):
"""Add trait to the node. """Add trait to the node.
If the node_id and trait pair already exists, this should still If the node_id and trait pair already exists, this should still
@ -963,6 +964,7 @@ class Connection(object):
:param node_id: The id of a node. :param node_id: The id of a node.
:param trait: A trait string. :param trait: A trait string.
:param version: the version of the object.Trait.
:returns: the NodeTrait object. :returns: the NodeTrait object.
:raises: InvalidParameterValue if adding the trait would exceed the :raises: InvalidParameterValue if adding the trait would exceed the
per-node traits limit. per-node traits limit.

@ -1306,7 +1306,7 @@ class Connection(api.Connection):
max_traits=MAX_TRAITS_PER_NODE) max_traits=MAX_TRAITS_PER_NODE)
@oslo_db_api.retry_on_deadlock @oslo_db_api.retry_on_deadlock
def set_node_traits(self, node_id, traits): def set_node_traits(self, node_id, traits, version):
# Remove duplicate traits # Remove duplicate traits
traits = set(traits) traits = set(traits)
@ -1317,7 +1317,8 @@ class Connection(api.Connection):
self.unset_node_traits(node_id) self.unset_node_traits(node_id)
node_traits = [] node_traits = []
for trait in traits: for trait in traits:
node_trait = models.NodeTrait(trait=trait, node_id=node_id) node_trait = models.NodeTrait(trait=trait, node_id=node_id,
version=version)
session.add(node_trait) session.add(node_trait)
node_traits.append(node_trait) node_traits.append(node_trait)
@ -1337,8 +1338,9 @@ class Connection(api.Connection):
return result return result
@oslo_db_api.retry_on_deadlock @oslo_db_api.retry_on_deadlock
def add_node_trait(self, node_id, trait): def add_node_trait(self, node_id, trait, version):
node_trait = models.NodeTrait(trait=trait, node_id=node_id) node_trait = models.NodeTrait(trait=trait, node_id=node_id,
version=version)
self._check_node_exists(node_id) self._check_node_exists(node_id)
try: try:

@ -24,31 +24,34 @@ class DbNodeTraitTestCase(base.DbTestCase):
self.node = db_utils.create_test_node() self.node = db_utils.create_test_node()
def test_set_node_traits(self): def test_set_node_traits(self):
result = self.dbapi.set_node_traits(self.node.id, ['trait1', 'trait2']) result = self.dbapi.set_node_traits(self.node.id, ['trait1', 'trait2'],
'1.0')
self.assertEqual(self.node.id, result[0].node_id) self.assertEqual(self.node.id, result[0].node_id)
self.assertItemsEqual(['trait1', 'trait2'], self.assertItemsEqual(['trait1', 'trait2'],
[trait.trait for trait in result]) [trait.trait for trait in result])
result = self.dbapi.set_node_traits(self.node.id, []) result = self.dbapi.set_node_traits(self.node.id, [], '1.0')
self.assertEqual([], result) self.assertEqual([], result)
def test_set_node_traits_duplicate(self): def test_set_node_traits_duplicate(self):
result = self.dbapi.set_node_traits(self.node.id, result = self.dbapi.set_node_traits(self.node.id,
['trait1', 'trait2', 'trait2']) ['trait1', 'trait2', 'trait2'],
'1.0')
self.assertEqual(self.node.id, result[0].node_id) self.assertEqual(self.node.id, result[0].node_id)
self.assertItemsEqual(['trait1', 'trait2'], self.assertItemsEqual(['trait1', 'trait2'],
[trait.trait for trait in result]) [trait.trait for trait in result])
def test_set_node_traits_at_limit(self): def test_set_node_traits_at_limit(self):
traits = ['trait%d' % n for n in range(50)] traits = ['trait%d' % n for n in range(50)]
result = self.dbapi.set_node_traits(self.node.id, traits) result = self.dbapi.set_node_traits(self.node.id, traits, '1.0')
self.assertEqual(self.node.id, result[0].node_id) self.assertEqual(self.node.id, result[0].node_id)
self.assertItemsEqual(traits, [trait.trait for trait in result]) self.assertItemsEqual(traits, [trait.trait for trait in result])
def test_set_node_traits_over_limit(self): def test_set_node_traits_over_limit(self):
traits = ['trait%d' % n for n in range(51)] traits = ['trait%d' % n for n in range(51)]
self.assertRaises(exception.InvalidParameterValue, self.assertRaises(exception.InvalidParameterValue,
self.dbapi.set_node_traits, self.node.id, traits) self.dbapi.set_node_traits, self.node.id, traits,
'1.0')
# Ensure the traits were not set. # Ensure the traits were not set.
result = self.dbapi.get_node_traits_by_node_id(self.node.id) result = self.dbapi.get_node_traits_by_node_id(self.node.id)
self.assertEqual([], result) self.assertEqual([], result)
@ -56,10 +59,11 @@ class DbNodeTraitTestCase(base.DbTestCase):
def test_set_node_traits_node_not_exist(self): def test_set_node_traits_node_not_exist(self):
self.assertRaises(exception.NodeNotFound, self.assertRaises(exception.NodeNotFound,
self.dbapi.set_node_traits, '1234', self.dbapi.set_node_traits, '1234',
['trait1', 'trait2']) ['trait1', 'trait2'], '1.0')
def test_get_node_traits_by_node_id(self): def test_get_node_traits_by_node_id(self):
self.dbapi.set_node_traits(self.node.id, ['trait1', 'trait2']) db_utils.create_test_node_traits(node_id=self.node.id,
traits=['trait1', 'trait2'])
result = self.dbapi.get_node_traits_by_node_id(self.node.id) result = self.dbapi.get_node_traits_by_node_id(self.node.id)
self.assertEqual(self.node.id, result[0].node_id) self.assertEqual(self.node.id, result[0].node_id)
self.assertItemsEqual(['trait1', 'trait2'], self.assertItemsEqual(['trait1', 'trait2'],
@ -74,7 +78,8 @@ class DbNodeTraitTestCase(base.DbTestCase):
self.dbapi.get_node_traits_by_node_id, '123') self.dbapi.get_node_traits_by_node_id, '123')
def test_unset_node_traits(self): def test_unset_node_traits(self):
self.dbapi.set_node_traits(self.node.id, ['trait1', 'trait2']) db_utils.create_test_node_traits(node_id=self.node.id,
traits=['trait1', 'trait2'])
self.dbapi.unset_node_traits(self.node.id) self.dbapi.unset_node_traits(self.node.id)
result = self.dbapi.get_node_traits_by_node_id(self.node.id) result = self.dbapi.get_node_traits_by_node_id(self.node.id)
self.assertEqual([], result) self.assertEqual([], result)
@ -89,13 +94,13 @@ class DbNodeTraitTestCase(base.DbTestCase):
self.dbapi.unset_node_traits, '123') self.dbapi.unset_node_traits, '123')
def test_add_node_trait(self): def test_add_node_trait(self):
result = self.dbapi.add_node_trait(self.node.id, 'trait1') result = self.dbapi.add_node_trait(self.node.id, 'trait1', '1.0')
self.assertEqual(self.node.id, result.node_id) self.assertEqual(self.node.id, result.node_id)
self.assertEqual('trait1', result.trait) self.assertEqual('trait1', result.trait)
def test_add_node_trait_duplicate(self): def test_add_node_trait_duplicate(self):
self.dbapi.add_node_trait(self.node.id, 'trait1') self.dbapi.add_node_trait(self.node.id, 'trait1', '1.0')
result = self.dbapi.add_node_trait(self.node.id, 'trait1') result = self.dbapi.add_node_trait(self.node.id, 'trait1', '1.0')
self.assertEqual(self.node.id, result.node_id) self.assertEqual(self.node.id, result.node_id)
self.assertEqual('trait1', result.trait) self.assertEqual('trait1', result.trait)
result = self.dbapi.get_node_traits_by_node_id(self.node.id) result = self.dbapi.get_node_traits_by_node_id(self.node.id)
@ -103,36 +108,38 @@ class DbNodeTraitTestCase(base.DbTestCase):
def test_add_node_trait_at_limit(self): def test_add_node_trait_at_limit(self):
traits = ['trait%d' % n for n in range(49)] traits = ['trait%d' % n for n in range(49)]
self.dbapi.set_node_traits(self.node.id, traits) db_utils.create_test_node_traits(node_id=self.node.id, traits=traits)
result = self.dbapi.add_node_trait(self.node.id, 'trait49') result = self.dbapi.add_node_trait(self.node.id, 'trait49', '1.0')
self.assertEqual(self.node.id, result.node_id) self.assertEqual(self.node.id, result.node_id)
self.assertEqual('trait49', result.trait) self.assertEqual('trait49', result.trait)
def test_add_node_trait_duplicate_at_limit(self): def test_add_node_trait_duplicate_at_limit(self):
traits = ['trait%d' % n for n in range(50)] traits = ['trait%d' % n for n in range(50)]
self.dbapi.set_node_traits(self.node.id, traits) db_utils.create_test_node_traits(node_id=self.node.id, traits=traits)
result = self.dbapi.add_node_trait(self.node.id, 'trait49') result = self.dbapi.add_node_trait(self.node.id, 'trait49', '1.0')
self.assertEqual(self.node.id, result.node_id) self.assertEqual(self.node.id, result.node_id)
self.assertEqual('trait49', result.trait) self.assertEqual('trait49', result.trait)
def test_add_node_trait_over_limit(self): def test_add_node_trait_over_limit(self):
traits = ['trait%d' % n for n in range(50)] traits = ['trait%d' % n for n in range(50)]
self.dbapi.set_node_traits(self.node.id, traits) db_utils.create_test_node_traits(node_id=self.node.id, traits=traits)
self.assertRaises(exception.InvalidParameterValue, self.assertRaises(exception.InvalidParameterValue,
self.dbapi.add_node_trait, self.node.id, 'trait50') self.dbapi.add_node_trait, self.node.id, 'trait50',
'1.0')
# Ensure the trait was not added. # Ensure the trait was not added.
result = self.dbapi.get_node_traits_by_node_id(self.node.id) result = self.dbapi.get_node_traits_by_node_id(self.node.id)
self.assertNotIn('trait50', [trait.trait for trait in result]) self.assertNotIn('trait50', [trait.trait for trait in result])
def test_add_node_trait_node_not_exist(self): def test_add_node_trait_node_not_exist(self):
self.assertRaises(exception.NodeNotFound, self.assertRaises(exception.NodeNotFound,
self.dbapi.add_node_trait, '123', 'trait1') self.dbapi.add_node_trait, '123', 'trait1', '1.0')
def test_delete_node_trait(self): def test_delete_node_trait(self):
self.dbapi.set_node_traits(self.node.id, ['trait1', 'trait2']) db_utils.create_test_node_traits(node_id=self.node.id,
traits=['trait1', 'trait2'])
self.dbapi.delete_node_trait(self.node.id, 'trait1') self.dbapi.delete_node_trait(self.node.id, 'trait1')
result = self.dbapi.get_node_traits_by_node_id(self.node.id) result = self.dbapi.get_node_traits_by_node_id(self.node.id)
self.assertEqual(1, len(result)) self.assertEqual(1, len(result))
@ -147,7 +154,8 @@ class DbNodeTraitTestCase(base.DbTestCase):
self.dbapi.delete_node_trait, '123', 'trait1') self.dbapi.delete_node_trait, '123', 'trait1')
def test_node_trait_exists(self): def test_node_trait_exists(self):
self.dbapi.set_node_traits(self.node.id, ['trait1', 'trait2']) db_utils.create_test_node_traits(node_id=self.node.id,
traits=['trait1', 'trait2'])
result = self.dbapi.node_trait_exists(self.node.id, 'trait1') result = self.dbapi.node_trait_exists(self.node.id, 'trait1')
self.assertTrue(result) self.assertTrue(result)

@ -66,7 +66,8 @@ class DbNodeTestCase(base.DbTestCase):
def test_get_node_by_id(self): def test_get_node_by_id(self):
node = utils.create_test_node() node = utils.create_test_node()
self.dbapi.set_node_tags(node.id, ['tag1', 'tag2']) self.dbapi.set_node_tags(node.id, ['tag1', 'tag2'])
self.dbapi.set_node_traits(node.id, ['trait1', 'trait2']) utils.create_test_node_traits(node_id=node.id,
traits=['trait1', 'trait2'])
res = self.dbapi.get_node_by_id(node.id) res = self.dbapi.get_node_by_id(node.id)
self.assertEqual(node.id, res.id) self.assertEqual(node.id, res.id)
self.assertEqual(node.uuid, res.uuid) self.assertEqual(node.uuid, res.uuid)
@ -77,7 +78,8 @@ class DbNodeTestCase(base.DbTestCase):
def test_get_node_by_uuid(self): def test_get_node_by_uuid(self):
node = utils.create_test_node() node = utils.create_test_node()
self.dbapi.set_node_tags(node.id, ['tag1', 'tag2']) self.dbapi.set_node_tags(node.id, ['tag1', 'tag2'])
self.dbapi.set_node_traits(node.id, ['trait1', 'trait2']) utils.create_test_node_traits(node_id=node.id,
traits=['trait1', 'trait2'])
res = self.dbapi.get_node_by_uuid(node.uuid) res = self.dbapi.get_node_by_uuid(node.uuid)
self.assertEqual(node.id, res.id) self.assertEqual(node.id, res.id)
self.assertEqual(node.uuid, res.uuid) self.assertEqual(node.uuid, res.uuid)
@ -88,7 +90,8 @@ class DbNodeTestCase(base.DbTestCase):
def test_get_node_by_name(self): def test_get_node_by_name(self):
node = utils.create_test_node() node = utils.create_test_node()
self.dbapi.set_node_tags(node.id, ['tag1', 'tag2']) self.dbapi.set_node_tags(node.id, ['tag1', 'tag2'])
self.dbapi.set_node_traits(node.id, ['trait1', 'trait2']) utils.create_test_node_traits(node_id=node.id,
traits=['trait1', 'trait2'])
res = self.dbapi.get_node_by_name(node.name) res = self.dbapi.get_node_by_name(node.name)
self.assertEqual(node.id, res.id) self.assertEqual(node.id, res.id)
self.assertEqual(node.uuid, res.uuid) self.assertEqual(node.uuid, res.uuid)
@ -299,7 +302,8 @@ class DbNodeTestCase(base.DbTestCase):
node = utils.create_test_node( node = utils.create_test_node(
instance_uuid='12345678-9999-0000-aaaa-123456789012') instance_uuid='12345678-9999-0000-aaaa-123456789012')
self.dbapi.set_node_tags(node.id, ['tag1', 'tag2']) self.dbapi.set_node_tags(node.id, ['tag1', 'tag2'])
self.dbapi.set_node_traits(node.id, ['trait1', 'trait2']) utils.create_test_node_traits(node_id=node.id,
traits=['trait1', 'trait2'])
res = self.dbapi.get_node_by_instance(node.instance_uuid) res = self.dbapi.get_node_by_instance(node.instance_uuid)
self.assertEqual(node.uuid, res.uuid) self.assertEqual(node.uuid, res.uuid)
@ -528,7 +532,8 @@ class DbNodeTestCase(base.DbTestCase):
def test_reserve_node(self): def test_reserve_node(self):
node = utils.create_test_node() node = utils.create_test_node()
self.dbapi.set_node_tags(node.id, ['tag1', 'tag2']) self.dbapi.set_node_tags(node.id, ['tag1', 'tag2'])
self.dbapi.set_node_traits(node.id, ['trait1', 'trait2']) utils.create_test_node_traits(node_id=node.id,
traits=['trait1', 'trait2'])
uuid = node.uuid uuid = node.uuid
r1 = 'fake-reservation' r1 = 'fake-reservation'

@ -513,4 +513,17 @@ def create_test_node_trait(**kw):
""" """
trait = get_test_node_trait(**kw) trait = get_test_node_trait(**kw)
dbapi = db_api.get_instance() dbapi = db_api.get_instance()
return dbapi.add_node_trait(trait['node_id'], trait['trait']) return dbapi.add_node_trait(trait['node_id'], trait['trait'],
trait['version'])
def create_test_node_traits(traits, **kw):
"""Create test node trait entries in DB and return NodeTrait DB objects.
Function to be used to create test NodeTrait objects in the database.
:param traits: a list of Strings; traits to create.
:param kw: kwargs with overriding values for trait's attributes.
:returns: a list of test NodeTrait DB objects.
"""
return [create_test_node_trait(trait=trait, **kw) for trait in traits]