[rbac] Pull up policy checks on share/snapshot APIs

RBAC enforcement in manila happens in stages:
1) Does user have access to the API
2) Does user have access to the resource
3) Is user permitted to perform the API action on the resource

If (1) fails, user gets a HTTP 403, if (2) fails,
they get a HTTP 404; if (3) fails, they get a HTTP 403.

More often than not, (2) prevents "existence" detection
of resources that don't belong to the user; except in
case of "public" resources (e.g.: shares can be "public").

In some share API methods, policy checks for (1) are
happening after a bunch of processing. This leads to
some inconsistency.

Fix these occurrences to ensure a consistent user
experience.

Change-Id: I5b1f1ce517efed000f17b1e0901e183a1913ba9f
Related-Bug: #2004230
Signed-off-by: Goutham Pacha Ravi <gouthampravi@gmail.com>
(cherry picked from commit 190876809f)
(cherry picked from commit 318140e250)
(cherry picked from commit 8edaa3254c)
(cherry picked from commit f878e15bca)
(cherry picked from commit 8ab5ec6b4f)
This commit is contained in:
Goutham Pacha Ravi 2023-05-25 15:20:31 -07:00
parent 5c76512b9f
commit 96cca4b94e
6 changed files with 74 additions and 77 deletions

View File

@ -26,6 +26,7 @@ from manila.api.views import share_snapshots as snapshot_views
from manila import db from manila import db
from manila import exception from manila import exception
from manila.i18n import _ from manila.i18n import _
from manila import policy
from manila import share from manila import share
LOG = log.getLogger(__name__) LOG = log.getLogger(__name__)
@ -63,6 +64,7 @@ class ShareSnapshotMixin(object):
context = req.environ['manila.context'] context = req.environ['manila.context']
LOG.info("Delete snapshot with id: %s", id, context=context) LOG.info("Delete snapshot with id: %s", id, context=context)
policy.check_policy(context, 'share', 'delete_snapshot')
try: try:
snapshot = self.share_api.get_snapshot(context, id) snapshot = self.share_api.get_snapshot(context, id)
@ -142,6 +144,7 @@ class ShareSnapshotMixin(object):
def update(self, req, id, body): def update(self, req, id, body):
"""Update a snapshot.""" """Update a snapshot."""
context = req.environ['manila.context'] context = req.environ['manila.context']
policy.check_policy(context, 'share', 'snapshot_update')
if not body or 'snapshot' not in body: if not body or 'snapshot' not in body:
raise exc.HTTPUnprocessableEntity() raise exc.HTTPUnprocessableEntity()

View File

@ -52,6 +52,7 @@ class ShareMixin(object):
def _delete(self, *args, **kwargs): def _delete(self, *args, **kwargs):
return self.share_api.delete(*args, **kwargs) return self.share_api.delete(*args, **kwargs)
@wsgi.Controller.authorize('get')
def show(self, req, id): def show(self, req, id):
"""Return data about the given share.""" """Return data about the given share."""
context = req.environ['manila.context'] context = req.environ['manila.context']
@ -63,6 +64,7 @@ class ShareMixin(object):
return self._view_builder.detail(req, share) return self._view_builder.detail(req, share)
@wsgi.Controller.authorize
def delete(self, req, id): def delete(self, req, id):
"""Delete a share.""" """Delete a share."""
context = req.environ['manila.context'] context = req.environ['manila.context']
@ -97,6 +99,7 @@ class ShareMixin(object):
return webob.Response(status_int=http_client.ACCEPTED) return webob.Response(status_int=http_client.ACCEPTED)
@wsgi.Controller.authorize("get_all")
def index(self, req): def index(self, req):
"""Returns a summary list of shares.""" """Returns a summary list of shares."""
req.GET.pop('export_location_id', None) req.GET.pop('export_location_id', None)
@ -107,6 +110,7 @@ class ShareMixin(object):
req.GET.pop('with_count', None) req.GET.pop('with_count', None)
return self._get_shares(req, is_detail=False) return self._get_shares(req, is_detail=False)
@wsgi.Controller.authorize("get_all")
def detail(self, req): def detail(self, req):
"""Returns a detailed list of shares.""" """Returns a detailed list of shares."""
req.GET.pop('export_location_id', None) req.GET.pop('export_location_id', None)
@ -514,6 +518,7 @@ class ShareMixin(object):
return self._access_view_builder.list_view(req, access_rules) return self._access_view_builder.list_view(req, access_rules)
@wsgi.Controller.authorize("extend")
def _extend(self, req, id, body): def _extend(self, req, id, body):
"""Extend size of a share.""" """Extend size of a share."""
context = req.environ['manila.context'] context = req.environ['manila.context']
@ -529,6 +534,7 @@ class ShareMixin(object):
return webob.Response(status_int=http_client.ACCEPTED) return webob.Response(status_int=http_client.ACCEPTED)
@wsgi.Controller.authorize("shrink")
def _shrink(self, req, id, body): def _shrink(self, req, id, body):
"""Shrink size of a share.""" """Shrink size of a share."""
context = req.environ['manila.context'] context = req.environ['manila.context']

View File

@ -207,6 +207,7 @@ class ShareController(shares.ShareMixin,
@wsgi.Controller.api_version('2.7') @wsgi.Controller.api_version('2.7')
@wsgi.action('reset_status') @wsgi.action('reset_status')
@wsgi.Controller.authorize('reset_status')
def share_reset_status(self, req, id, body): def share_reset_status(self, req, id, body):
return self._reset_status(req, id, body) return self._reset_status(req, id, body)
@ -444,6 +445,7 @@ class ShareController(shares.ShareMixin,
return self._revert(req, id, body) return self._revert(req, id, body)
@wsgi.Controller.api_version("2.0") @wsgi.Controller.api_version("2.0")
@wsgi.Controller.authorize("get_all")
def index(self, req): def index(self, req):
"""Returns a summary list of shares.""" """Returns a summary list of shares."""
if req.api_version_request < api_version.APIVersionRequest("2.35"): if req.api_version_request < api_version.APIVersionRequest("2.35"):
@ -461,6 +463,7 @@ class ShareController(shares.ShareMixin,
return self._get_shares(req, is_detail=False) return self._get_shares(req, is_detail=False)
@wsgi.Controller.api_version("2.0") @wsgi.Controller.api_version("2.0")
@wsgi.Controller.authorize("get_all")
def detail(self, req): def detail(self, req):
"""Returns a detailed list of shares.""" """Returns a detailed list of shares."""
if req.api_version_request < api_version.APIVersionRequest("2.35"): if req.api_version_request < api_version.APIVersionRequest("2.35"):

View File

@ -1830,8 +1830,6 @@ class API(base.Base):
def _get_all(self, context, search_opts=None, sort_key='created_at', def _get_all(self, context, search_opts=None, sort_key='created_at',
sort_dir='desc', show_count=False): sort_dir='desc', show_count=False):
policy.check_policy(context, 'share', 'get_all')
if search_opts is None: if search_opts is None:
search_opts = {} search_opts = {}
@ -2154,7 +2152,7 @@ class API(base.Base):
return self.db.share_network_get(context, share_net_id) return self.db.share_network_get(context, share_net_id)
def extend(self, context, share, new_size): def extend(self, context, share, new_size):
policy.check_policy(context, 'share', 'extend') policy.check_policy(context, 'share', 'extend', share)
if share['status'] != constants.STATUS_AVAILABLE: if share['status'] != constants.STATUS_AVAILABLE:
msg_params = { msg_params = {
@ -2253,8 +2251,6 @@ class API(base.Base):
resource=share) resource=share)
def shrink(self, context, share, new_size): def shrink(self, context, share, new_size):
policy.check_policy(context, 'share', 'shrink')
status = six.text_type(share['status']).lower() status = six.text_type(share['status']).lower()
valid_statuses = (constants.STATUS_AVAILABLE, valid_statuses = (constants.STATUS_AVAILABLE,
constants.STATUS_SHRINKING_POSSIBLE_DATA_LOSS_ERROR) constants.STATUS_SHRINKING_POSSIBLE_DATA_LOSS_ERROR)

View File

@ -106,6 +106,7 @@ class ShareAPITest(test.TestCase):
} }
CONF.set_default("default_share_type", None) CONF.set_default("default_share_type", None)
self.mock_object(policy, 'check_policy')
def _process_expected_share_detailed_response(self, shr_dict, req_version): def _process_expected_share_detailed_response(self, shr_dict, req_version):
"""Sets version based parameters on share dictionary.""" """Sets version based parameters on share dictionary."""
@ -1691,6 +1692,10 @@ class ShareAPITest(test.TestCase):
{'display_name~': search_opts['display_name~'], {'display_name~': search_opts['display_name~'],
'display_description~': search_opts['display_description~']}) 'display_description~': search_opts['display_description~']})
policy.check_policy.assert_called_once_with(
req.environ['manila.context'],
'share',
'get_all')
if use_admin_context: if use_admin_context:
search_opts_expected.update({'fake_key': 'fake_value'}) search_opts_expected.update({'fake_key': 'fake_value'})
search_opts_expected['host'] = search_opts['host'] search_opts_expected['host'] = search_opts['host']
@ -1735,6 +1740,10 @@ class ShareAPITest(test.TestCase):
search_opts_expected = {} search_opts_expected = {}
policy.check_policy.assert_called_once_with(
req.environ['manila.context'],
'share',
'get_all')
if use_admin_context: if use_admin_context:
search_opts_expected.update({'fake_key': 'fake_value'}) search_opts_expected.update({'fake_key': 'fake_value'})
search_opts_expected['host'] = search_opts['host'] search_opts_expected['host'] = search_opts['host']
@ -1770,6 +1779,10 @@ class ShareAPITest(test.TestCase):
} }
] ]
} }
policy.check_policy.assert_called_once_with(
req.environ['manila.context'],
'share',
'get_all')
self.assertEqual(expected, res_dict) self.assertEqual(expected, res_dict)
@ddt.data({'use_admin_context': False, 'version': '2.4'}, @ddt.data({'use_admin_context': False, 'version': '2.4'},
@ -1856,10 +1869,14 @@ class ShareAPITest(test.TestCase):
search_opts['export_location_id']) search_opts['export_location_id'])
search_opts_expected['export_location_path'] = ( search_opts_expected['export_location_path'] = (
search_opts['export_location_path']) search_opts['export_location_path'])
if use_admin_context: if use_admin_context:
search_opts_expected.update({'fake_key': 'fake_value'}) search_opts_expected.update({'fake_key': 'fake_value'})
search_opts_expected['host'] = search_opts['host'] search_opts_expected['host'] = search_opts['host']
policy.check_policy.assert_called_once_with(
req.environ['manila.context'],
'share',
'get_all')
mock_get_all.assert_called_once_with( mock_get_all.assert_called_once_with(
req.environ['manila.context'], req.environ['manila.context'],
sort_key=search_opts['sort_key'], sort_key=search_opts['sort_key'],
@ -1927,7 +1944,13 @@ class ShareAPITest(test.TestCase):
def _list_detail_test_common(self, req, expected): def _list_detail_test_common(self, req, expected):
self.mock_object(share_api.API, 'get_all', self.mock_object(share_api.API, 'get_all',
stubs.stub_share_get_all_by_project) stubs.stub_share_get_all_by_project)
res_dict = self.controller.detail(req) res_dict = self.controller.detail(req)
policy.check_policy.assert_called_once_with(
req.environ['manila.context'],
'share',
'get_all')
self.assertDictListMatch(expected['shares'], res_dict['shares']) self.assertDictListMatch(expected['shares'], res_dict['shares'])
self.assertEqual(res_dict['shares'][0]['volume_type'], self.assertEqual(res_dict['shares'][0]['volume_type'],
res_dict['shares'][0]['share_type']) res_dict['shares'][0]['share_type'])
@ -1986,7 +2009,13 @@ class ShareAPITest(test.TestCase):
req = fakes.HTTPRequest.blank( req = fakes.HTTPRequest.blank(
'/v2/fake/shares/detail', environ=env, '/v2/fake/shares/detail', environ=env,
version=share_replicas.MIN_SUPPORTED_API_VERSION) version=share_replicas.MIN_SUPPORTED_API_VERSION)
res_dict = self.controller.detail(req) res_dict = self.controller.detail(req)
policy.check_policy.assert_called_once_with(
req.environ['manila.context'],
'share',
'get_all')
expected = { expected = {
'shares': [ 'shares': [
{ {
@ -2071,6 +2100,7 @@ class ShareActionsTest(test.TestCase):
super(ShareActionsTest, self).setUp() super(ShareActionsTest, self).setUp()
self.controller = shares.ShareController() self.controller = shares.ShareController()
self.mock_object(share_api.API, 'get', stubs.stub_share_get) self.mock_object(share_api.API, 'get', stubs.stub_share_get)
self.mock_object(policy, 'check_policy')
@ddt.unpack @ddt.unpack
@ddt.data( @ddt.data(

View File

@ -240,9 +240,9 @@ class ShareAPITestCase(test.TestCase):
self.mock_object(db_api, 'share_get_all_by_project', self.mock_object(db_api, 'share_get_all_by_project',
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[0])) mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[0]))
ctx = context.RequestContext('fake_uid', 'fake_pid_1', is_admin=True) ctx = context.RequestContext('fake_uid', 'fake_pid_1', is_admin=True)
shares = self.api.get_all(ctx) shares = self.api.get_all(ctx)
share_api.policy.check_policy.assert_called_once_with(
ctx, 'share', 'get_all')
db_api.share_get_all_by_project.assert_called_once_with( db_api.share_get_all_by_project.assert_called_once_with(
ctx, sort_dir='desc', sort_key='created_at', ctx, sort_dir='desc', sort_key='created_at',
project_id='fake_pid_1', filters={}, is_public=False project_id='fake_pid_1', filters={}, is_public=False
@ -253,9 +253,9 @@ class ShareAPITestCase(test.TestCase):
ctx = context.RequestContext('fake_uid', 'fake_pid_1', is_admin=True) ctx = context.RequestContext('fake_uid', 'fake_pid_1', is_admin=True)
self.mock_object(db_api, 'share_get_all', self.mock_object(db_api, 'share_get_all',
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES)) mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES))
shares = self.api.get_all(ctx, {'all_tenants': 1}) shares = self.api.get_all(ctx, {'all_tenants': 1})
share_api.policy.check_policy.assert_called_once_with(
ctx, 'share', 'get_all')
db_api.share_get_all.assert_called_once_with( db_api.share_get_all.assert_called_once_with(
ctx, sort_dir='desc', sort_key='created_at', filters={}) ctx, sort_dir='desc', sort_key='created_at', filters={})
self.assertEqual(_FAKE_LIST_OF_ALL_SHARES, shares) self.assertEqual(_FAKE_LIST_OF_ALL_SHARES, shares)
@ -264,9 +264,9 @@ class ShareAPITestCase(test.TestCase):
ctx = context.RequestContext('fake_uid', 'fake_pid_1', is_admin=True) ctx = context.RequestContext('fake_uid', 'fake_pid_1', is_admin=True)
self.mock_object(db_api, 'share_get_all', self.mock_object(db_api, 'share_get_all',
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES)) mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES))
shares = self.api.get_all(ctx, {'all_tenants': ''}) shares = self.api.get_all(ctx, {'all_tenants': ''})
share_api.policy.check_policy.assert_called_once_with(
ctx, 'share', 'get_all')
db_api.share_get_all.assert_called_once_with( db_api.share_get_all.assert_called_once_with(
ctx, sort_dir='desc', sort_key='created_at', filters={}) ctx, sort_dir='desc', sort_key='created_at', filters={})
self.assertEqual(_FAKE_LIST_OF_ALL_SHARES, shares) self.assertEqual(_FAKE_LIST_OF_ALL_SHARES, shares)
@ -275,9 +275,9 @@ class ShareAPITestCase(test.TestCase):
ctx = context.RequestContext('fake_uid', 'fake_pid_1', is_admin=True) ctx = context.RequestContext('fake_uid', 'fake_pid_1', is_admin=True)
self.mock_object(db_api, 'share_get_all_by_project', self.mock_object(db_api, 'share_get_all_by_project',
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[0])) mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[0]))
shares = self.api.get_all(ctx, {'all_tenants': 'false'}) shares = self.api.get_all(ctx, {'all_tenants': 'false'})
share_api.policy.check_policy.assert_called_once_with(
ctx, 'share', 'get_all')
db_api.share_get_all_by_project.assert_called_once_with( db_api.share_get_all_by_project.assert_called_once_with(
ctx, sort_dir='desc', sort_key='created_at', ctx, sort_dir='desc', sort_key='created_at',
project_id='fake_pid_1', filters={}, is_public=False project_id='fake_pid_1', filters={}, is_public=False
@ -311,10 +311,8 @@ class ShareAPITestCase(test.TestCase):
exception.NotAuthorized, exception.NotAuthorized,
self.api.get_all, ctx, filters) self.api.get_all, ctx, filters)
share_api.policy.check_policy.assert_has_calls([ share_api.policy.check_policy.assert_called_once_with(
mock.call(ctx, 'share', 'get_all'), ctx, 'share', policy)
mock.call(ctx, 'share', policy),
])
def test_get_all_admin_filter_by_share_server_and_all_tenants(self): def test_get_all_admin_filter_by_share_server_and_all_tenants(self):
# NOTE(vponomaryov): if share_server_id provided, 'all_tenants' opt # NOTE(vponomaryov): if share_server_id provided, 'all_tenants' opt
@ -326,10 +324,8 @@ class ShareAPITestCase(test.TestCase):
self.mock_object(db_api, 'share_get_all_by_project') self.mock_object(db_api, 'share_get_all_by_project')
shares = self.api.get_all( shares = self.api.get_all(
ctx, {'share_server_id': 'fake_server_3', 'all_tenants': 1}) ctx, {'share_server_id': 'fake_server_3', 'all_tenants': 1})
share_api.policy.check_policy.assert_has_calls([ share_api.policy.check_policy.assert_called_once_with(
mock.call(ctx, 'share', 'get_all'), ctx, 'share', 'list_by_share_server_id')
mock.call(ctx, 'share', 'list_by_share_server_id'),
])
db_api.share_get_all_by_share_server.assert_called_once_with( db_api.share_get_all_by_share_server.assert_called_once_with(
ctx, 'fake_server_3', sort_dir='desc', sort_key='created_at', ctx, 'fake_server_3', sort_dir='desc', sort_key='created_at',
filters={}, filters={},
@ -344,10 +340,9 @@ class ShareAPITestCase(test.TestCase):
db_api, 'share_get_all_by_project', db_api, 'share_get_all_by_project',
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[1::2])) mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[1::2]))
expected_filters = {'display_name': 'bar'} expected_filters = {'display_name': 'bar'}
shares = self.api.get_all(ctx, {'display_name': 'bar'}) shares = self.api.get_all(ctx, {'display_name': 'bar'})
share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'),
])
db_api.share_get_all_by_project.assert_called_once_with( db_api.share_get_all_by_project.assert_called_once_with(
ctx, sort_dir='desc', sort_key='created_at', ctx, sort_dir='desc', sort_key='created_at',
project_id='fake_pid_2', filters=expected_filters, is_public=False project_id='fake_pid_2', filters=expected_filters, is_public=False
@ -377,9 +372,6 @@ class ShareAPITestCase(test.TestCase):
expected_filters = copy.copy(search_opts) expected_filters = copy.copy(search_opts)
shares = self.api.get_all(ctx, search_opts) shares = self.api.get_all(ctx, search_opts)
share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'),
])
db_api.share_get_all_by_project.assert_called_once_with( db_api.share_get_all_by_project.assert_called_once_with(
ctx, sort_dir='desc', sort_key='created_at', ctx, sort_dir='desc', sort_key='created_at',
@ -394,10 +386,9 @@ class ShareAPITestCase(test.TestCase):
ctx = context.RequestContext('fake_uid', 'fake_pid_2', is_admin=True) ctx = context.RequestContext('fake_uid', 'fake_pid_2', is_admin=True)
self.mock_object(db_api, 'share_get_all_by_project', self.mock_object(db_api, 'share_get_all_by_project',
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[1:])) mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[1:]))
shares = self.api.get_all(ctx, {'export_location_' + type: 'test'}) shares = self.api.get_all(ctx, {'export_location_' + type: 'test'})
share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'),
])
db_api.share_get_all_by_project.assert_called_once_with( db_api.share_get_all_by_project.assert_called_once_with(
ctx, sort_dir='desc', sort_key='created_at', ctx, sort_dir='desc', sort_key='created_at',
project_id='fake_pid_2', project_id='fake_pid_2',
@ -409,10 +400,9 @@ class ShareAPITestCase(test.TestCase):
ctx = context.RequestContext('fake_uid', 'fake_pid_2', is_admin=True) ctx = context.RequestContext('fake_uid', 'fake_pid_2', is_admin=True)
self.mock_object(db_api, 'share_get_all', self.mock_object(db_api, 'share_get_all',
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[:1])) mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[:1]))
shares = self.api.get_all(ctx, {'name': 'foo', 'all_tenants': 1}) shares = self.api.get_all(ctx, {'name': 'foo', 'all_tenants': 1})
share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'),
])
db_api.share_get_all.assert_called_once_with( db_api.share_get_all.assert_called_once_with(
ctx, sort_dir='desc', sort_key='created_at', filters={}) ctx, sort_dir='desc', sort_key='created_at', filters={})
self.assertEqual(_FAKE_LIST_OF_ALL_SHARES[:1], shares) self.assertEqual(_FAKE_LIST_OF_ALL_SHARES[:1], shares)
@ -425,9 +415,7 @@ class ShareAPITestCase(test.TestCase):
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[0::2])) mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[0::2]))
shares = self.api.get_all(ctx, {'status': constants.STATUS_AVAILABLE}) shares = self.api.get_all(ctx, {'status': constants.STATUS_AVAILABLE})
share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'),
])
db_api.share_get_all_by_project.assert_called_once_with( db_api.share_get_all_by_project.assert_called_once_with(
ctx, sort_dir='desc', sort_key='created_at', ctx, sort_dir='desc', sort_key='created_at',
project_id='fake_pid_2', filters=expected_filter, is_public=False project_id='fake_pid_2', filters=expected_filter, is_public=False
@ -442,9 +430,6 @@ class ShareAPITestCase(test.TestCase):
expected_filter = {'status': constants.STATUS_ERROR} expected_filter = {'status': constants.STATUS_ERROR}
shares = self.api.get_all( shares = self.api.get_all(
ctx, {'status': constants.STATUS_ERROR, 'all_tenants': 1}) ctx, {'status': constants.STATUS_ERROR, 'all_tenants': 1})
share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'),
])
db_api.share_get_all.assert_called_once_with( db_api.share_get_all.assert_called_once_with(
ctx, sort_dir='desc', sort_key='created_at', ctx, sort_dir='desc', sort_key='created_at',
filters=expected_filter) filters=expected_filter)
@ -455,10 +440,9 @@ class ShareAPITestCase(test.TestCase):
ctx = context.RequestContext('fake_uid', 'fake_pid_2', is_admin=False) ctx = context.RequestContext('fake_uid', 'fake_pid_2', is_admin=False)
self.mock_object(db_api, 'share_get_all_by_project', self.mock_object(db_api, 'share_get_all_by_project',
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[1:])) mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[1:]))
shares = self.api.get_all(ctx, {'all_tenants': 1}) shares = self.api.get_all(ctx, {'all_tenants': 1})
share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'),
])
db_api.share_get_all_by_project.assert_called_once_with( db_api.share_get_all_by_project.assert_called_once_with(
ctx, sort_dir='desc', sort_key='created_at', ctx, sort_dir='desc', sort_key='created_at',
project_id='fake_pid_2', filters={}, is_public=False project_id='fake_pid_2', filters={}, is_public=False
@ -472,11 +456,10 @@ class ShareAPITestCase(test.TestCase):
mock.Mock(side_effect=[ mock.Mock(side_effect=[
_FAKE_LIST_OF_ALL_SHARES[1::2], _FAKE_LIST_OF_ALL_SHARES[1::2],
_FAKE_LIST_OF_ALL_SHARES[2::4]])) _FAKE_LIST_OF_ALL_SHARES[2::4]]))
shares = self.api.get_all( shares = self.api.get_all(
ctx, {'name': 'bar', 'status': constants.STATUS_ERROR}) ctx, {'name': 'bar', 'status': constants.STATUS_ERROR})
share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'),
])
expected_filter_1 = {'status': constants.STATUS_ERROR} expected_filter_1 = {'status': constants.STATUS_ERROR}
expected_filter_2 = {'status': constants.STATUS_AVAILABLE} expected_filter_2 = {'status': constants.STATUS_AVAILABLE}
@ -491,11 +474,8 @@ class ShareAPITestCase(test.TestCase):
# one item expected, two filtered # one item expected, two filtered
shares = self.api.get_all( shares = self.api.get_all(
ctx, {'name': 'foo1', 'status': constants.STATUS_AVAILABLE}) ctx, {'name': 'foo1', 'status': constants.STATUS_AVAILABLE})
self.assertEqual(_FAKE_LIST_OF_ALL_SHARES[2::4], shares) self.assertEqual(_FAKE_LIST_OF_ALL_SHARES[2::4], shares)
share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'),
mock.call(ctx, 'share', 'get_all'),
])
db_api.share_get_all_by_project.assert_has_calls([ db_api.share_get_all_by_project.assert_has_calls([
mock.call( mock.call(
ctx, sort_dir='desc', sort_key='created_at', ctx, sort_dir='desc', sort_key='created_at',
@ -514,9 +494,6 @@ class ShareAPITestCase(test.TestCase):
self.mock_object(db_api, 'share_get_all_by_project', mock.Mock( self.mock_object(db_api, 'share_get_all_by_project', mock.Mock(
return_value=_FAKE_LIST_OF_ALL_SHARES[1:])) return_value=_FAKE_LIST_OF_ALL_SHARES[1:]))
shares = self.api.get_all(ctx, {'is_public': is_public}) shares = self.api.get_all(ctx, {'is_public': is_public})
share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'),
])
db_api.share_get_all_by_project.assert_called_once_with( db_api.share_get_all_by_project.assert_called_once_with(
ctx, sort_dir='desc', sort_key='created_at', ctx, sort_dir='desc', sort_key='created_at',
project_id='fake_pid_2', filters={}, is_public=True project_id='fake_pid_2', filters={}, is_public=True
@ -530,9 +507,6 @@ class ShareAPITestCase(test.TestCase):
self.mock_object(db_api, 'share_get_all_by_project', mock.Mock( self.mock_object(db_api, 'share_get_all_by_project', mock.Mock(
return_value=_FAKE_LIST_OF_ALL_SHARES[1:])) return_value=_FAKE_LIST_OF_ALL_SHARES[1:]))
shares = self.api.get_all(ctx, {'is_public': is_public}) shares = self.api.get_all(ctx, {'is_public': is_public})
share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'),
])
db_api.share_get_all_by_project.assert_called_once_with( db_api.share_get_all_by_project.assert_called_once_with(
ctx, sort_dir='desc', sort_key='created_at', ctx, sort_dir='desc', sort_key='created_at',
project_id='fake_pid_2', filters={}, is_public=False project_id='fake_pid_2', filters={}, is_public=False
@ -545,17 +519,14 @@ class ShareAPITestCase(test.TestCase):
is_admin=False) is_admin=False)
self.assertRaises(ValueError, self.api.get_all, self.assertRaises(ValueError, self.api.get_all,
ctx, {'is_public': is_public}) ctx, {'is_public': is_public})
share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'),
])
def test_get_all_with_sorting_valid(self): def test_get_all_with_sorting_valid(self):
self.mock_object(db_api, 'share_get_all_by_project', self.mock_object(db_api, 'share_get_all_by_project',
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[0])) mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[0]))
ctx = context.RequestContext('fake_uid', 'fake_pid_1', is_admin=False) ctx = context.RequestContext('fake_uid', 'fake_pid_1', is_admin=False)
shares = self.api.get_all(ctx, sort_key='status', sort_dir='asc') shares = self.api.get_all(ctx, sort_key='status', sort_dir='asc')
share_api.policy.check_policy.assert_called_once_with(
ctx, 'share', 'get_all')
db_api.share_get_all_by_project.assert_called_once_with( db_api.share_get_all_by_project.assert_called_once_with(
ctx, sort_dir='asc', sort_key='status', ctx, sort_dir='asc', sort_key='status',
project_id='fake_pid_1', filters={}, is_public=False project_id='fake_pid_1', filters={}, is_public=False
@ -572,8 +543,6 @@ class ShareAPITestCase(test.TestCase):
ctx, ctx,
sort_key=1, sort_key=1,
) )
share_api.policy.check_policy.assert_called_once_with(
ctx, 'share', 'get_all')
def test_get_all_sort_dir_invalid(self): def test_get_all_sort_dir_invalid(self):
self.mock_object(db_api, 'share_get_all_by_project', self.mock_object(db_api, 'share_get_all_by_project',
@ -585,23 +554,18 @@ class ShareAPITestCase(test.TestCase):
ctx, ctx,
sort_dir=1, sort_dir=1,
) )
share_api.policy.check_policy.assert_called_once_with(
ctx, 'share', 'get_all')
def _get_all_filter_metadata_or_extra_specs_valid(self, key): def _get_all_filter_metadata_or_extra_specs_valid(self, key):
self.mock_object(db_api, 'share_get_all_by_project', self.mock_object(db_api, 'share_get_all_by_project',
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[0])) mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[0]))
ctx = context.RequestContext('fake_uid', 'fake_pid_1', is_admin=False) ctx = context.RequestContext('fake_uid', 'fake_pid_1', is_admin=False)
search_opts = {key: {'foo1': 'bar1', 'foo2': 'bar2'}} search_opts = {key: {'foo1': 'bar1', 'foo2': 'bar2'}}
shares = self.api.get_all(ctx, search_opts=search_opts.copy()) shares = self.api.get_all(ctx, search_opts=search_opts.copy())
if key == 'extra_specs': if key == 'extra_specs':
share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'),
mock.call(ctx, 'share_types_extra_spec', 'index'),
])
else:
share_api.policy.check_policy.assert_called_once_with( share_api.policy.check_policy.assert_called_once_with(
ctx, 'share', 'get_all') ctx, 'share_types_extra_spec', 'index')
db_api.share_get_all_by_project.assert_called_once_with( db_api.share_get_all_by_project.assert_called_once_with(
ctx, sort_dir='desc', sort_key='created_at', ctx, sort_dir='desc', sort_key='created_at',
project_id='fake_pid_1', filters=search_opts, is_public=False) project_id='fake_pid_1', filters=search_opts, is_public=False)
@ -621,13 +585,8 @@ class ShareAPITestCase(test.TestCase):
self.assertRaises(exception.InvalidInput, self.api.get_all, ctx, self.assertRaises(exception.InvalidInput, self.api.get_all, ctx,
search_opts=search_opts) search_opts=search_opts)
if key == 'extra_specs': if key == 'extra_specs':
share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'),
mock.call(ctx, 'share_types_extra_spec', 'index'),
])
else:
share_api.policy.check_policy.assert_called_once_with( share_api.policy.check_policy.assert_called_once_with(
ctx, 'share', 'get_all') ctx, 'share_types_extra_spec', 'index')
def test_get_all_filter_by_invalid_metadata(self): def test_get_all_filter_by_invalid_metadata(self):
self._get_all_filter_metadata_or_extra_specs_invalid(key='metadata') self._get_all_filter_metadata_or_extra_specs_invalid(key='metadata')