From db0c42a95599955b27624429f16ea9f73860729e Mon Sep 17 00:00:00 2001 From: Ruby Loo Date: Tue, 2 May 2017 02:18:08 +0000 Subject: [PATCH] Check if sort key is allowed in API version This checks the sort key, to make sure the specified field is allowed/available in the specified API microversion. If it is not allowed, a 406 HTTP status is returned. This affects requests to get lists of nodes, port groups, and ports. Change-Id: Id5fb44b8b7fe989514dbae4b60cef4a34d47e52b Closes-Bug: #1659419 --- ironic/api/controllers/v1/node.py | 2 + ironic/api/controllers/v1/port.py | 43 +++---- ironic/api/controllers/v1/portgroup.py | 2 + ironic/tests/unit/api/v1/test_nodes.py | 37 +++++- ironic/tests/unit/api/v1/test_portgroups.py | 38 +++++++ ironic/tests/unit/api/v1/test_ports.py | 105 +++++++++++++++++- ...rt_key_allowed_field-091f8eeedd0a2ace.yaml | 6 + 7 files changed, 209 insertions(+), 24 deletions(-) create mode 100644 releasenotes/notes/sort_key_allowed_field-091f8eeedd0a2ace.yaml diff --git a/ironic/api/controllers/v1/node.py b/ironic/api/controllers/v1/node.py index d036f71caa..9e5929d1ad 100644 --- a/ironic/api/controllers/v1/node.py +++ b/ironic/api/controllers/v1/node.py @@ -1442,6 +1442,7 @@ class NodesController(rest.RestController): api_utils.check_allow_specify_fields(fields) api_utils.check_allowed_fields(fields) + api_utils.check_allowed_fields([sort_key]) api_utils.check_for_invalid_state_and_allow_filter(provision_state) api_utils.check_allow_specify_driver(driver) api_utils.check_allow_specify_resource_class(resource_class) @@ -1495,6 +1496,7 @@ class NodesController(rest.RestController): api_utils.check_for_invalid_state_and_allow_filter(provision_state) api_utils.check_allow_specify_driver(driver) api_utils.check_allow_specify_resource_class(resource_class) + api_utils.check_allowed_fields([sort_key]) # /detail should only work against collections parent = pecan.request.path.split('/')[:-1][-1] if parent != "nodes": diff --git a/ironic/api/controllers/v1/port.py b/ironic/api/controllers/v1/port.py index 0a25fb2f5c..4999878ba1 100644 --- a/ironic/api/controllers/v1/port.py +++ b/ironic/api/controllers/v1/port.py @@ -354,6 +354,24 @@ class PortsController(rest.RestController): except exception.PortNotFound: return [] + def _check_allowed_port_fields(self, fields): + """Check if fetching a particular field of a port is allowed. + + Check if the required version is being requested for fields + that are only allowed to be fetched in a particular API version. + + :param fields: list or set of fields to check + :raises: NotAcceptable if a field is not allowed + """ + if fields is None: + return + if (not api_utils.allow_port_advanced_net_fields() and + set(fields).intersection(self.advanced_net_fields)): + raise exception.NotAcceptable() + if ('portgroup_uuid' in fields and not + api_utils.allow_portgroups_subcontrollers()): + raise exception.NotAcceptable() + @METRICS.timer('PortsController.get_all') @expose.expose(PortCollection, types.uuid_or_name, types.uuid, types.macaddress, types.uuid, int, wtypes.text, @@ -389,14 +407,8 @@ class PortsController(rest.RestController): policy.authorize('baremetal:port:get', cdict, cdict) api_utils.check_allow_specify_fields(fields) - if fields: - if (not api_utils.allow_port_advanced_net_fields() and - set(fields).intersection(self.advanced_net_fields)): - raise exception.NotAcceptable() - if ('portgroup_uuid' in fields and not - api_utils.allow_portgroups_subcontrollers()): - raise exception.NotAcceptable() - + self._check_allowed_port_fields(fields) + self._check_allowed_port_fields([sort_key]) if portgroup and not api_utils.allow_portgroups_subcontrollers(): raise exception.NotAcceptable() @@ -446,6 +458,7 @@ class PortsController(rest.RestController): cdict = pecan.request.context.to_policy_values() policy.authorize('baremetal:port:get', cdict, cdict) + self._check_allowed_port_fields([sort_key]) if portgroup and not api_utils.allow_portgroups_subcontrollers(): raise exception.NotAcceptable() @@ -504,12 +517,7 @@ class PortsController(rest.RestController): raise exception.OperationNotPermitted() pdict = port.as_dict() - if (not api_utils.allow_port_advanced_net_fields() and - set(pdict).intersection(self.advanced_net_fields)): - raise exception.NotAcceptable() - if (not api_utils.allow_portgroups_subcontrollers() and - 'portgroup_uuid' in pdict): - raise exception.NotAcceptable() + self._check_allowed_port_fields(pdict) extra = pdict.get('extra') vif = extra.get('vif_port_id') if extra else None @@ -569,12 +577,7 @@ class PortsController(rest.RestController): if (api_utils.get_patch_values(patch, field_path) or api_utils.is_path_removed(patch, field_path)): fields_to_check.add(field) - if (fields_to_check.intersection(self.advanced_net_fields) and - not api_utils.allow_port_advanced_net_fields()): - raise exception.NotAcceptable() - if ('portgroup_uuid' in fields_to_check and - not api_utils.allow_portgroups_subcontrollers()): - raise exception.NotAcceptable() + self._check_allowed_port_fields(fields_to_check) rpc_port = objects.Port.get_by_uuid(context, port_uuid) try: diff --git a/ironic/api/controllers/v1/portgroup.py b/ironic/api/controllers/v1/portgroup.py index 1d884840d1..74c4e3b71a 100644 --- a/ironic/api/controllers/v1/portgroup.py +++ b/ironic/api/controllers/v1/portgroup.py @@ -356,6 +356,7 @@ class PortgroupsController(pecan.rest.RestController): policy.authorize('baremetal:portgroup:get', cdict, cdict) api_utils.check_allowed_portgroup_fields(fields) + api_utils.check_allowed_portgroup_fields([sort_key]) if fields is None: fields = _DEFAULT_RETURN_FIELDS @@ -389,6 +390,7 @@ class PortgroupsController(pecan.rest.RestController): cdict = pecan.request.context.to_policy_values() policy.authorize('baremetal:portgroup:get', cdict, cdict) + api_utils.check_allowed_portgroup_fields([sort_key]) # NOTE: /detail should only work against collections parent = pecan.request.path.split('/')[:-1][-1] diff --git a/ironic/tests/unit/api/v1/test_nodes.py b/ironic/tests/unit/api/v1/test_nodes.py index f9ce20c900..a061c37514 100644 --- a/ironic/tests/unit/api/v1/test_nodes.py +++ b/ironic/tests/unit/api/v1/test_nodes.py @@ -1,5 +1,3 @@ -# -*- encoding: utf-8 -*- -# # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at @@ -504,6 +502,41 @@ class TestListNodes(test_api_base.BaseApiTest): self.assertEqual('application/json', response.content_type) self.assertIn(invalid_key, response.json['error_message']) + def _test_sort_key_allowed(self, detail=False): + node_uuids = [] + for id in range(3, 0, -1): + node = obj_utils.create_test_node(self.context, + uuid=uuidutils.generate_uuid(), + resource_class='rc_%s' % id) + node_uuids.append(node.uuid) + node_uuids.reverse() + headers = {'X-OpenStack-Ironic-API-Version': '1.21'} + detail_str = '/detail' if detail else '' + data = self.get_json('/nodes%s?sort_key=resource_class' % detail_str, + headers=headers) + data_uuids = [n['uuid'] for n in data['nodes']] + self.assertEqual(node_uuids, data_uuids) + + def test_sort_key_allowed(self): + self._test_sort_key_allowed() + + def test_detail_sort_key_allowed(self): + self._test_sort_key_allowed(detail=True) + + def _test_sort_key_not_allowed(self, detail=False): + headers = {'X-OpenStack-Ironic-API-Version': '1.20'} + detail_str = '/detail' if detail else '' + resp = self.get_json('/nodes%s?sort_key=resource_class' % detail_str, + headers=headers, expect_errors=True) + self.assertEqual(http_client.NOT_ACCEPTABLE, resp.status_int) + self.assertEqual('application/json', resp.content_type) + + def test_sort_key_not_allowed(self): + self._test_sort_key_not_allowed() + + def test_detail_sort_key_not_allowed(self): + self._test_sort_key_not_allowed(detail=True) + def test_ports_subresource_link(self): node = obj_utils.create_test_node(self.context) data = self.get_json('/nodes/%s' % node.uuid) diff --git a/ironic/tests/unit/api/v1/test_portgroups.py b/ironic/tests/unit/api/v1/test_portgroups.py index d2901823bc..40120990a5 100644 --- a/ironic/tests/unit/api/v1/test_portgroups.py +++ b/ironic/tests/unit/api/v1/test_portgroups.py @@ -391,6 +391,44 @@ class TestListPortgroups(test_api_base.BaseApiTest): self.assertEqual('application/json', response.content_type) self.assertIn(invalid_key, response.json['error_message']) + def _test_sort_key_allowed(self, detail=False): + portgroup_uuids = [] + for id_ in range(3, 0, -1): + portgroup = obj_utils.create_test_portgroup( + self.context, + node_id=self.node.id, + uuid=uuidutils.generate_uuid(), + name='portgroup%s' % id_, + address='52:54:00:cf:2d:3%s' % id_, + mode='mode_%s' % id_) + portgroup_uuids.append(portgroup.uuid) + portgroup_uuids.reverse() + detail_str = '/detail' if detail else '' + data = self.get_json('/portgroups%s?sort_key=mode' % detail_str, + headers=self.headers) + data_uuids = [p['uuid'] for p in data['portgroups']] + self.assertEqual(portgroup_uuids, data_uuids) + + def test_sort_key_allowed(self): + self._test_sort_key_allowed() + + def test_detail_sort_key_allowed(self): + self._test_sort_key_allowed(detail=True) + + def _test_sort_key_not_allowed(self, detail=False): + headers = {api_base.Version.string: '1.25'} + detail_str = '/detail' if detail else '' + response = self.get_json('/portgroups%s?sort_key=mode' % detail_str, + headers=headers, expect_errors=True) + self.assertEqual(http_client.NOT_ACCEPTABLE, response.status_int) + self.assertEqual('application/json', response.content_type) + + def test_sort_key_not_allowed(self): + self._test_sort_key_not_allowed() + + def test_detail_sort_key_not_allowed(self): + self._test_sort_key_not_allowed(detail=True) + @mock.patch.object(api_utils, 'get_rpc_node') def test_get_all_by_node_name_ok(self, mock_get_rpc_node): # GET /v1/portgroups specifying node_name - success diff --git a/ironic/tests/unit/api/v1/test_ports.py b/ironic/tests/unit/api/v1/test_ports.py index 00a3c459e4..13d3a43dda 100644 --- a/ironic/tests/unit/api/v1/test_ports.py +++ b/ironic/tests/unit/api/v1/test_ports.py @@ -1,5 +1,3 @@ -# -*- encoding: utf-8 -*- -# # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at @@ -68,6 +66,72 @@ class TestPortObject(base.TestCase): self.assertEqual(wtypes.Unset, port.extra) +@mock.patch.object(api_utils, 'allow_portgroups_subcontrollers', autospec=True) +@mock.patch.object(api_utils, 'allow_port_advanced_net_fields', autospec=True) +class TestPortsController__CheckAllowedPortFields(base.TestCase): + + def setUp(self): + super(TestPortsController__CheckAllowedPortFields, self).setUp() + self.controller = api_port.PortsController() + + def test__check_allowed_port_fields_none(self, mock_allow_port, + mock_allow_portgroup): + self.assertIsNone( + self.controller._check_allowed_port_fields(None)) + self.assertFalse(mock_allow_port.called) + self.assertFalse(mock_allow_portgroup.called) + + def test__check_allowed_port_fields_empty(self, mock_allow_port, + mock_allow_portgroup): + for v in (True, False): + mock_allow_port.return_value = v + self.assertIsNone( + self.controller._check_allowed_port_fields([])) + mock_allow_port.assert_called_once_with() + mock_allow_port.reset_mock() + self.assertFalse(mock_allow_portgroup.called) + + def test__check_allowed_port_fields_not_allow(self, mock_allow_port, + mock_allow_portgroup): + mock_allow_port.return_value = False + for field in api_port.PortsController.advanced_net_fields: + self.assertRaises(exception.NotAcceptable, + self.controller._check_allowed_port_fields, + [field]) + mock_allow_port.assert_called_once_with() + mock_allow_port.reset_mock() + self.assertFalse(mock_allow_portgroup.called) + + def test__check_allowed_port_fields_allow(self, mock_allow_port, + mock_allow_portgroup): + mock_allow_port.return_value = True + for field in api_port.PortsController.advanced_net_fields: + self.assertIsNone( + self.controller._check_allowed_port_fields([field])) + mock_allow_port.assert_called_once_with() + mock_allow_port.reset_mock() + self.assertFalse(mock_allow_portgroup.called) + + def test__check_allowed_port_fields_portgroup_not_allow( + self, mock_allow_port, mock_allow_portgroup): + mock_allow_port.return_value = True + mock_allow_portgroup.return_value = False + self.assertRaises(exception.NotAcceptable, + self.controller._check_allowed_port_fields, + ['portgroup_uuid']) + mock_allow_port.assert_called_once_with() + mock_allow_portgroup.assert_called_once_with() + + def test__check_allowed_port_fields_portgroup_allow( + self, mock_allow_port, mock_allow_portgroup): + mock_allow_port.return_value = True + mock_allow_portgroup.return_value = True + self.assertIsNone( + self.controller._check_allowed_port_fields(['portgroup_uuid'])) + mock_allow_port.assert_called_once_with() + mock_allow_portgroup.assert_called_once_with() + + class TestListPorts(test_api_base.BaseApiTest): def setUp(self): @@ -325,6 +389,43 @@ class TestListPorts(test_api_base.BaseApiTest): self.assertEqual('application/json', response.content_type) self.assertIn(invalid_key, response.json['error_message']) + def _test_sort_key_allowed(self, detail=False): + port_uuids = [] + for id_ in range(2): + port = obj_utils.create_test_port( + self.context, + node_id=self.node.id, + uuid=uuidutils.generate_uuid(), + address='52:54:00:cf:2d:3%s' % id_, + pxe_enabled=id_ % 2) + port_uuids.append(port.uuid) + headers = {api_base.Version.string: str(api_v1.MAX_VER)} + detail_str = '/detail' if detail else '' + data = self.get_json('/ports%s?sort_key=pxe_enabled' % detail_str, + headers=headers) + data_uuids = [p['uuid'] for p in data['ports']] + self.assertEqual(port_uuids, data_uuids) + + def test_sort_key_allowed(self): + self._test_sort_key_allowed() + + def test_detail_sort_key_allowed(self): + self._test_sort_key_allowed(detail=True) + + def _test_sort_key_not_allowed(self, detail=False): + headers = {api_base.Version.string: '1.18'} + detail_str = '/detail' if detail else '' + resp = self.get_json('/ports%s?sort_key=pxe_enabled' % detail_str, + headers=headers, expect_errors=True) + self.assertEqual(http_client.NOT_ACCEPTABLE, resp.status_int) + self.assertEqual('application/json', resp.content_type) + + def test_sort_key_not_allowed(self): + self._test_sort_key_not_allowed() + + def test_detail_sort_key_not_allowed(self): + self._test_sort_key_not_allowed(detail=True) + @mock.patch.object(api_utils, 'get_rpc_node') def test_get_all_by_node_name_ok(self, mock_get_rpc_node): # GET /v1/ports specifying node_name - success diff --git a/releasenotes/notes/sort_key_allowed_field-091f8eeedd0a2ace.yaml b/releasenotes/notes/sort_key_allowed_field-091f8eeedd0a2ace.yaml new file mode 100644 index 0000000000..fc7e0c25f1 --- /dev/null +++ b/releasenotes/notes/sort_key_allowed_field-091f8eeedd0a2ace.yaml @@ -0,0 +1,6 @@ +--- +fixes: + - | + When returning lists of nodes, port groups, or ports, checks the sort key + to make sure the field is available in the requested API version. A 406 + (Not Acceptable) HTTP status is returned if the field is not available.