Fix wrong totalcount returned by share listing query

This bugfix [1] modified the totalcount returned by pagination
query when the argument 'limit' was specified. It caused
manila to do not return precise count of shares in a query that
satisfied the conditions.

This bug has been fixed and now manila is returning the precise
values of shares matched in a given query. Also, manila is now
performing filtering actions in the database to have more
performatic results.

[1] https://review.opendev.org/#/c/688542/

Closes-Bug: #1860061

Co-Authored-By: Carlos Eduardo <ces.eduardo98@gmail.com>

Change-Id: I6ddd919bbd5180593cc52bf986912f65a2dab3a7
This commit is contained in:
maaoyu 2020-01-17 15:22:32 +08:00 committed by silvacarloss
parent 0d8f415e86
commit 268686c448
9 changed files with 288 additions and 78 deletions

View File

@ -164,12 +164,16 @@ class ShareMixin(object):
common.remove_invalid_options( common.remove_invalid_options(
context, search_opts, self._get_share_search_options()) context, search_opts, self._get_share_search_options())
shares = self.share_api.get_all(
context, search_opts=search_opts, sort_key=sort_key,
sort_dir=sort_dir)
total_count = None total_count = None
if show_count: if show_count:
total_count = len(shares) count, shares = self.share_api.get_all_with_count(
context, search_opts=search_opts, sort_key=sort_key,
sort_dir=sort_dir)
total_count = count
else:
shares = self.share_api.get_all(
context, search_opts=search_opts, sort_key=sort_key,
sort_dir=sort_dir)
if is_detail: if is_detail:
shares = self._view_builder.detail_list(req, shares, total_count) shares = self._view_builder.detail_list(req, shares, total_count)
@ -189,8 +193,7 @@ class ShareMixin(object):
'is_public', 'metadata', 'extra_specs', 'sort_key', 'sort_dir', 'is_public', 'metadata', 'extra_specs', 'sort_key', 'sort_dir',
'share_group_id', 'share_group_snapshot_id', 'export_location_id', 'share_group_id', 'share_group_snapshot_id', 'export_location_id',
'export_location_path', 'display_name~', 'display_description~', 'export_location_path', 'display_name~', 'display_description~',
'display_description', 'limit', 'offset' 'display_description', 'limit', 'offset')
)
@wsgi.Controller.authorize @wsgi.Controller.authorize
def update(self, req, id, body): def update(self, req, id, body):

View File

@ -401,13 +401,28 @@ def share_get_all(context, filters=None, sort_key=None, sort_dir=None):
) )
def share_get_all_with_count(context, filters=None, sort_key=None,
sort_dir=None):
"""Get all shares."""
return IMPL.share_get_all_with_count(
context, filters=filters, sort_key=sort_key, sort_dir=sort_dir)
def share_get_all_by_project(context, project_id, filters=None, def share_get_all_by_project(context, project_id, filters=None,
is_public=False, sort_key=None, sort_dir=None): is_public=False, sort_key=None, sort_dir=None):
"""Returns all shares with given project ID.""" """Returns all shares with given project ID."""
return IMPL.share_get_all_by_project( return IMPL.share_get_all_by_project(
context, project_id, filters=filters, is_public=is_public, context, project_id, filters=filters, is_public=is_public,
sort_key=sort_key, sort_dir=sort_dir, sort_key=sort_key, sort_dir=sort_dir)
)
def share_get_all_by_project_with_count(
context, project_id, filters=None, is_public=False, sort_key=None,
sort_dir=None,):
"""Returns all shares with given project ID."""
return IMPL.share_get_all_by_project_with_count(
context, project_id, filters=filters, is_public=is_public,
sort_key=sort_key, sort_dir=sort_dir)
def share_get_all_by_share_group_id(context, share_group_id, def share_get_all_by_share_group_id(context, share_group_id,
@ -419,13 +434,29 @@ def share_get_all_by_share_group_id(context, share_group_id,
sort_key=sort_key, sort_dir=sort_dir) sort_key=sort_key, sort_dir=sort_dir)
def share_get_all_by_share_group_id_with_count(context, share_group_id,
filters=None, sort_key=None,
sort_dir=None):
"""Returns all shares with given project ID and share group id."""
return IMPL.share_get_all_by_share_group_id_with_count(
context, share_group_id, filters=filters, sort_key=sort_key,
sort_dir=sort_dir)
def share_get_all_by_share_server(context, share_server_id, filters=None, def share_get_all_by_share_server(context, share_server_id, filters=None,
sort_key=None, sort_dir=None): sort_key=None, sort_dir=None):
"""Returns all shares with given share server ID.""" """Returns all shares with given share server ID."""
return IMPL.share_get_all_by_share_server( return IMPL.share_get_all_by_share_server(
context, share_server_id, filters=filters, sort_key=sort_key, context, share_server_id, filters=filters, sort_key=sort_key,
sort_dir=sort_dir, sort_dir=sort_dir)
)
def share_get_all_by_share_server_with_count(
context, share_server_id, filters=None, sort_key=None, sort_dir=None):
"""Returns all shares with given share server ID."""
return IMPL.share_get_all_by_share_server_with_count(
context, share_server_id, filters=filters, sort_key=sort_key,
sort_dir=sort_dir)
def share_delete(context, share_id): def share_delete(context, share_id):

View File

@ -43,6 +43,7 @@ from sqlalchemy import MetaData
from sqlalchemy import or_ from sqlalchemy import or_
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from sqlalchemy.orm import subqueryload from sqlalchemy.orm import subqueryload
from sqlalchemy.sql import distinct
from sqlalchemy.sql.expression import literal from sqlalchemy.sql.expression import literal
from sqlalchemy.sql.expression import true from sqlalchemy.sql.expression import true
from sqlalchemy.sql import func from sqlalchemy.sql import func
@ -2108,7 +2109,7 @@ def share_get(context, share_id, session=None):
def _share_get_all_with_filters(context, project_id=None, share_server_id=None, def _share_get_all_with_filters(context, project_id=None, share_server_id=None,
share_group_id=None, filters=None, share_group_id=None, filters=None,
is_public=False, sort_key=None, is_public=False, sort_key=None,
sort_dir=None): sort_dir=None, show_count=False):
"""Returns sorted list of shares that satisfies filters. """Returns sorted list of shares that satisfies filters.
:param context: context to query under :param context: context to query under
@ -2154,12 +2155,23 @@ def _share_get_all_with_filters(context, project_id=None, share_server_id=None,
msg = _("Wrong sorting key provided - '%s'.") % sort_key msg = _("Wrong sorting key provided - '%s'.") % sort_key
raise exception.InvalidInput(reason=msg) raise exception.InvalidInput(reason=msg)
count = None
# NOTE(carloss): Count must be calculated before limit and offset are
# applied into the query.
if show_count:
count = query.with_entities(
func.count(distinct(models.Share.id))).scalar()
if 'limit' in filters: if 'limit' in filters:
offset = filters.get('offset', 0) offset = filters.get('offset', 0)
query = query.limit(filters['limit']).offset(offset) query = query.limit(filters['limit']).offset(offset)
# Returns list of shares that satisfy filters. # Returns list of shares that satisfy filters.
query = query.all() query = query.all()
if show_count:
return count, query
return query return query
@ -2174,17 +2186,37 @@ def share_get_all(context, filters=None, sort_key=None, sort_dir=None):
return query return query
@require_admin_context
def share_get_all_with_count(context, filters=None, sort_key=None,
sort_dir=None):
count, query = _share_get_all_with_filters(
context,
filters=filters, sort_key=sort_key, sort_dir=sort_dir,
show_count=True)
return count, query
@require_context @require_context
def share_get_all_by_project(context, project_id, filters=None, def share_get_all_by_project(context, project_id, filters=None,
is_public=False, sort_key=None, sort_dir=None): is_public=False, sort_key=None, sort_dir=None):
"""Returns list of shares with given project ID.""" """Returns list of shares with given project ID."""
query = _share_get_all_with_filters( query = _share_get_all_with_filters(
context, project_id=project_id, filters=filters, is_public=is_public, context, project_id=project_id, filters=filters, is_public=is_public,
sort_key=sort_key, sort_dir=sort_dir, sort_key=sort_key, sort_dir=sort_dir)
)
return query return query
@require_context
def share_get_all_by_project_with_count(
context, project_id, filters=None, is_public=False, sort_key=None,
sort_dir=None):
"""Returns list of shares with given project ID."""
count, query = _share_get_all_with_filters(
context, project_id=project_id, filters=filters, is_public=is_public,
sort_key=sort_key, sort_dir=sort_dir, show_count=True)
return count, query
@require_context @require_context
def share_get_all_by_share_group_id(context, share_group_id, def share_get_all_by_share_group_id(context, share_group_id,
filters=None, sort_key=None, filters=None, sort_key=None,
@ -2192,22 +2224,41 @@ def share_get_all_by_share_group_id(context, share_group_id,
"""Returns list of shares with given group ID.""" """Returns list of shares with given group ID."""
query = _share_get_all_with_filters( query = _share_get_all_with_filters(
context, share_group_id=share_group_id, context, share_group_id=share_group_id,
filters=filters, sort_key=sort_key, sort_dir=sort_dir, filters=filters, sort_key=sort_key, sort_dir=sort_dir)
)
return query return query
@require_context
def share_get_all_by_share_group_id_with_count(context, share_group_id,
filters=None, sort_key=None,
sort_dir=None):
"""Returns list of shares with given share group ID."""
count, query = _share_get_all_with_filters(
context, share_group_id=share_group_id,
filters=filters, sort_key=sort_key, sort_dir=sort_dir, show_count=True)
return count, query
@require_context @require_context
def share_get_all_by_share_server(context, share_server_id, filters=None, def share_get_all_by_share_server(context, share_server_id, filters=None,
sort_key=None, sort_dir=None): sort_key=None, sort_dir=None):
"""Returns list of shares with given share server.""" """Returns list of shares with given share server."""
query = _share_get_all_with_filters( query = _share_get_all_with_filters(
context, share_server_id=share_server_id, filters=filters, context, share_server_id=share_server_id, filters=filters,
sort_key=sort_key, sort_dir=sort_dir, sort_key=sort_key, sort_dir=sort_dir)
)
return query return query
@require_context
def share_get_all_by_share_server_with_count(
context, share_server_id, filters=None, sort_key=None, sort_dir=None):
"""Returns list of shares with given share server."""
count, query = _share_get_all_with_filters(
context, share_server_id=share_server_id, filters=filters,
sort_key=sort_key, sort_dir=sort_dir, show_count=True)
return count, query
@require_context @require_context
def share_delete(context, share_id): def share_delete(context, share_id):
session = get_session() session = get_session()

View File

@ -1763,6 +1763,17 @@ class API(base.Base):
def get_all(self, context, search_opts=None, sort_key='created_at', def get_all(self, context, search_opts=None, sort_key='created_at',
sort_dir='desc'): sort_dir='desc'):
return self._get_all(context, search_opts=search_opts,
sort_key=sort_key, sort_dir=sort_dir)
def get_all_with_count(self, context, search_opts=None,
sort_key='created_at', sort_dir='desc'):
return self._get_all(context, search_opts=search_opts,
sort_key=sort_key, sort_dir=sort_dir,
show_count=True)
def _get_all(self, context, search_opts=None, sort_key='created_at',
sort_dir='desc', show_count=False):
policy.check_policy(context, 'share', 'get_all') policy.check_policy(context, 'share', 'get_all')
if search_opts is None: if search_opts is None:
@ -1811,36 +1822,43 @@ class API(base.Base):
is_public = search_opts.pop('is_public', False) is_public = search_opts.pop('is_public', False)
is_public = strutils.bool_from_string(is_public, strict=True) is_public = strutils.bool_from_string(is_public, strict=True)
get_methods = {
'get_by_share_server': (
self.db.share_get_all_by_share_server_with_count
if show_count else self.db.share_get_all_by_share_server),
'get_all': (
self.db.share_get_all_with_count
if show_count else self.db.share_get_all),
'get_all_by_project': (
self.db.share_get_all_by_project_with_count
if show_count else self.db.share_get_all_by_project)}
# Get filtered list of shares # Get filtered list of shares
if 'host' in filters: if 'host' in filters:
policy.check_policy(context, 'share', 'list_by_host') policy.check_policy(context, 'share', 'list_by_host')
if 'share_server_id' in search_opts: if 'share_server_id' in search_opts:
# NOTE(vponomaryov): this is project_id independent # NOTE(vponomaryov): this is project_id independent
policy.check_policy(context, 'share', 'list_by_share_server_id') policy.check_policy(context, 'share', 'list_by_share_server_id')
shares = self.db.share_get_all_by_share_server( result = get_methods['get_by_share_server'](
context, search_opts.pop('share_server_id'), filters=filters, context, search_opts.pop('share_server_id'), filters=filters,
sort_key=sort_key, sort_dir=sort_dir) sort_key=sort_key, sort_dir=sort_dir)
elif (context.is_admin and utils.is_all_tenants(search_opts)): elif context.is_admin and utils.is_all_tenants(search_opts):
shares = self.db.share_get_all( result = get_methods['get_all'](
context, filters=filters, sort_key=sort_key, sort_dir=sort_dir) context, filters=filters, sort_key=sort_key, sort_dir=sort_dir)
else: else:
shares = self.db.share_get_all_by_project( result = get_methods['get_all_by_project'](
context, project_id=context.project_id, filters=filters, context, project_id=context.project_id, filters=filters,
is_public=is_public, sort_key=sort_key, sort_dir=sort_dir) is_public=is_public, sort_key=sort_key, sort_dir=sort_dir)
# NOTE(vponomaryov): we do not need 'all_tenants' opt anymore if show_count:
search_opts.pop('all_tenants', None) count = result[0]
shares = result[1]
else:
shares = result
if search_opts: result = (count, shares) if show_count else shares
results = []
for s in shares: return result
# values in search_opts can be only strings
if (all(s.get(k, None) == v or (v in (s.get(k.rstrip('~'))
if k.endswith('~') and s.get(k.rstrip('~')) else ()))
for k, v in search_opts.items())):
results.append(s)
shares = results
return shares
def get_snapshot(self, context, snapshot_id): def get_snapshot(self, context, snapshot_id):
policy.check_policy(context, 'share_snapshot', 'get_snapshot') policy.check_policy(context, 'share_snapshot', 'get_snapshot')

View File

@ -1622,9 +1622,19 @@ class ShareAPITest(test.TestCase):
search_opts.update( search_opts.update(
{'display_name~': 'fake', {'display_name~': 'fake',
'display_description~': 'fake'}) 'display_description~': 'fake'})
method = 'get_all'
shares = [
{'id': 'id1', 'display_name': 'n1'},
{'id': 'id2', 'display_name': 'n2'},
{'id': 'id3', 'display_name': 'n3'},
]
mock_action = {'return_value': [shares[1]]}
if (api_version.APIVersionRequest(version) >= if (api_version.APIVersionRequest(version) >=
api_version.APIVersionRequest('2.42')): api_version.APIVersionRequest('2.42')):
search_opts.update({'with_count': 'true'}) search_opts.update({'with_count': 'true'})
method = 'get_all_with_count'
mock_action = {'side_effect': [(1, [shares[1]])]}
if use_admin_context: if use_admin_context:
search_opts['host'] = 'fake_host' search_opts['host'] = 'fake_host'
# fake_key should be filtered for non-admin # fake_key should be filtered for non-admin
@ -1634,13 +1644,8 @@ class ShareAPITest(test.TestCase):
req = fakes.HTTPRequest.blank(url, version=version, req = fakes.HTTPRequest.blank(url, version=version,
use_admin_context=use_admin_context) use_admin_context=use_admin_context)
shares = [ mock_get_all = (
{'id': 'id1', 'display_name': 'n1'}, self.mock_object(share_api.API, method, mock.Mock(**mock_action)))
{'id': 'id2', 'display_name': 'n2'},
{'id': 'id3', 'display_name': 'n3'},
]
self.mock_object(share_api.API, 'get_all',
mock.Mock(return_value=[shares[1]]))
result = self.controller.index(req) result = self.controller.index(req)
@ -1672,7 +1677,7 @@ class ShareAPITest(test.TestCase):
if use_admin_context: if use_admin_context:
search_opts_expected.update({'fake_key': 'fake_value'}) search_opts_expected.update({'fake_key': 'fake_value'})
search_opts_expected['host'] = search_opts['host'] search_opts_expected['host'] = search_opts['host']
share_api.API.get_all.assert_called_once_with( mock_get_all.assert_called_once_with(
req.environ['manila.context'], req.environ['manila.context'],
sort_key=search_opts['sort_key'], sort_key=search_opts['sort_key'],
sort_dir=search_opts['sort_dir'], sort_dir=search_opts['sort_dir'],
@ -1706,8 +1711,8 @@ class ShareAPITest(test.TestCase):
req = fakes.HTTPRequest.blank(url, version=version, req = fakes.HTTPRequest.blank(url, version=version,
use_admin_context=use_admin_context) use_admin_context=use_admin_context)
self.mock_object(share_api.API, 'get_all', self.mock_object(share_api.API, 'get_all_with_count',
mock.Mock(return_value=[])) mock.Mock(side_effect=[(0, [])]))
result = self.controller.index(req) result = self.controller.index(req)
@ -1716,7 +1721,7 @@ class ShareAPITest(test.TestCase):
if use_admin_context: if use_admin_context:
search_opts_expected.update({'fake_key': 'fake_value'}) search_opts_expected.update({'fake_key': 'fake_value'})
search_opts_expected['host'] = search_opts['host'] search_opts_expected['host'] = search_opts['host']
share_api.API.get_all.assert_called_once_with( share_api.API.get_all_with_count.assert_called_once_with(
req.environ['manila.context'], req.environ['manila.context'],
sort_key=search_opts['sort_key'], sort_key=search_opts['sort_key'],
sort_dir=search_opts['sort_dir'], sort_dir=search_opts['sort_dir'],
@ -1776,18 +1781,6 @@ class ShareAPITest(test.TestCase):
'export_location_id': 'fake_export_location_id', 'export_location_id': 'fake_export_location_id',
'export_location_path': 'fake_export_location_path', 'export_location_path': 'fake_export_location_path',
} }
if (api_version.APIVersionRequest(version) >=
api_version.APIVersionRequest('2.42')):
search_opts.update({'with_count': 'true'})
if use_admin_context:
search_opts['host'] = 'fake_host'
# fake_key should be filtered for non-admin
url = '/v2/fake/shares/detail?fake_key=fake_value'
for k, v in search_opts.items():
url = url + '&' + k + '=' + v
req = fakes.HTTPRequest.blank(url, version=version,
use_admin_context=use_admin_context)
shares = [ shares = [
{'id': 'id1', 'display_name': 'n1'}, {'id': 'id1', 'display_name': 'n1'},
{ {
@ -1805,8 +1798,24 @@ class ShareAPITest(test.TestCase):
{'id': 'id3', 'display_name': 'n3'}, {'id': 'id3', 'display_name': 'n3'},
] ]
self.mock_object(share_api.API, 'get_all', method = 'get_all'
mock.Mock(return_value=[shares[1]])) mock_action = {'return_value': [shares[1]]}
if (api_version.APIVersionRequest(version) >=
api_version.APIVersionRequest('2.42')):
search_opts.update({'with_count': 'true'})
method = 'get_all_with_count'
mock_action = {'side_effect': [(1, [shares[1]])]}
if use_admin_context:
search_opts['host'] = 'fake_host'
# fake_key should be filtered for non-admin
url = '/v2/fake/shares/detail?fake_key=fake_value'
for k, v in search_opts.items():
url = url + '&' + k + '=' + v
req = fakes.HTTPRequest.blank(url, version=version,
use_admin_context=use_admin_context)
mock_get_all = self.mock_object(share_api.API, method,
mock.Mock(**mock_action))
result = self.controller.detail(req) result = self.controller.detail(req)
@ -1834,7 +1843,7 @@ class ShareAPITest(test.TestCase):
if use_admin_context: if use_admin_context:
search_opts_expected.update({'fake_key': 'fake_value'}) search_opts_expected.update({'fake_key': 'fake_value'})
search_opts_expected['host'] = search_opts['host'] search_opts_expected['host'] = search_opts['host']
share_api.API.get_all.assert_called_once_with( mock_get_all.assert_called_once_with(
req.environ['manila.context'], req.environ['manila.context'],
sort_key=search_opts['sort_key'], sort_key=search_opts['sort_key'],
sort_dir=search_opts['sort_dir'], sort_dir=search_opts['sort_dir'],

View File

@ -534,6 +534,17 @@ class ShareDatabaseAPITestCase(test.TestCase):
self.assertEqual(2, len(actual_result)) self.assertEqual(2, len(actual_result))
self.assertEqual(shares[0]['id'], actual_result[1]['id']) self.assertEqual(shares[0]['id'], actual_result[1]['id'])
@ddt.data('id')
def test_share_get_all_sort_by_share_fields(self, sort_key):
shares = [db_utils.create_share(**{sort_key: n, 'size': 1})
for n in ('FAKE_UUID1', 'FAKE_UUID2')]
actual_result = db_api.share_get_all(
self.ctxt, sort_key=sort_key, sort_dir='desc')
self.assertEqual(2, len(actual_result))
self.assertEqual(shares[0]['id'], actual_result[1]['id'])
@ddt.data('id', 'path') @ddt.data('id', 'path')
def test_share_get_all_by_export_location(self, type): def test_share_get_all_by_export_location(self, type):
share = db_utils.create_share() share = db_utils.create_share()
@ -585,6 +596,63 @@ class ShareDatabaseAPITestCase(test.TestCase):
self.assertEqual(0, len( self.assertEqual(0, len(
set(shares_requested_ids) & set(shares_not_requested_ids))) set(shares_requested_ids) & set(shares_not_requested_ids)))
@ddt.data(
({'display_name~': 'fake_name'}, 3, 3),
({'display_name~': 'fake_name', 'limit': 2}, 3, 2)
)
@ddt.unpack
def test_share_get_all_with_count(self, filters, amount_of_shares,
expected_shares_len):
[db_utils.create_share(display_name='fake_name_%s' % str(i))
for i in range(amount_of_shares)]
count, shares = db_api.share_get_all_with_count(
self.ctxt, filters=filters)
self.assertEqual(count, amount_of_shares)
for share in shares:
self.assertIn('fake_name', share['display_name'])
self.assertEqual(expected_shares_len, len(shares))
def test_share_get_all_by_share_group_id_with_count(self):
share_groups = [db_utils.create_share_group() for i in range(2)]
shares = [
db_utils.create_share(share_group_id=share_group['id'])
for share_group in share_groups]
count, result = db_api.share_get_all_by_share_group_id_with_count(
self.ctxt, share_groups[0]['id'])
self.assertEqual(count, 1)
self.assertEqual(shares[0]['id'], result[0]['id'])
self.assertEqual(1, len(result))
def test_share_get_all_by_share_server_with_count(self):
share_servers = [db_utils.create_share_server() for i in range(2)]
shares = [
db_utils.create_share(share_server_id=share_server['id'])
for share_server in share_servers]
count, result = db_api.share_get_all_by_share_server_with_count(
self.ctxt, share_servers[0]['id'])
self.assertEqual(count, 1)
self.assertEqual(shares[0]['id'], result[0]['id'])
self.assertEqual(1, len(result))
def test_share_get_all_by_project_with_count(self):
project_ids = ['fake_id_1', 'fake_id_2']
shares = [
db_utils.create_share(project_id=project_id)
for project_id in project_ids]
count, result = db_api.share_get_all_by_project_with_count(
self.ctxt, project_ids[0])
self.assertEqual(count, 1)
self.assertEqual(shares[0]['id'], result[0]['id'])
self.assertEqual(1, len(result))
@ddt.data( @ddt.data(
({'status': constants.STATUS_AVAILABLE}, 'status', ({'status': constants.STATUS_AVAILABLE}, 'status',
[constants.STATUS_AVAILABLE, constants.STATUS_ERROR]), [constants.STATUS_AVAILABLE, constants.STATUS_ERROR]),

View File

@ -329,44 +329,54 @@ class ShareAPITestCase(test.TestCase):
def test_get_all_admin_filter_by_name(self): def test_get_all_admin_filter_by_name(self):
ctx = context.RequestContext('fake_uid', 'fake_pid_2', is_admin=True) ctx = context.RequestContext('fake_uid', 'fake_pid_2', is_admin=True)
self.mock_object(db_api, 'share_get_all_by_project', self.mock_object(
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[1:])) db_api, 'share_get_all_by_project',
shares = self.api.get_all(ctx, {'name': 'bar'}) mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[1::2]))
expected_filters = {'display_name': 'bar'}
shares = self.api.get_all(ctx, {'display_name': 'bar'})
share_api.policy.check_policy.assert_has_calls([ share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'), mock.call(ctx, 'share', 'get_all'),
]) ])
db_api.share_get_all_by_project.assert_called_once_with( db_api.share_get_all_by_project.assert_called_once_with(
ctx, sort_dir='desc', sort_key='created_at', ctx, sort_dir='desc', sort_key='created_at',
project_id='fake_pid_2', filters={}, is_public=False project_id='fake_pid_2', filters=expected_filters, is_public=False
) )
self.assertEqual(_FAKE_LIST_OF_ALL_SHARES[1::2], shares) self.assertEqual(_FAKE_LIST_OF_ALL_SHARES[1::2], shares)
@ddt.data(({'name': 'fo'}, 0), ({'description': 'd'}, 0), @ddt.data(({'display_name': 'fo'}, 0), ({'display_description': 'd'}, 0),
({'name': 'foo', 'description': 'd'}, 0), ({'display_name': 'foo', 'display_description': 'd'}, 0),
({'name': 'foo'}, 1), ({'description': 'ds'}, 1), ({'display_name': 'foo'}, 1), ({'display_description': 'ds'}, 1),
({'name~': 'foo', 'description~': 'ds'}, 2), ({'display_name~': 'foo', 'display_description~': 'ds'}, 2),
({'name': 'foo', 'description~': 'ds'}, 1), ({'display_name': 'foo', 'display_description~': 'ds'}, 1),
({'name~': 'foo', 'description': 'ds'}, 1)) ({'display_name~': 'foo', 'display_description': 'ds'}, 1))
@ddt.unpack @ddt.unpack
def test_get_all_admin_filter_by_name_and_description( def test_get_all_admin_filter_by_name_and_description(
self, search_opts, get_share_number): self, search_opts, get_share_number):
ctx = context.RequestContext('fake_uid', 'fake_pid_2', is_admin=True) ctx = context.RequestContext('fake_uid', 'fake_pid_2', is_admin=True)
expected_result = []
if get_share_number == 2:
expected_result = _FAKE_LIST_OF_ALL_SHARES[0::2]
elif get_share_number == 1:
expected_result = _FAKE_LIST_OF_ALL_SHARES[:1]
self.mock_object(db_api, 'share_get_all_by_project', self.mock_object(db_api, 'share_get_all_by_project',
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES)) mock.Mock(return_value=expected_result))
expected_filters = copy.copy(search_opts)
shares = self.api.get_all(ctx, search_opts) shares = self.api.get_all(ctx, search_opts)
share_api.policy.check_policy.assert_has_calls([ share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'), mock.call(ctx, 'share', 'get_all'),
]) ])
db_api.share_get_all_by_project.assert_called_once_with( db_api.share_get_all_by_project.assert_called_once_with(
ctx, sort_dir='desc', sort_key='created_at', ctx, sort_dir='desc', sort_key='created_at',
project_id='fake_pid_2', project_id='fake_pid_2',
filters={}, is_public=False filters=expected_filters, is_public=False
) )
self.assertEqual(get_share_number, len(shares)) self.assertEqual(get_share_number, len(shares))
if get_share_number == 2: self.assertEqual(expected_result, shares)
self.assertEqual(_FAKE_LIST_OF_ALL_SHARES[0::2], shares)
elif get_share_number == 1:
self.assertEqual(_FAKE_LIST_OF_ALL_SHARES[:1], shares)
@ddt.data('id', 'path') @ddt.data('id', 'path')
def test_get_all_admin_filter_by_export_location(self, type): def test_get_all_admin_filter_by_export_location(self, type):
@ -387,7 +397,7 @@ class ShareAPITestCase(test.TestCase):
def test_get_all_admin_filter_by_name_and_all_tenants(self): def test_get_all_admin_filter_by_name_and_all_tenants(self):
ctx = context.RequestContext('fake_uid', 'fake_pid_2', is_admin=True) ctx = context.RequestContext('fake_uid', 'fake_pid_2', is_admin=True)
self.mock_object(db_api, 'share_get_all', self.mock_object(db_api, 'share_get_all',
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES)) mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[:1]))
shares = self.api.get_all(ctx, {'name': 'foo', 'all_tenants': 1}) shares = self.api.get_all(ctx, {'name': 'foo', 'all_tenants': 1})
share_api.policy.check_policy.assert_has_calls([ share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'), mock.call(ctx, 'share', 'get_all'),
@ -446,8 +456,11 @@ class ShareAPITestCase(test.TestCase):
def test_get_all_non_admin_with_name_and_status_filters(self): def test_get_all_non_admin_with_name_and_status_filters(self):
ctx = context.RequestContext('fake_uid', 'fake_pid_2', is_admin=False) ctx = context.RequestContext('fake_uid', 'fake_pid_2', is_admin=False)
self.mock_object(db_api, 'share_get_all_by_project', self.mock_object(
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[1:])) db_api, 'share_get_all_by_project',
mock.Mock(side_effect=[
_FAKE_LIST_OF_ALL_SHARES[1::2],
_FAKE_LIST_OF_ALL_SHARES[2::4]]))
shares = self.api.get_all( shares = self.api.get_all(
ctx, {'name': 'bar', 'status': constants.STATUS_ERROR}) ctx, {'name': 'bar', 'status': constants.STATUS_ERROR})
share_api.policy.check_policy.assert_has_calls([ share_api.policy.check_policy.assert_has_calls([

View File

@ -393,6 +393,16 @@ def is_valid_ip_address(ip_address, ip_version):
return False return False
def get_bool_param(param_string, params, default=False):
param = params.get(param_string, default)
if not strutils.is_valid_boolstr(param):
msg = _("Value '%(param)s' for '%(param_string)s' is not "
"a boolean.") % {'param': param, 'param_string': param_string}
raise exception.InvalidParameterValue(err=msg)
return strutils.bool_from_string(param, strict=True)
def is_all_tenants(search_opts): def is_all_tenants(search_opts):
"""Checks to see if the all_tenants flag is in search_opts """Checks to see if the all_tenants flag is in search_opts

View File

@ -0,0 +1,7 @@
---
fixes:
- |
Fixed the issue that caused pagination queries to return erroneous
results when the argument `limit` was specified. Also improved the
queries performance by moving some filtering operations to the
database.