Pass context objects directly to policy enforcement

The oslo.policy Enforcer() object knows what to do with instances of
oslo.context RequestContext() if you pass it one.

This makes it easier for people to perform policy enforcement since they
don't need to map important authorization information from the context
object into a dictionary (historically called `creds`). This practiced
didn't guarantee any consistency in `creds` implementations.

You also don't need to call context.to_policy_values() anymore. The
oslo.policy library will do that for you under the hood and map context
values into a set of policy attributes it understands.

This commit updates the calls to enforcement to pass in the context
object where applicable.

Change-Id: Ife4ba098303088023e4341354a1e3bc9f378ce93
This commit is contained in:
Lance Bragstad 2021-01-22 19:20:42 +00:00
parent 5640860c81
commit 72044aaa85
2 changed files with 52 additions and 27 deletions

View File

@ -1458,8 +1458,11 @@ def check_policy(policy_name):
:policy_name: Name of the policy to check. :policy_name: Name of the policy to check.
:raises: HTTPForbidden if the policy forbids access. :raises: HTTPForbidden if the policy forbids access.
""" """
# NOTE(lbragstad): Mapping context attributes into a target dictionary is
# effectively a noop from an authorization perspective because the values
# we're comparing are coming from the same place.
cdict = api.request.context.to_policy_values() cdict = api.request.context.to_policy_values()
policy.authorize(policy_name, cdict, cdict) policy.authorize(policy_name, cdict, api.request.context)
def check_owner_policy(object_type, policy_name, owner, lessee=None): def check_owner_policy(object_type, policy_name, owner, lessee=None):
@ -1478,7 +1481,7 @@ def check_owner_policy(object_type, policy_name, owner, lessee=None):
target_dict[object_type + '.owner'] = owner target_dict[object_type + '.owner'] = owner
if lessee: if lessee:
target_dict[object_type + '.lessee'] = lessee target_dict[object_type + '.lessee'] = lessee
policy.authorize(policy_name, target_dict, cdict) policy.authorize(policy_name, target_dict, api.request.context)
def check_node_policy_and_retrieve(policy_name, node_ident, def check_node_policy_and_retrieve(policy_name, node_ident,
@ -1502,7 +1505,7 @@ def check_node_policy_and_retrieve(policy_name, node_ident,
# don't expose non-existence of node unless requester # don't expose non-existence of node unless requester
# has generic access to policy # has generic access to policy
cdict = api.request.context.to_policy_values() cdict = api.request.context.to_policy_values()
policy.authorize(policy_name, cdict, cdict) policy.authorize(policy_name, cdict, api.request.context)
raise raise
check_owner_policy('node', policy_name, check_owner_policy('node', policy_name,
@ -1527,7 +1530,7 @@ def check_allocation_policy_and_retrieve(policy_name, allocation_ident):
# don't expose non-existence unless requester # don't expose non-existence unless requester
# has generic access to policy # has generic access to policy
cdict = api.request.context.to_policy_values() cdict = api.request.context.to_policy_values()
policy.authorize(policy_name, cdict, cdict) policy.authorize(policy_name, cdict, api.request.context)
raise raise
check_owner_policy('allocation', policy_name, rpc_allocation['owner']) check_owner_policy('allocation', policy_name, rpc_allocation['owner'])
@ -1571,12 +1574,13 @@ def check_list_policy(object_type, owner=None):
cdict = api.request.context.to_policy_values() cdict = api.request.context.to_policy_values()
try: try:
policy.authorize('baremetal:%s:list_all' % object_type, policy.authorize('baremetal:%s:list_all' % object_type,
cdict, cdict) cdict, api.request.context)
except exception.HTTPForbidden: except exception.HTTPForbidden:
project_owner = cdict.get('project_id') project_owner = cdict.get('project_id')
if (not project_owner or (owner and owner != project_owner)): if (not project_owner or (owner and owner != project_owner)):
raise raise
policy.authorize('baremetal:%s:list' % object_type, cdict, cdict) policy.authorize('baremetal:%s:list' % object_type,
cdict, api.request.context)
return project_owner return project_owner
return owner return owner
@ -1599,14 +1603,14 @@ def check_port_policy_and_retrieve(policy_name, port_uuid):
except exception.PortNotFound: except exception.PortNotFound:
# don't expose non-existence of port unless requester # don't expose non-existence of port unless requester
# has generic access to policy # has generic access to policy
policy.authorize(policy_name, cdict, cdict) policy.authorize(policy_name, cdict, context)
raise raise
rpc_node = objects.Node.get_by_id(context, rpc_port.node_id) rpc_node = objects.Node.get_by_id(context, rpc_port.node_id)
target_dict = dict(cdict) target_dict = dict(cdict)
target_dict['node.owner'] = rpc_node['owner'] target_dict['node.owner'] = rpc_node['owner']
target_dict['node.lessee'] = rpc_node['lessee'] target_dict['node.lessee'] = rpc_node['lessee']
policy.authorize(policy_name, target_dict, cdict) policy.authorize(policy_name, target_dict, context)
return rpc_port, rpc_node return rpc_port, rpc_node
@ -1619,12 +1623,14 @@ def check_port_list_policy():
""" """
cdict = api.request.context.to_policy_values() cdict = api.request.context.to_policy_values()
try: try:
policy.authorize('baremetal:port:list_all', cdict, cdict) policy.authorize('baremetal:port:list_all',
cdict, api.request.context)
except exception.HTTPForbidden: except exception.HTTPForbidden:
owner = cdict.get('project_id') owner = cdict.get('project_id')
if not owner: if not owner:
raise raise
policy.authorize('baremetal:port:list', cdict, cdict) policy.authorize('baremetal:port:list',
cdict, api.request.context)
return owner return owner

View File

@ -25,6 +25,7 @@ from oslo_utils import uuidutils
from ironic import api from ironic import api
from ironic.api.controllers.v1 import node as api_node from ironic.api.controllers.v1 import node as api_node
from ironic.api.controllers.v1 import utils from ironic.api.controllers.v1 import utils
from ironic.common import context as ironic_context
from ironic.common import exception from ironic.common import exception
from ironic.common import policy from ironic.common import policy
from ironic.common import states from ironic.common import states
@ -992,9 +993,12 @@ class TestVendorPassthru(base.TestCase):
@mock.patch.object(api, 'request', spec_set=["context"]) @mock.patch.object(api, 'request', spec_set=["context"])
@mock.patch.object(policy, 'authorize', spec=True) @mock.patch.object(policy, 'authorize', spec=True)
def test_check_policy(self, mock_authorize, mock_pr): def test_check_policy(self, mock_authorize, mock_pr):
fake_context = ironic_context.RequestContext()
mock_pr.context = fake_context
expected_target = dict(fake_context.to_policy_values())
utils.check_policy('fake-policy') utils.check_policy('fake-policy')
cdict = api.request.context.to_policy_values() mock_authorize.assert_called_once_with('fake-policy', expected_target,
mock_authorize.assert_called_once_with('fake-policy', cdict, cdict) fake_context)
@mock.patch.object(api, 'request', spec_set=["context"]) @mock.patch.object(api, 'request', spec_set=["context"])
@mock.patch.object(policy, 'authorize', spec=True) @mock.patch.object(policy, 'authorize', spec=True)
@ -1048,15 +1052,18 @@ class TestCheckOwnerPolicy(base.TestCase):
def test_check_owner_policy( def test_check_owner_policy(
self, mock_authorize, mock_pr self, mock_authorize, mock_pr
): ):
fake_context = ironic_context.RequestContext()
mock_pr.version.minor = 50 mock_pr.version.minor = 50
mock_pr.context.to_policy_values.return_value = {} mock_pr.context = fake_context
expected_target = dict(fake_context.to_policy_values())
expected_target['node.owner'] = '12345'
expected_target['node.lessee'] = '54321'
utils.check_owner_policy( utils.check_owner_policy(
'node', 'fake_policy', self.node['owner'], self.node['lessee'] 'node', 'fake_policy', self.node['owner'], self.node['lessee']
) )
mock_authorize.assert_called_once_with( mock_authorize.assert_called_once_with(
'fake_policy', 'fake_policy', expected_target, fake_context)
{'node.owner': '12345', 'node.lessee': '54321'}, {})
@mock.patch.object(api, 'request', spec_set=["context", "version"]) @mock.patch.object(api, 'request', spec_set=["context", "version"])
@mock.patch.object(policy, 'authorize', spec=True) @mock.patch.object(policy, 'authorize', spec=True)
@ -1091,8 +1098,13 @@ class TestCheckNodePolicyAndRetrieve(base.TestCase):
def test_check_node_policy_and_retrieve( def test_check_node_policy_and_retrieve(
self, mock_grnws, mock_grn, mock_authorize, mock_pr self, mock_grnws, mock_grn, mock_authorize, mock_pr
): ):
fake_context = ironic_context.RequestContext()
expected_target = dict(fake_context.to_policy_values())
expected_target['node.owner'] = '12345'
expected_target['node.lessee'] = '54321'
mock_pr.context = fake_context
mock_pr.version.minor = 50 mock_pr.version.minor = 50
mock_pr.context.to_policy_values.return_value = {}
mock_grn.return_value = self.node mock_grn.return_value = self.node
rpc_node = utils.check_node_policy_and_retrieve( rpc_node = utils.check_node_policy_and_retrieve(
@ -1101,8 +1113,7 @@ class TestCheckNodePolicyAndRetrieve(base.TestCase):
mock_grn.assert_called_once_with(self.valid_node_uuid) mock_grn.assert_called_once_with(self.valid_node_uuid)
mock_grnws.assert_not_called() mock_grnws.assert_not_called()
mock_authorize.assert_called_once_with( mock_authorize.assert_called_once_with(
'fake_policy', 'fake_policy', expected_target, fake_context)
{'node.owner': '12345', 'node.lessee': '54321'}, {})
self.assertEqual(self.node, rpc_node) self.assertEqual(self.node, rpc_node)
@mock.patch.object(api, 'request', spec_set=["context", "version"]) @mock.patch.object(api, 'request', spec_set=["context", "version"])
@ -1112,8 +1123,12 @@ class TestCheckNodePolicyAndRetrieve(base.TestCase):
def test_check_node_policy_and_retrieve_with_suffix( def test_check_node_policy_and_retrieve_with_suffix(
self, mock_grnws, mock_grn, mock_authorize, mock_pr self, mock_grnws, mock_grn, mock_authorize, mock_pr
): ):
fake_context = ironic_context.RequestContext()
expected_target = fake_context.to_policy_values()
expected_target['node.owner'] = '12345'
expected_target['node.lessee'] = '54321'
mock_pr.context = fake_context
mock_pr.version.minor = 50 mock_pr.version.minor = 50
mock_pr.context.to_policy_values.return_value = {}
mock_grnws.return_value = self.node mock_grnws.return_value = self.node
rpc_node = utils.check_node_policy_and_retrieve( rpc_node = utils.check_node_policy_and_retrieve(
@ -1122,8 +1137,7 @@ class TestCheckNodePolicyAndRetrieve(base.TestCase):
mock_grn.assert_not_called() mock_grn.assert_not_called()
mock_grnws.assert_called_once_with(self.valid_node_uuid) mock_grnws.assert_called_once_with(self.valid_node_uuid)
mock_authorize.assert_called_once_with( mock_authorize.assert_called_once_with(
'fake_policy', 'fake_policy', expected_target, fake_context)
{'node.owner': '12345', 'node.lessee': '54321'}, {})
self.assertEqual(self.node, rpc_node) self.assertEqual(self.node, rpc_node)
@mock.patch.object(api, 'request', spec_set=["context"]) @mock.patch.object(api, 'request', spec_set=["context"])
@ -1193,8 +1207,11 @@ class TestCheckAllocationPolicyAndRetrieve(base.TestCase):
def test_check_node_policy_and_retrieve( def test_check_node_policy_and_retrieve(
self, mock_graws, mock_authorize, mock_pr self, mock_graws, mock_authorize, mock_pr
): ):
fake_context = ironic_context.RequestContext()
expected_target = dict(fake_context.to_policy_values())
expected_target['allocation.owner'] = '12345'
mock_pr.version.minor = 60 mock_pr.version.minor = 60
mock_pr.context.to_policy_values.return_value = {} mock_pr.context = fake_context
mock_graws.return_value = self.allocation mock_graws.return_value = self.allocation
rpc_allocation = utils.check_allocation_policy_and_retrieve( rpc_allocation = utils.check_allocation_policy_and_retrieve(
@ -1202,7 +1219,7 @@ class TestCheckAllocationPolicyAndRetrieve(base.TestCase):
) )
mock_graws.assert_called_once_with(self.valid_allocation_uuid) mock_graws.assert_called_once_with(self.valid_allocation_uuid)
mock_authorize.assert_called_once_with( mock_authorize.assert_called_once_with(
'fake_policy', {'allocation.owner': '12345'}, {}) 'fake_policy', expected_target, fake_context)
self.assertEqual(self.allocation, rpc_allocation) self.assertEqual(self.allocation, rpc_allocation)
@mock.patch.object(api, 'request', spec_set=["context"]) @mock.patch.object(api, 'request', spec_set=["context"])
@ -1444,8 +1461,12 @@ class TestCheckPortPolicyAndRetrieve(base.TestCase):
def test_check_port_policy_and_retrieve( def test_check_port_policy_and_retrieve(
self, mock_ngbi, mock_pgbu, mock_authorize, mock_pr self, mock_ngbi, mock_pgbu, mock_authorize, mock_pr
): ):
fake_context = ironic_context.RequestContext()
expected_target = fake_context.to_policy_values()
expected_target['node.owner'] = '12345'
expected_target['node.lessee'] = '54321'
mock_pr.context = fake_context
mock_pr.version.minor = 50 mock_pr.version.minor = 50
mock_pr.context.to_policy_values.return_value = {}
mock_pgbu.return_value = self.port mock_pgbu.return_value = self.port
mock_ngbi.return_value = self.node mock_ngbi.return_value = self.node
@ -1456,9 +1477,7 @@ class TestCheckPortPolicyAndRetrieve(base.TestCase):
self.valid_port_uuid) self.valid_port_uuid)
mock_ngbi.assert_called_once_with(mock_pr.context, 42) mock_ngbi.assert_called_once_with(mock_pr.context, 42)
mock_authorize.assert_called_once_with( mock_authorize.assert_called_once_with(
'fake_policy', 'fake_policy', expected_target, fake_context)
{'node.owner': '12345', 'node.lessee': '54321'},
{})
self.assertEqual(self.port, rpc_port) self.assertEqual(self.port, rpc_port)
self.assertEqual(self.node, rpc_node) self.assertEqual(self.node, rpc_node)