Merge "Fix wrong totalcount returned by share listing query"

This commit is contained in:
Zuul 2021-03-25 21:36:36 +00:00 committed by Gerrit Code Review
commit 4617e20e7b
9 changed files with 288 additions and 78 deletions

View File

@ -164,12 +164,16 @@ class ShareMixin(object):
common.remove_invalid_options(
context, search_opts, self._get_share_search_options())
shares = self.share_api.get_all(
context, search_opts=search_opts, sort_key=sort_key,
sort_dir=sort_dir)
total_count = None
if show_count:
total_count = len(shares)
count, shares = self.share_api.get_all_with_count(
context, search_opts=search_opts, sort_key=sort_key,
sort_dir=sort_dir)
total_count = count
else:
shares = self.share_api.get_all(
context, search_opts=search_opts, sort_key=sort_key,
sort_dir=sort_dir)
if is_detail:
shares = self._view_builder.detail_list(req, shares, total_count)
@ -189,8 +193,7 @@ class ShareMixin(object):
'is_public', 'metadata', 'extra_specs', 'sort_key', 'sort_dir',
'share_group_id', 'share_group_snapshot_id', 'export_location_id',
'export_location_path', 'display_name~', 'display_description~',
'display_description', 'limit', 'offset'
)
'display_description', 'limit', 'offset')
@wsgi.Controller.authorize
def update(self, req, id, body):

View File

@ -401,13 +401,28 @@ def share_get_all(context, filters=None, sort_key=None, sort_dir=None):
)
def share_get_all_with_count(context, filters=None, sort_key=None,
sort_dir=None):
"""Get all shares."""
return IMPL.share_get_all_with_count(
context, filters=filters, sort_key=sort_key, sort_dir=sort_dir)
def share_get_all_by_project(context, project_id, filters=None,
is_public=False, sort_key=None, sort_dir=None):
"""Returns all shares with given project ID."""
return IMPL.share_get_all_by_project(
context, project_id, filters=filters, is_public=is_public,
sort_key=sort_key, sort_dir=sort_dir,
)
sort_key=sort_key, sort_dir=sort_dir)
def share_get_all_by_project_with_count(
context, project_id, filters=None, is_public=False, sort_key=None,
sort_dir=None,):
"""Returns all shares with given project ID."""
return IMPL.share_get_all_by_project_with_count(
context, project_id, filters=filters, is_public=is_public,
sort_key=sort_key, sort_dir=sort_dir)
def share_get_all_by_share_group_id(context, share_group_id,
@ -419,13 +434,29 @@ def share_get_all_by_share_group_id(context, share_group_id,
sort_key=sort_key, sort_dir=sort_dir)
def share_get_all_by_share_group_id_with_count(context, share_group_id,
filters=None, sort_key=None,
sort_dir=None):
"""Returns all shares with given project ID and share group id."""
return IMPL.share_get_all_by_share_group_id_with_count(
context, share_group_id, filters=filters, sort_key=sort_key,
sort_dir=sort_dir)
def share_get_all_by_share_server(context, share_server_id, filters=None,
sort_key=None, sort_dir=None):
"""Returns all shares with given share server ID."""
return IMPL.share_get_all_by_share_server(
context, share_server_id, filters=filters, sort_key=sort_key,
sort_dir=sort_dir,
)
sort_dir=sort_dir)
def share_get_all_by_share_server_with_count(
context, share_server_id, filters=None, sort_key=None, sort_dir=None):
"""Returns all shares with given share server ID."""
return IMPL.share_get_all_by_share_server_with_count(
context, share_server_id, filters=filters, sort_key=sort_key,
sort_dir=sort_dir)
def share_delete(context, share_id):

View File

@ -46,6 +46,7 @@ from sqlalchemy import MetaData
from sqlalchemy import or_
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import subqueryload
from sqlalchemy.sql import distinct
from sqlalchemy.sql.expression import literal
from sqlalchemy.sql.expression import true
from sqlalchemy.sql import func
@ -2123,7 +2124,7 @@ def share_get(context, share_id, session=None):
def _share_get_all_with_filters(context, project_id=None, share_server_id=None,
share_group_id=None, filters=None,
is_public=False, sort_key=None,
sort_dir=None):
sort_dir=None, show_count=False):
"""Returns sorted list of shares that satisfies filters.
:param context: context to query under
@ -2169,12 +2170,23 @@ def _share_get_all_with_filters(context, project_id=None, share_server_id=None,
msg = _("Wrong sorting key provided - '%s'.") % sort_key
raise exception.InvalidInput(reason=msg)
count = None
# NOTE(carloss): Count must be calculated before limit and offset are
# applied into the query.
if show_count:
count = query.with_entities(
func.count(distinct(models.Share.id))).scalar()
if 'limit' in filters:
offset = filters.get('offset', 0)
query = query.limit(filters['limit']).offset(offset)
# Returns list of shares that satisfy filters.
query = query.all()
if show_count:
return count, query
return query
@ -2189,17 +2201,37 @@ def share_get_all(context, filters=None, sort_key=None, sort_dir=None):
return query
@require_admin_context
def share_get_all_with_count(context, filters=None, sort_key=None,
sort_dir=None):
count, query = _share_get_all_with_filters(
context,
filters=filters, sort_key=sort_key, sort_dir=sort_dir,
show_count=True)
return count, query
@require_context
def share_get_all_by_project(context, project_id, filters=None,
is_public=False, sort_key=None, sort_dir=None):
"""Returns list of shares with given project ID."""
query = _share_get_all_with_filters(
context, project_id=project_id, filters=filters, is_public=is_public,
sort_key=sort_key, sort_dir=sort_dir,
)
sort_key=sort_key, sort_dir=sort_dir)
return query
@require_context
def share_get_all_by_project_with_count(
context, project_id, filters=None, is_public=False, sort_key=None,
sort_dir=None):
"""Returns list of shares with given project ID."""
count, query = _share_get_all_with_filters(
context, project_id=project_id, filters=filters, is_public=is_public,
sort_key=sort_key, sort_dir=sort_dir, show_count=True)
return count, query
@require_context
def share_get_all_by_share_group_id(context, share_group_id,
filters=None, sort_key=None,
@ -2207,22 +2239,41 @@ def share_get_all_by_share_group_id(context, share_group_id,
"""Returns list of shares with given group ID."""
query = _share_get_all_with_filters(
context, share_group_id=share_group_id,
filters=filters, sort_key=sort_key, sort_dir=sort_dir,
)
filters=filters, sort_key=sort_key, sort_dir=sort_dir)
return query
@require_context
def share_get_all_by_share_group_id_with_count(context, share_group_id,
filters=None, sort_key=None,
sort_dir=None):
"""Returns list of shares with given share group ID."""
count, query = _share_get_all_with_filters(
context, share_group_id=share_group_id,
filters=filters, sort_key=sort_key, sort_dir=sort_dir, show_count=True)
return count, query
@require_context
def share_get_all_by_share_server(context, share_server_id, filters=None,
sort_key=None, sort_dir=None):
"""Returns list of shares with given share server."""
query = _share_get_all_with_filters(
context, share_server_id=share_server_id, filters=filters,
sort_key=sort_key, sort_dir=sort_dir,
)
sort_key=sort_key, sort_dir=sort_dir)
return query
@require_context
def share_get_all_by_share_server_with_count(
context, share_server_id, filters=None, sort_key=None, sort_dir=None):
"""Returns list of shares with given share server."""
count, query = _share_get_all_with_filters(
context, share_server_id=share_server_id, filters=filters,
sort_key=sort_key, sort_dir=sort_dir, show_count=True)
return count, query
@require_context
def share_delete(context, share_id):
session = get_session()

View File

@ -1813,6 +1813,17 @@ class API(base.Base):
def get_all(self, context, search_opts=None, sort_key='created_at',
sort_dir='desc'):
return self._get_all(context, search_opts=search_opts,
sort_key=sort_key, sort_dir=sort_dir)
def get_all_with_count(self, context, search_opts=None,
sort_key='created_at', sort_dir='desc'):
return self._get_all(context, search_opts=search_opts,
sort_key=sort_key, sort_dir=sort_dir,
show_count=True)
def _get_all(self, context, search_opts=None, sort_key='created_at',
sort_dir='desc', show_count=False):
policy.check_policy(context, 'share', 'get_all')
if search_opts is None:
@ -1861,36 +1872,43 @@ class API(base.Base):
is_public = search_opts.pop('is_public', False)
is_public = strutils.bool_from_string(is_public, strict=True)
get_methods = {
'get_by_share_server': (
self.db.share_get_all_by_share_server_with_count
if show_count else self.db.share_get_all_by_share_server),
'get_all': (
self.db.share_get_all_with_count
if show_count else self.db.share_get_all),
'get_all_by_project': (
self.db.share_get_all_by_project_with_count
if show_count else self.db.share_get_all_by_project)}
# Get filtered list of shares
if 'host' in filters:
policy.check_policy(context, 'share', 'list_by_host')
if 'share_server_id' in search_opts:
# NOTE(vponomaryov): this is project_id independent
policy.check_policy(context, 'share', 'list_by_share_server_id')
shares = self.db.share_get_all_by_share_server(
result = get_methods['get_by_share_server'](
context, search_opts.pop('share_server_id'), filters=filters,
sort_key=sort_key, sort_dir=sort_dir)
elif (context.is_admin and utils.is_all_tenants(search_opts)):
shares = self.db.share_get_all(
elif context.is_admin and utils.is_all_tenants(search_opts):
result = get_methods['get_all'](
context, filters=filters, sort_key=sort_key, sort_dir=sort_dir)
else:
shares = self.db.share_get_all_by_project(
result = get_methods['get_all_by_project'](
context, project_id=context.project_id, filters=filters,
is_public=is_public, sort_key=sort_key, sort_dir=sort_dir)
# NOTE(vponomaryov): we do not need 'all_tenants' opt anymore
search_opts.pop('all_tenants', None)
if show_count:
count = result[0]
shares = result[1]
else:
shares = result
if search_opts:
results = []
for s in shares:
# values in search_opts can be only strings
if (all(s.get(k, None) == v or (v in (s.get(k.rstrip('~'))
if k.endswith('~') and s.get(k.rstrip('~')) else ()))
for k, v in search_opts.items())):
results.append(s)
shares = results
return shares
result = (count, shares) if show_count else shares
return result
def get_snapshot(self, context, snapshot_id):
policy.check_policy(context, 'share_snapshot', 'get_snapshot')

View File

@ -1634,9 +1634,19 @@ class ShareAPITest(test.TestCase):
search_opts.update(
{'display_name~': 'fake',
'display_description~': 'fake'})
method = 'get_all'
shares = [
{'id': 'id1', 'display_name': 'n1'},
{'id': 'id2', 'display_name': 'n2'},
{'id': 'id3', 'display_name': 'n3'},
]
mock_action = {'return_value': [shares[1]]}
if (api_version.APIVersionRequest(version) >=
api_version.APIVersionRequest('2.42')):
search_opts.update({'with_count': 'true'})
method = 'get_all_with_count'
mock_action = {'side_effect': [(1, [shares[1]])]}
if use_admin_context:
search_opts['host'] = 'fake_host'
# fake_key should be filtered for non-admin
@ -1646,13 +1656,8 @@ class ShareAPITest(test.TestCase):
req = fakes.HTTPRequest.blank(url, version=version,
use_admin_context=use_admin_context)
shares = [
{'id': 'id1', 'display_name': 'n1'},
{'id': 'id2', 'display_name': 'n2'},
{'id': 'id3', 'display_name': 'n3'},
]
self.mock_object(share_api.API, 'get_all',
mock.Mock(return_value=[shares[1]]))
mock_get_all = (
self.mock_object(share_api.API, method, mock.Mock(**mock_action)))
result = self.controller.index(req)
@ -1684,7 +1689,7 @@ class ShareAPITest(test.TestCase):
if use_admin_context:
search_opts_expected.update({'fake_key': 'fake_value'})
search_opts_expected['host'] = search_opts['host']
share_api.API.get_all.assert_called_once_with(
mock_get_all.assert_called_once_with(
req.environ['manila.context'],
sort_key=search_opts['sort_key'],
sort_dir=search_opts['sort_dir'],
@ -1718,8 +1723,8 @@ class ShareAPITest(test.TestCase):
req = fakes.HTTPRequest.blank(url, version=version,
use_admin_context=use_admin_context)
self.mock_object(share_api.API, 'get_all',
mock.Mock(return_value=[]))
self.mock_object(share_api.API, 'get_all_with_count',
mock.Mock(side_effect=[(0, [])]))
result = self.controller.index(req)
@ -1728,7 +1733,7 @@ class ShareAPITest(test.TestCase):
if use_admin_context:
search_opts_expected.update({'fake_key': 'fake_value'})
search_opts_expected['host'] = search_opts['host']
share_api.API.get_all.assert_called_once_with(
share_api.API.get_all_with_count.assert_called_once_with(
req.environ['manila.context'],
sort_key=search_opts['sort_key'],
sort_dir=search_opts['sort_dir'],
@ -1788,18 +1793,6 @@ class ShareAPITest(test.TestCase):
'export_location_id': 'fake_export_location_id',
'export_location_path': 'fake_export_location_path',
}
if (api_version.APIVersionRequest(version) >=
api_version.APIVersionRequest('2.42')):
search_opts.update({'with_count': 'true'})
if use_admin_context:
search_opts['host'] = 'fake_host'
# fake_key should be filtered for non-admin
url = '/v2/fake/shares/detail?fake_key=fake_value'
for k, v in search_opts.items():
url = url + '&' + k + '=' + v
req = fakes.HTTPRequest.blank(url, version=version,
use_admin_context=use_admin_context)
shares = [
{'id': 'id1', 'display_name': 'n1'},
{
@ -1817,8 +1810,24 @@ class ShareAPITest(test.TestCase):
{'id': 'id3', 'display_name': 'n3'},
]
self.mock_object(share_api.API, 'get_all',
mock.Mock(return_value=[shares[1]]))
method = 'get_all'
mock_action = {'return_value': [shares[1]]}
if (api_version.APIVersionRequest(version) >=
api_version.APIVersionRequest('2.42')):
search_opts.update({'with_count': 'true'})
method = 'get_all_with_count'
mock_action = {'side_effect': [(1, [shares[1]])]}
if use_admin_context:
search_opts['host'] = 'fake_host'
# fake_key should be filtered for non-admin
url = '/v2/fake/shares/detail?fake_key=fake_value'
for k, v in search_opts.items():
url = url + '&' + k + '=' + v
req = fakes.HTTPRequest.blank(url, version=version,
use_admin_context=use_admin_context)
mock_get_all = self.mock_object(share_api.API, method,
mock.Mock(**mock_action))
result = self.controller.detail(req)
@ -1846,7 +1855,7 @@ class ShareAPITest(test.TestCase):
if use_admin_context:
search_opts_expected.update({'fake_key': 'fake_value'})
search_opts_expected['host'] = search_opts['host']
share_api.API.get_all.assert_called_once_with(
mock_get_all.assert_called_once_with(
req.environ['manila.context'],
sort_key=search_opts['sort_key'],
sort_dir=search_opts['sort_dir'],

View File

@ -534,6 +534,17 @@ class ShareDatabaseAPITestCase(test.TestCase):
self.assertEqual(2, len(actual_result))
self.assertEqual(shares[0]['id'], actual_result[1]['id'])
@ddt.data('id')
def test_share_get_all_sort_by_share_fields(self, sort_key):
shares = [db_utils.create_share(**{sort_key: n, 'size': 1})
for n in ('FAKE_UUID1', 'FAKE_UUID2')]
actual_result = db_api.share_get_all(
self.ctxt, sort_key=sort_key, sort_dir='desc')
self.assertEqual(2, len(actual_result))
self.assertEqual(shares[0]['id'], actual_result[1]['id'])
@ddt.data('id', 'path')
def test_share_get_all_by_export_location(self, type):
share = db_utils.create_share()
@ -585,6 +596,63 @@ class ShareDatabaseAPITestCase(test.TestCase):
self.assertEqual(0, len(
set(shares_requested_ids) & set(shares_not_requested_ids)))
@ddt.data(
({'display_name~': 'fake_name'}, 3, 3),
({'display_name~': 'fake_name', 'limit': 2}, 3, 2)
)
@ddt.unpack
def test_share_get_all_with_count(self, filters, amount_of_shares,
expected_shares_len):
[db_utils.create_share(display_name='fake_name_%s' % str(i))
for i in range(amount_of_shares)]
count, shares = db_api.share_get_all_with_count(
self.ctxt, filters=filters)
self.assertEqual(count, amount_of_shares)
for share in shares:
self.assertIn('fake_name', share['display_name'])
self.assertEqual(expected_shares_len, len(shares))
def test_share_get_all_by_share_group_id_with_count(self):
share_groups = [db_utils.create_share_group() for i in range(2)]
shares = [
db_utils.create_share(share_group_id=share_group['id'])
for share_group in share_groups]
count, result = db_api.share_get_all_by_share_group_id_with_count(
self.ctxt, share_groups[0]['id'])
self.assertEqual(count, 1)
self.assertEqual(shares[0]['id'], result[0]['id'])
self.assertEqual(1, len(result))
def test_share_get_all_by_share_server_with_count(self):
share_servers = [db_utils.create_share_server() for i in range(2)]
shares = [
db_utils.create_share(share_server_id=share_server['id'])
for share_server in share_servers]
count, result = db_api.share_get_all_by_share_server_with_count(
self.ctxt, share_servers[0]['id'])
self.assertEqual(count, 1)
self.assertEqual(shares[0]['id'], result[0]['id'])
self.assertEqual(1, len(result))
def test_share_get_all_by_project_with_count(self):
project_ids = ['fake_id_1', 'fake_id_2']
shares = [
db_utils.create_share(project_id=project_id)
for project_id in project_ids]
count, result = db_api.share_get_all_by_project_with_count(
self.ctxt, project_ids[0])
self.assertEqual(count, 1)
self.assertEqual(shares[0]['id'], result[0]['id'])
self.assertEqual(1, len(result))
@ddt.data(
({'status': constants.STATUS_AVAILABLE}, 'status',
[constants.STATUS_AVAILABLE, constants.STATUS_ERROR]),

View File

@ -339,44 +339,54 @@ class ShareAPITestCase(test.TestCase):
def test_get_all_admin_filter_by_name(self):
ctx = context.RequestContext('fake_uid', 'fake_pid_2', is_admin=True)
self.mock_object(db_api, 'share_get_all_by_project',
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[1:]))
shares = self.api.get_all(ctx, {'name': 'bar'})
self.mock_object(
db_api, 'share_get_all_by_project',
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[1::2]))
expected_filters = {'display_name': 'bar'}
shares = self.api.get_all(ctx, {'display_name': 'bar'})
share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'),
])
db_api.share_get_all_by_project.assert_called_once_with(
ctx, sort_dir='desc', sort_key='created_at',
project_id='fake_pid_2', filters={}, is_public=False
project_id='fake_pid_2', filters=expected_filters, is_public=False
)
self.assertEqual(_FAKE_LIST_OF_ALL_SHARES[1::2], shares)
@ddt.data(({'name': 'fo'}, 0), ({'description': 'd'}, 0),
({'name': 'foo', 'description': 'd'}, 0),
({'name': 'foo'}, 1), ({'description': 'ds'}, 1),
({'name~': 'foo', 'description~': 'ds'}, 2),
({'name': 'foo', 'description~': 'ds'}, 1),
({'name~': 'foo', 'description': 'ds'}, 1))
@ddt.data(({'display_name': 'fo'}, 0), ({'display_description': 'd'}, 0),
({'display_name': 'foo', 'display_description': 'd'}, 0),
({'display_name': 'foo'}, 1), ({'display_description': 'ds'}, 1),
({'display_name~': 'foo', 'display_description~': 'ds'}, 2),
({'display_name': 'foo', 'display_description~': 'ds'}, 1),
({'display_name~': 'foo', 'display_description': 'ds'}, 1))
@ddt.unpack
def test_get_all_admin_filter_by_name_and_description(
self, search_opts, get_share_number):
ctx = context.RequestContext('fake_uid', 'fake_pid_2', is_admin=True)
expected_result = []
if get_share_number == 2:
expected_result = _FAKE_LIST_OF_ALL_SHARES[0::2]
elif get_share_number == 1:
expected_result = _FAKE_LIST_OF_ALL_SHARES[:1]
self.mock_object(db_api, 'share_get_all_by_project',
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES))
mock.Mock(return_value=expected_result))
expected_filters = copy.copy(search_opts)
shares = self.api.get_all(ctx, search_opts)
share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'),
])
db_api.share_get_all_by_project.assert_called_once_with(
ctx, sort_dir='desc', sort_key='created_at',
project_id='fake_pid_2',
filters={}, is_public=False
filters=expected_filters, is_public=False
)
self.assertEqual(get_share_number, len(shares))
if get_share_number == 2:
self.assertEqual(_FAKE_LIST_OF_ALL_SHARES[0::2], shares)
elif get_share_number == 1:
self.assertEqual(_FAKE_LIST_OF_ALL_SHARES[:1], shares)
self.assertEqual(expected_result, shares)
@ddt.data('id', 'path')
def test_get_all_admin_filter_by_export_location(self, type):
@ -397,7 +407,7 @@ class ShareAPITestCase(test.TestCase):
def test_get_all_admin_filter_by_name_and_all_tenants(self):
ctx = context.RequestContext('fake_uid', 'fake_pid_2', is_admin=True)
self.mock_object(db_api, 'share_get_all',
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES))
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[:1]))
shares = self.api.get_all(ctx, {'name': 'foo', 'all_tenants': 1})
share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'),
@ -456,8 +466,11 @@ class ShareAPITestCase(test.TestCase):
def test_get_all_non_admin_with_name_and_status_filters(self):
ctx = context.RequestContext('fake_uid', 'fake_pid_2', is_admin=False)
self.mock_object(db_api, 'share_get_all_by_project',
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[1:]))
self.mock_object(
db_api, 'share_get_all_by_project',
mock.Mock(side_effect=[
_FAKE_LIST_OF_ALL_SHARES[1::2],
_FAKE_LIST_OF_ALL_SHARES[2::4]]))
shares = self.api.get_all(
ctx, {'name': 'bar', 'status': constants.STATUS_ERROR})
share_api.policy.check_policy.assert_has_calls([

View File

@ -393,6 +393,16 @@ def is_valid_ip_address(ip_address, ip_version):
return False
def get_bool_param(param_string, params, default=False):
param = params.get(param_string, default)
if not strutils.is_valid_boolstr(param):
msg = _("Value '%(param)s' for '%(param_string)s' is not "
"a boolean.") % {'param': param, 'param_string': param_string}
raise exception.InvalidParameterValue(err=msg)
return strutils.bool_from_string(param, strict=True)
def is_all_tenants(search_opts):
"""Checks to see if the all_tenants flag is in search_opts

View File

@ -0,0 +1,7 @@
---
fixes:
- |
Fixed the issue that caused pagination queries to return erroneous
results when the argument `limit` was specified. Also improved the
queries performance by moving some filtering operations to the
database.