diff --git a/ironic/db/sqlalchemy/api.py b/ironic/db/sqlalchemy/api.py index 5e25c0d3c0..2d00fc5d1f 100644 --- a/ironic/db/sqlalchemy/api.py +++ b/ironic/db/sqlalchemy/api.py @@ -29,6 +29,7 @@ from oslo_utils import strutils from oslo_utils import timeutils from oslo_utils import uuidutils from sqlalchemy.orm.exc import NoResultFound +from sqlalchemy.orm import joinedload from sqlalchemy import sql from ironic.common import exception @@ -63,6 +64,10 @@ def _session_for_write(): return enginefacade.writer.using(_CONTEXT) +def _get_node_query_with_tags(): + return model_query(models.Node).options(joinedload('tags')) + + def model_query(model, *args, **kwargs): """Query helper for simpler session usage. @@ -239,14 +244,14 @@ class Connection(api.Connection): def get_node_list(self, filters=None, limit=None, marker=None, sort_key=None, sort_dir=None): - query = model_query(models.Node) + query = _get_node_query_with_tags() query = self._add_nodes_filters(query, filters) return _paginate_query(models.Node, limit, marker, sort_key, sort_dir, query) def reserve_node(self, tag, node_id): with _session_for_write(): - query = model_query(models.Node) + query = _get_node_query_with_tags() query = add_identity_filter(query, node_id) # be optimistic and assume we usually create a reservation count = query.filter_by(reservation=None).update( @@ -311,24 +316,29 @@ class Connection(api.Connection): instance_uuid=values['instance_uuid'], node=values['uuid']) raise exception.NodeAlreadyExists(uuid=values['uuid']) + # Set tags to [] for new created node + node['tags'] = [] return node def get_node_by_id(self, node_id): - query = model_query(models.Node).filter_by(id=node_id) + query = _get_node_query_with_tags() + query = query.filter_by(id=node_id) try: return query.one() except NoResultFound: raise exception.NodeNotFound(node=node_id) def get_node_by_uuid(self, node_uuid): - query = model_query(models.Node).filter_by(uuid=node_uuid) + query = _get_node_query_with_tags() + query = query.filter_by(uuid=node_uuid) try: return query.one() except NoResultFound: raise exception.NodeNotFound(node=node_uuid) def get_node_by_name(self, node_name): - query = model_query(models.Node).filter_by(name=node_name) + query = _get_node_query_with_tags() + query = query.filter_by(name=node_name) try: return query.one() except NoResultFound: @@ -338,8 +348,8 @@ class Connection(api.Connection): if not uuidutils.is_uuid_like(instance): raise exception.InvalidUUID(uuid=instance) - query = (model_query(models.Node) - .filter_by(instance_uuid=instance)) + query = _get_node_query_with_tags() + query = query.filter_by(instance_uuid=instance) try: result = query.one() diff --git a/ironic/db/sqlalchemy/models.py b/ironic/db/sqlalchemy/models.py index 127233ffa2..9fd73e3b00 100644 --- a/ironic/db/sqlalchemy/models.py +++ b/ironic/db/sqlalchemy/models.py @@ -27,6 +27,7 @@ from sqlalchemy import Boolean, Column, DateTime, Index from sqlalchemy import ForeignKey, Integer from sqlalchemy import schema, String, Text from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import orm from ironic.common.i18n import _ from ironic.common import paths @@ -197,3 +198,10 @@ class NodeTag(Base): node_id = Column(Integer, ForeignKey('nodes.id'), primary_key=True, nullable=False) tag = Column(String(255), primary_key=True, nullable=False) + + node = orm.relationship( + "Node", + backref='tags', + primaryjoin='and_(NodeTag.node_id == Node.id)', + foreign_keys=node_id + ) diff --git a/ironic/tests/unit/db/test_nodes.py b/ironic/tests/unit/db/test_nodes.py index 3587024758..33522a1e4c 100644 --- a/ironic/tests/unit/db/test_nodes.py +++ b/ironic/tests/unit/db/test_nodes.py @@ -61,22 +61,28 @@ class DbNodeTestCase(base.DbTestCase): def test_get_node_by_id(self): node = utils.create_test_node() + self.dbapi.set_node_tags(node.id, ['tag1', 'tag2']) res = self.dbapi.get_node_by_id(node.id) self.assertEqual(node.id, res.id) self.assertEqual(node.uuid, res.uuid) + self.assertItemsEqual(['tag1', 'tag2'], [tag.tag for tag in res.tags]) def test_get_node_by_uuid(self): node = utils.create_test_node() + self.dbapi.set_node_tags(node.id, ['tag1', 'tag2']) res = self.dbapi.get_node_by_uuid(node.uuid) self.assertEqual(node.id, res.id) self.assertEqual(node.uuid, res.uuid) + self.assertItemsEqual(['tag1', 'tag2'], [tag.tag for tag in res.tags]) def test_get_node_by_name(self): node = utils.create_test_node() + self.dbapi.set_node_tags(node.id, ['tag1', 'tag2']) res = self.dbapi.get_node_by_name(node.name) self.assertEqual(node.id, res.id) self.assertEqual(node.uuid, res.uuid) self.assertEqual(node.name, res.name) + self.assertItemsEqual(['tag1', 'tag2'], [tag.tag for tag in res.tags]) def test_get_node_that_does_not_exist(self): self.assertRaises(exception.NodeNotFound, @@ -217,6 +223,8 @@ class DbNodeTestCase(base.DbTestCase): res = self.dbapi.get_node_list() res_uuids = [r.uuid for r in res] six.assertCountEqual(self, uuids, res_uuids) + for r in res: + self.assertEqual([], r.tags) def test_get_node_list_with_filters(self): ch1 = utils.create_test_chassis(uuid=uuidutils.generate_uuid()) @@ -272,9 +280,11 @@ class DbNodeTestCase(base.DbTestCase): def test_get_node_by_instance(self): node = utils.create_test_node( instance_uuid='12345678-9999-0000-aaaa-123456789012') + self.dbapi.set_node_tags(node.id, ['tag1', 'tag2']) res = self.dbapi.get_node_by_instance(node.instance_uuid) self.assertEqual(node.uuid, res.uuid) + self.assertItemsEqual(['tag1', 'tag2'], [tag.tag for tag in res.tags]) def test_get_node_by_instance_wrong_uuid(self): utils.create_test_node( @@ -446,12 +456,14 @@ class DbNodeTestCase(base.DbTestCase): def test_reserve_node(self): node = utils.create_test_node() + self.dbapi.set_node_tags(node.id, ['tag1', 'tag2']) uuid = node.uuid r1 = 'fake-reservation' # reserve the node - self.dbapi.reserve_node(r1, uuid) + res = self.dbapi.reserve_node(r1, uuid) + self.assertItemsEqual(['tag1', 'tag2'], [tag.tag for tag in res.tags]) # check reservation res = self.dbapi.get_node_by_uuid(uuid)