diff --git a/ironic/db/sqlalchemy/api.py b/ironic/db/sqlalchemy/api.py index d3c85cbef6..65a3e3fceb 100644 --- a/ironic/db/sqlalchemy/api.py +++ b/ironic/db/sqlalchemy/api.py @@ -93,6 +93,15 @@ def add_port_filter(query, value): return add_identity_filter(query, value) +def add_port_filter_by_node(query, value): + if utils.is_int_like(value): + return query.filter_by(node_id=value) + else: + query = query.join(models.Node, + models.Port.node_id == models.Node.id) + return query.filter(models.Node.uuid == value) + + class Connection(api.Connection): """SqlAlchemy connection.""" @@ -221,10 +230,24 @@ class Connection(api.Connection): query = model_query(models.Node, session=session) query = add_identity_filter(query, node) + # Get node ID, if an UUID was supplied. The ID is + # required for deleting all ports, attached to the node. + if uuidutils.is_uuid_like(node): + try: + node_id = query.one()['id'] + except NoResultFound: + raise exception.NodeNotFound(node=node) + else: + node_id = node + count = query.delete() if count != 1: raise exception.NodeNotFound(node=node) + query = model_query(models.Port, session=session) + query = add_port_filter_by_node(query, node_id) + query.delete() + @objects.objectify(objects.Node) def update_node(self, node, values): session = get_session() @@ -256,19 +279,10 @@ class Connection(api.Connection): @objects.objectify(objects.Port) def get_ports_by_node(self, node): - session = get_session() + query = model_query(models.Port) + query = add_port_filter_by_node(query, node) - if utils.is_int_like(node): - query = session.query(models.Port).\ - filter_by(node_id=node) - else: - query = session.query(models.Port).\ - join(models.Node, - models.Port.node_id == models.Node.id).\ - filter(models.Node.uuid == node) - result = query.all() - - return result + return query.all() @objects.objectify(objects.Port) def create_port(self, values): diff --git a/ironic/tests/db/test_nodes.py b/ironic/tests/db/test_nodes.py index 8e259ebe85..348bc71a29 100644 --- a/ironic/tests/db/test_nodes.py +++ b/ironic/tests/db/test_nodes.py @@ -93,7 +93,41 @@ class DbNodeTestCase(base.DbTestCase): self.dbapi.destroy_node(n['id']) self.assertRaises(exception.NodeNotFound, - self.dbapi.destroy_node, n['id']) + self.dbapi.get_node, n['id']) + + def test_destroy_node_by_uuid(self): + n = self._create_test_node() + + self.dbapi.destroy_node(n['uuid']) + self.assertRaises(exception.NodeNotFound, + self.dbapi.get_node, n['uuid']) + + def test_destroy_node_that_does_not_exist(self): + self.assertRaises(exception.NodeNotFound, + self.dbapi.destroy_node, + '12345678-9999-0000-aaaa-123456789012') + + def test_ports_get_destroyed_after_destroying_a_node(self): + n = self._create_test_node() + node_id = n['id'] + + p = utils.get_test_port(node_id=node_id) + p = self.dbapi.create_port(p) + + self.dbapi.destroy_node(node_id) + + self.assertRaises(exception.PortNotFound, self.dbapi.get_port, p['id']) + + def test_ports_get_destroyed_after_destroying_a_node_by_uuid(self): + n = self._create_test_node() + node_id = n['id'] + + p = utils.get_test_port(node_id=node_id) + p = self.dbapi.create_port(p) + + self.dbapi.destroy_node(n['uuid']) + + self.assertRaises(exception.PortNotFound, self.dbapi.get_port, p['id']) def test_update_node(self): n = self._create_test_node()