Check if sort key is allowed in API version

This checks the sort key, to make sure the specified field
is allowed/available in the specified API microversion.
If it is not allowed, a 406 HTTP status is returned.

This affects requests to get lists of nodes, port groups,
and ports.

Change-Id: Id5fb44b8b7fe989514dbae4b60cef4a34d47e52b
Closes-Bug: #1659419
This commit is contained in:
Ruby Loo 2017-05-02 02:18:08 +00:00
parent 6f237ecfd9
commit db0c42a955
7 changed files with 209 additions and 24 deletions

View File

@ -1442,6 +1442,7 @@ class NodesController(rest.RestController):
api_utils.check_allow_specify_fields(fields)
api_utils.check_allowed_fields(fields)
api_utils.check_allowed_fields([sort_key])
api_utils.check_for_invalid_state_and_allow_filter(provision_state)
api_utils.check_allow_specify_driver(driver)
api_utils.check_allow_specify_resource_class(resource_class)
@ -1495,6 +1496,7 @@ class NodesController(rest.RestController):
api_utils.check_for_invalid_state_and_allow_filter(provision_state)
api_utils.check_allow_specify_driver(driver)
api_utils.check_allow_specify_resource_class(resource_class)
api_utils.check_allowed_fields([sort_key])
# /detail should only work against collections
parent = pecan.request.path.split('/')[:-1][-1]
if parent != "nodes":

View File

@ -354,6 +354,24 @@ class PortsController(rest.RestController):
except exception.PortNotFound:
return []
def _check_allowed_port_fields(self, fields):
"""Check if fetching a particular field of a port is allowed.
Check if the required version is being requested for fields
that are only allowed to be fetched in a particular API version.
:param fields: list or set of fields to check
:raises: NotAcceptable if a field is not allowed
"""
if fields is None:
return
if (not api_utils.allow_port_advanced_net_fields() and
set(fields).intersection(self.advanced_net_fields)):
raise exception.NotAcceptable()
if ('portgroup_uuid' in fields and not
api_utils.allow_portgroups_subcontrollers()):
raise exception.NotAcceptable()
@METRICS.timer('PortsController.get_all')
@expose.expose(PortCollection, types.uuid_or_name, types.uuid,
types.macaddress, types.uuid, int, wtypes.text,
@ -389,14 +407,8 @@ class PortsController(rest.RestController):
policy.authorize('baremetal:port:get', cdict, cdict)
api_utils.check_allow_specify_fields(fields)
if fields:
if (not api_utils.allow_port_advanced_net_fields() and
set(fields).intersection(self.advanced_net_fields)):
raise exception.NotAcceptable()
if ('portgroup_uuid' in fields and not
api_utils.allow_portgroups_subcontrollers()):
raise exception.NotAcceptable()
self._check_allowed_port_fields(fields)
self._check_allowed_port_fields([sort_key])
if portgroup and not api_utils.allow_portgroups_subcontrollers():
raise exception.NotAcceptable()
@ -446,6 +458,7 @@ class PortsController(rest.RestController):
cdict = pecan.request.context.to_policy_values()
policy.authorize('baremetal:port:get', cdict, cdict)
self._check_allowed_port_fields([sort_key])
if portgroup and not api_utils.allow_portgroups_subcontrollers():
raise exception.NotAcceptable()
@ -504,12 +517,7 @@ class PortsController(rest.RestController):
raise exception.OperationNotPermitted()
pdict = port.as_dict()
if (not api_utils.allow_port_advanced_net_fields() and
set(pdict).intersection(self.advanced_net_fields)):
raise exception.NotAcceptable()
if (not api_utils.allow_portgroups_subcontrollers() and
'portgroup_uuid' in pdict):
raise exception.NotAcceptable()
self._check_allowed_port_fields(pdict)
extra = pdict.get('extra')
vif = extra.get('vif_port_id') if extra else None
@ -569,12 +577,7 @@ class PortsController(rest.RestController):
if (api_utils.get_patch_values(patch, field_path) or
api_utils.is_path_removed(patch, field_path)):
fields_to_check.add(field)
if (fields_to_check.intersection(self.advanced_net_fields) and
not api_utils.allow_port_advanced_net_fields()):
raise exception.NotAcceptable()
if ('portgroup_uuid' in fields_to_check and
not api_utils.allow_portgroups_subcontrollers()):
raise exception.NotAcceptable()
self._check_allowed_port_fields(fields_to_check)
rpc_port = objects.Port.get_by_uuid(context, port_uuid)
try:

View File

@ -356,6 +356,7 @@ class PortgroupsController(pecan.rest.RestController):
policy.authorize('baremetal:portgroup:get', cdict, cdict)
api_utils.check_allowed_portgroup_fields(fields)
api_utils.check_allowed_portgroup_fields([sort_key])
if fields is None:
fields = _DEFAULT_RETURN_FIELDS
@ -389,6 +390,7 @@ class PortgroupsController(pecan.rest.RestController):
cdict = pecan.request.context.to_policy_values()
policy.authorize('baremetal:portgroup:get', cdict, cdict)
api_utils.check_allowed_portgroup_fields([sort_key])
# NOTE: /detail should only work against collections
parent = pecan.request.path.split('/')[:-1][-1]

View File

@ -1,5 +1,3 @@
# -*- encoding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
@ -504,6 +502,41 @@ class TestListNodes(test_api_base.BaseApiTest):
self.assertEqual('application/json', response.content_type)
self.assertIn(invalid_key, response.json['error_message'])
def _test_sort_key_allowed(self, detail=False):
node_uuids = []
for id in range(3, 0, -1):
node = obj_utils.create_test_node(self.context,
uuid=uuidutils.generate_uuid(),
resource_class='rc_%s' % id)
node_uuids.append(node.uuid)
node_uuids.reverse()
headers = {'X-OpenStack-Ironic-API-Version': '1.21'}
detail_str = '/detail' if detail else ''
data = self.get_json('/nodes%s?sort_key=resource_class' % detail_str,
headers=headers)
data_uuids = [n['uuid'] for n in data['nodes']]
self.assertEqual(node_uuids, data_uuids)
def test_sort_key_allowed(self):
self._test_sort_key_allowed()
def test_detail_sort_key_allowed(self):
self._test_sort_key_allowed(detail=True)
def _test_sort_key_not_allowed(self, detail=False):
headers = {'X-OpenStack-Ironic-API-Version': '1.20'}
detail_str = '/detail' if detail else ''
resp = self.get_json('/nodes%s?sort_key=resource_class' % detail_str,
headers=headers, expect_errors=True)
self.assertEqual(http_client.NOT_ACCEPTABLE, resp.status_int)
self.assertEqual('application/json', resp.content_type)
def test_sort_key_not_allowed(self):
self._test_sort_key_not_allowed()
def test_detail_sort_key_not_allowed(self):
self._test_sort_key_not_allowed(detail=True)
def test_ports_subresource_link(self):
node = obj_utils.create_test_node(self.context)
data = self.get_json('/nodes/%s' % node.uuid)

View File

@ -391,6 +391,44 @@ class TestListPortgroups(test_api_base.BaseApiTest):
self.assertEqual('application/json', response.content_type)
self.assertIn(invalid_key, response.json['error_message'])
def _test_sort_key_allowed(self, detail=False):
portgroup_uuids = []
for id_ in range(3, 0, -1):
portgroup = obj_utils.create_test_portgroup(
self.context,
node_id=self.node.id,
uuid=uuidutils.generate_uuid(),
name='portgroup%s' % id_,
address='52:54:00:cf:2d:3%s' % id_,
mode='mode_%s' % id_)
portgroup_uuids.append(portgroup.uuid)
portgroup_uuids.reverse()
detail_str = '/detail' if detail else ''
data = self.get_json('/portgroups%s?sort_key=mode' % detail_str,
headers=self.headers)
data_uuids = [p['uuid'] for p in data['portgroups']]
self.assertEqual(portgroup_uuids, data_uuids)
def test_sort_key_allowed(self):
self._test_sort_key_allowed()
def test_detail_sort_key_allowed(self):
self._test_sort_key_allowed(detail=True)
def _test_sort_key_not_allowed(self, detail=False):
headers = {api_base.Version.string: '1.25'}
detail_str = '/detail' if detail else ''
response = self.get_json('/portgroups%s?sort_key=mode' % detail_str,
headers=headers, expect_errors=True)
self.assertEqual(http_client.NOT_ACCEPTABLE, response.status_int)
self.assertEqual('application/json', response.content_type)
def test_sort_key_not_allowed(self):
self._test_sort_key_not_allowed()
def test_detail_sort_key_not_allowed(self):
self._test_sort_key_not_allowed(detail=True)
@mock.patch.object(api_utils, 'get_rpc_node')
def test_get_all_by_node_name_ok(self, mock_get_rpc_node):
# GET /v1/portgroups specifying node_name - success

View File

@ -1,5 +1,3 @@
# -*- encoding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
@ -68,6 +66,72 @@ class TestPortObject(base.TestCase):
self.assertEqual(wtypes.Unset, port.extra)
@mock.patch.object(api_utils, 'allow_portgroups_subcontrollers', autospec=True)
@mock.patch.object(api_utils, 'allow_port_advanced_net_fields', autospec=True)
class TestPortsController__CheckAllowedPortFields(base.TestCase):
def setUp(self):
super(TestPortsController__CheckAllowedPortFields, self).setUp()
self.controller = api_port.PortsController()
def test__check_allowed_port_fields_none(self, mock_allow_port,
mock_allow_portgroup):
self.assertIsNone(
self.controller._check_allowed_port_fields(None))
self.assertFalse(mock_allow_port.called)
self.assertFalse(mock_allow_portgroup.called)
def test__check_allowed_port_fields_empty(self, mock_allow_port,
mock_allow_portgroup):
for v in (True, False):
mock_allow_port.return_value = v
self.assertIsNone(
self.controller._check_allowed_port_fields([]))
mock_allow_port.assert_called_once_with()
mock_allow_port.reset_mock()
self.assertFalse(mock_allow_portgroup.called)
def test__check_allowed_port_fields_not_allow(self, mock_allow_port,
mock_allow_portgroup):
mock_allow_port.return_value = False
for field in api_port.PortsController.advanced_net_fields:
self.assertRaises(exception.NotAcceptable,
self.controller._check_allowed_port_fields,
[field])
mock_allow_port.assert_called_once_with()
mock_allow_port.reset_mock()
self.assertFalse(mock_allow_portgroup.called)
def test__check_allowed_port_fields_allow(self, mock_allow_port,
mock_allow_portgroup):
mock_allow_port.return_value = True
for field in api_port.PortsController.advanced_net_fields:
self.assertIsNone(
self.controller._check_allowed_port_fields([field]))
mock_allow_port.assert_called_once_with()
mock_allow_port.reset_mock()
self.assertFalse(mock_allow_portgroup.called)
def test__check_allowed_port_fields_portgroup_not_allow(
self, mock_allow_port, mock_allow_portgroup):
mock_allow_port.return_value = True
mock_allow_portgroup.return_value = False
self.assertRaises(exception.NotAcceptable,
self.controller._check_allowed_port_fields,
['portgroup_uuid'])
mock_allow_port.assert_called_once_with()
mock_allow_portgroup.assert_called_once_with()
def test__check_allowed_port_fields_portgroup_allow(
self, mock_allow_port, mock_allow_portgroup):
mock_allow_port.return_value = True
mock_allow_portgroup.return_value = True
self.assertIsNone(
self.controller._check_allowed_port_fields(['portgroup_uuid']))
mock_allow_port.assert_called_once_with()
mock_allow_portgroup.assert_called_once_with()
class TestListPorts(test_api_base.BaseApiTest):
def setUp(self):
@ -325,6 +389,43 @@ class TestListPorts(test_api_base.BaseApiTest):
self.assertEqual('application/json', response.content_type)
self.assertIn(invalid_key, response.json['error_message'])
def _test_sort_key_allowed(self, detail=False):
port_uuids = []
for id_ in range(2):
port = obj_utils.create_test_port(
self.context,
node_id=self.node.id,
uuid=uuidutils.generate_uuid(),
address='52:54:00:cf:2d:3%s' % id_,
pxe_enabled=id_ % 2)
port_uuids.append(port.uuid)
headers = {api_base.Version.string: str(api_v1.MAX_VER)}
detail_str = '/detail' if detail else ''
data = self.get_json('/ports%s?sort_key=pxe_enabled' % detail_str,
headers=headers)
data_uuids = [p['uuid'] for p in data['ports']]
self.assertEqual(port_uuids, data_uuids)
def test_sort_key_allowed(self):
self._test_sort_key_allowed()
def test_detail_sort_key_allowed(self):
self._test_sort_key_allowed(detail=True)
def _test_sort_key_not_allowed(self, detail=False):
headers = {api_base.Version.string: '1.18'}
detail_str = '/detail' if detail else ''
resp = self.get_json('/ports%s?sort_key=pxe_enabled' % detail_str,
headers=headers, expect_errors=True)
self.assertEqual(http_client.NOT_ACCEPTABLE, resp.status_int)
self.assertEqual('application/json', resp.content_type)
def test_sort_key_not_allowed(self):
self._test_sort_key_not_allowed()
def test_detail_sort_key_not_allowed(self):
self._test_sort_key_not_allowed(detail=True)
@mock.patch.object(api_utils, 'get_rpc_node')
def test_get_all_by_node_name_ok(self, mock_get_rpc_node):
# GET /v1/ports specifying node_name - success

View File

@ -0,0 +1,6 @@
---
fixes:
- |
When returning lists of nodes, port groups, or ports, checks the sort key
to make sure the field is available in the requested API version. A 406
(Not Acceptable) HTTP status is returned if the field is not available.