Merge "Move shares filtering to database layer"

This commit is contained in:
Zuul 2021-01-14 17:39:25 +00:00 committed by Gerrit Code Review
commit a6b8d5b170
4 changed files with 207 additions and 73 deletions

View File

@ -207,9 +207,19 @@ def apply_sorting(model, query, sort_key, sort_dir):
"sort_key": sort_key, "sort_dir": sort_dir} "sort_key": sort_key, "sort_dir": sort_dir}
raise exception.InvalidInput(reason=msg) raise exception.InvalidInput(reason=msg)
sort_attr = getattr(model, sort_key) # NOTE(maaoyu): We add the additional sort by ID in this case to
sort_method = getattr(sort_attr, sort_dir.lower()) # get deterministic results. Without the ordering by ID this could
return query.order_by(sort_method()) # lead to flapping return lists.
sort_keys = [sort_key]
if sort_key != 'id':
sort_keys.append('id')
for sort_key in sort_keys:
sort_attr = getattr(model, sort_key)
sort_method = getattr(sort_attr, sort_dir.lower())
query = query.order_by(sort_method())
return query
def handle_db_data_error(f): def handle_db_data_error(f):
@ -1911,6 +1921,7 @@ def share_replica_delete(context, share_replica_id, session=None,
################ ################
@require_context
def _share_get_query(context, session=None): def _share_get_query(context, session=None):
if session is None: if session is None:
session = get_session() session = get_session()
@ -1918,6 +1929,96 @@ def _share_get_query(context, session=None):
options(joinedload('share_metadata'))) options(joinedload('share_metadata')))
def _process_share_filters(query, filters, project_id=None, is_public=False):
if filters is None:
filters = {}
share_filter_keys = ['share_group_id', 'snapshot_id']
instance_filter_keys = ['share_server_id', 'status', 'share_type_id',
'host', 'share_network_id']
share_filters = {}
instance_filters = {}
for k, v in filters.items():
share_filters.update({k: v}) if k in share_filter_keys else None
instance_filters.update({k: v}) if k in instance_filter_keys else None
no_key = 'key_is_absent'
def _filter_data(query, model, desired_filters):
for key, value in desired_filters.items():
filter_attr = getattr(model, key, no_key)
if filter_attr == no_key:
pass
query = query.filter(filter_attr == value)
return query
if share_filters:
query = _filter_data(query, models.Share, share_filters)
if instance_filters:
query = _filter_data(query, models.ShareInstance, instance_filters)
if project_id:
if is_public:
query = query.filter(or_(models.Share.project_id == project_id,
models.Share.is_public))
else:
query = query.filter(models.Share.project_id == project_id)
display_name = filters.get('display_name')
if display_name:
query = query.filter(
models.Share.display_name == display_name)
else:
display_name = filters.get('display_name~')
if display_name:
query = query.filter(models.Share.display_name.op('LIKE')(
u'%' + display_name + u'%'))
display_description = filters.get('display_description')
if display_description:
query = query.filter(
models.Share.display_description == display_description)
else:
display_description = filters.get('display_description~')
if display_description:
query = query.filter(models.Share.display_description.op('LIKE')(
u'%' + display_description + u'%'))
export_location_id = filters.pop('export_location_id', None)
export_location_path = filters.pop('export_location_path', None)
if export_location_id or export_location_path:
query = query.join(
models.ShareInstanceExportLocations,
models.ShareInstanceExportLocations.share_instance_id ==
models.ShareInstance.id)
if export_location_path:
query = query.filter(
models.ShareInstanceExportLocations.path ==
export_location_path)
if export_location_id:
query = query.filter(
models.ShareInstanceExportLocations.uuid ==
export_location_id)
if 'metadata' in filters:
for k, v in filters['metadata'].items():
# pylint: disable=no-member
query = query.filter(
or_(models.Share.share_metadata.any(
key=k, value=v)))
if 'extra_specs' in filters:
query = query.join(
models.ShareTypeExtraSpecs,
models.ShareTypeExtraSpecs.share_type_id ==
models.ShareInstance.share_type_id)
for k, v in filters['extra_specs'].items():
query = query.filter(or_(models.ShareTypeExtraSpecs.key == k,
models.ShareTypeExtraSpecs.value == v))
return query
def _metadata_refs(metadata_dict, meta_class): def _metadata_refs(metadata_dict, meta_class):
metadata_refs = [] metadata_refs = []
if metadata_dict: if metadata_dict:
@ -2022,6 +2123,9 @@ def _share_get_all_with_filters(context, project_id=None, share_server_id=None,
:returns: list -- models.Share :returns: list -- models.Share
:raises: exception.InvalidInput :raises: exception.InvalidInput
""" """
if filters is None:
filters = {}
if not sort_key: if not sort_key:
sort_key = 'created_at' sort_key = 'created_at'
if not sort_dir: if not sort_dir:
@ -2033,54 +2137,13 @@ def _share_get_all_with_filters(context, project_id=None, share_server_id=None,
) )
) )
if project_id:
if is_public:
query = query.filter(or_(models.Share.project_id == project_id,
models.Share.is_public))
else:
query = query.filter(models.Share.project_id == project_id)
if share_server_id:
query = query.filter(
models.ShareInstance.share_server_id == share_server_id)
if share_group_id: if share_group_id:
query = query.filter( filters['share_group_id'] = share_group_id
models.Share.share_group_id == share_group_id) if share_server_id:
filters['share_server_id'] = share_server_id
# Apply filters query = _process_share_filters(
if not filters: query, filters, project_id, is_public=is_public)
filters = {}
export_location_id = filters.get('export_location_id')
export_location_path = filters.get('export_location_path')
if export_location_id or export_location_path:
query = query.join(
models.ShareInstanceExportLocations,
models.ShareInstanceExportLocations.share_instance_id ==
models.ShareInstance.id)
if export_location_path:
query = query.filter(
models.ShareInstanceExportLocations.path ==
export_location_path)
if export_location_id:
query = query.filter(
models.ShareInstanceExportLocations.uuid ==
export_location_id)
if 'metadata' in filters:
for k, v in filters['metadata'].items():
# pylint: disable=no-member
query = query.filter(
or_(models.Share.share_metadata.any(
key=k, value=v)))
if 'extra_specs' in filters:
query = query.join(
models.ShareTypeExtraSpecs,
models.ShareTypeExtraSpecs.share_type_id ==
models.ShareInstance.share_type_id)
for k, v in filters['extra_specs'].items():
query = query.filter(or_(models.ShareTypeExtraSpecs.key == k,
models.ShareTypeExtraSpecs.value == v))
try: try:
query = apply_sorting(models.Share, query, sort_key, sort_dir) query = apply_sorting(models.Share, query, sort_key, sort_dir)
@ -2103,8 +2166,12 @@ def _share_get_all_with_filters(context, project_id=None, share_server_id=None,
@require_admin_context @require_admin_context
def share_get_all(context, filters=None, sort_key=None, sort_dir=None): def share_get_all(context, filters=None, sort_key=None, sort_dir=None):
project_id = filters.pop('project_id', None) if filters else None
query = _share_get_all_with_filters( query = _share_get_all_with_filters(
context, filters=filters, sort_key=sort_key, sort_dir=sort_dir) context,
project_id=project_id,
filters=filters, sort_key=sort_key, sort_dir=sort_dir)
return query return query

View File

@ -1772,12 +1772,18 @@ class API(base.Base):
# Prepare filters # Prepare filters
filters = {} filters = {}
if 'export_location_id' in search_opts:
filters['export_location_id'] = search_opts.pop( filter_keys = [
'export_location_id') 'display_name', 'share_group_id', 'display_name~',
if 'export_location_path' in search_opts: 'display_description', 'display_description~', 'snapshot_id',
filters['export_location_path'] = search_opts.pop( 'status', 'share_type_id', 'project_id', 'export_location_id',
'export_location_path') 'export_location_path', 'limit', 'offset', 'host',
'share_network_id']
for key in filter_keys:
if key in search_opts:
filters[key] = search_opts.pop(key)
if 'metadata' in search_opts: if 'metadata' in search_opts:
filters['metadata'] = search_opts.pop('metadata') filters['metadata'] = search_opts.pop('metadata')
if not isinstance(filters['metadata'], dict): if not isinstance(filters['metadata'], dict):
@ -1792,10 +1798,7 @@ class API(base.Base):
msg = _("Wrong extra specs filter provided: " msg = _("Wrong extra specs filter provided: "
"%s.") % six.text_type(filters['extra_specs']) "%s.") % six.text_type(filters['extra_specs'])
raise exception.InvalidInput(reason=msg) raise exception.InvalidInput(reason=msg)
if 'limit' in search_opts:
filters['limit'] = search_opts.pop('limit')
if 'offset' in search_opts:
filters['offset'] = search_opts.pop('offset')
if not (isinstance(sort_key, six.string_types) and sort_key): if not (isinstance(sort_key, six.string_types) and sort_key):
msg = _("Wrong sort_key filter provided: " msg = _("Wrong sort_key filter provided: "
"'%s'.") % six.text_type(sort_key) "'%s'.") % six.text_type(sort_key)
@ -1809,7 +1812,7 @@ class API(base.Base):
is_public = strutils.bool_from_string(is_public, strict=True) is_public = strutils.bool_from_string(is_public, strict=True)
# Get filtered list of shares # Get filtered list of shares
if 'host' in search_opts: if 'host' in filters:
policy.check_policy(context, 'share', 'list_by_host') policy.check_policy(context, 'share', 'list_by_host')
if 'share_server_id' in search_opts: if 'share_server_id' in search_opts:
# NOTE(vponomaryov): this is project_id independent # NOTE(vponomaryov): this is project_id independent

View File

@ -586,6 +586,57 @@ class ShareDatabaseAPITestCase(test.TestCase):
self.assertEqual(0, len( self.assertEqual(0, len(
set(shares_requested_ids) & set(shares_not_requested_ids))) set(shares_requested_ids) & set(shares_not_requested_ids)))
@ddt.data(
({'status': constants.STATUS_AVAILABLE}, 'status',
[constants.STATUS_AVAILABLE, constants.STATUS_ERROR]),
({'share_group_id': 'fake_group_id'}, 'share_group_id',
['fake_group_id', 'group_id']),
({'snapshot_id': 'fake_snapshot_id'}, 'snapshot_id',
['fake_snapshot_id', 'snapshot_id']),
({'share_type_id': 'fake_type_id'}, 'share_type_id',
['fake_type_id', 'type_id']),
({'host': 'fakehost@fakebackend#fakepool'}, 'host',
['fakehost@fakebackend#fakepool', 'foo@bar#test']),
({'share_network_id': 'fake_net_id'}, 'share_network_id',
['fake_net_id', 'net_id']),
({'display_name': 'fake_share_name'}, 'display_name',
['fake_share_name', 'share_name']),
({'display_description': 'fake description'}, 'display_description',
['fake description', 'description'])
)
@ddt.unpack
def test_share_get_all_with_filters(self, filters, key, share_values):
for value in share_values:
kwargs = {key: value}
db_utils.create_share(**kwargs)
results = db_api.share_get_all(self.ctxt, filters=filters)
for share in results:
self.assertEqual(share[key], filters[key])
@ddt.data(
('display_name~', 'display_name',
['fake_name_1', 'fake_name_2', 'fake_name_3'], 'fake_name'),
('display_description~', 'display_description',
['fake desc 1', 'fake desc 2', 'fake desc 3'], 'fake desc')
)
@ddt.unpack
def test_share_get_all_like_filters(
self, filter_name, key, share_values, like_value):
for value in share_values:
kwargs = {key: value}
db_utils.create_share(**kwargs)
db_utils.create_share(
display_name='irrelevant_name',
display_description='should not be queried')
filters = {filter_name: like_value}
results = db_api.share_get_all(self.ctxt, filters=filters)
self.assertEqual(len(share_values), len(results))
@ddt.data(None, 'writable') @ddt.data(None, 'writable')
def test_share_get_has_replicas_field(self, replication_type): def test_share_get_has_replicas_field(self, replication_type):
share = db_utils.create_share(replication_type=replication_type) share = db_utils.create_share(replication_type=replication_type)

View File

@ -398,29 +398,35 @@ class ShareAPITestCase(test.TestCase):
def test_get_all_admin_filter_by_status(self): def test_get_all_admin_filter_by_status(self):
ctx = context.RequestContext('fake_uid', 'fake_pid_2', is_admin=True) ctx = context.RequestContext('fake_uid', 'fake_pid_2', is_admin=True)
self.mock_object(db_api, 'share_get_all_by_project', expected_filter = {'status': constants.STATUS_AVAILABLE}
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[1:])) self.mock_object(
db_api, 'share_get_all_by_project',
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[0::2]))
shares = self.api.get_all(ctx, {'status': constants.STATUS_AVAILABLE}) shares = self.api.get_all(ctx, {'status': constants.STATUS_AVAILABLE})
share_api.policy.check_policy.assert_has_calls([ share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'), mock.call(ctx, 'share', 'get_all'),
]) ])
db_api.share_get_all_by_project.assert_called_once_with( db_api.share_get_all_by_project.assert_called_once_with(
ctx, sort_dir='desc', sort_key='created_at', ctx, sort_dir='desc', sort_key='created_at',
project_id='fake_pid_2', filters={}, is_public=False project_id='fake_pid_2', filters=expected_filter, is_public=False
) )
self.assertEqual(_FAKE_LIST_OF_ALL_SHARES[2::4], shares) self.assertEqual(_FAKE_LIST_OF_ALL_SHARES[0::2], shares)
def test_get_all_admin_filter_by_status_and_all_tenants(self): def test_get_all_admin_filter_by_status_and_all_tenants(self):
ctx = context.RequestContext('fake_uid', 'fake_pid_2', is_admin=True) ctx = context.RequestContext('fake_uid', 'fake_pid_2', is_admin=True)
self.mock_object(db_api, 'share_get_all', self.mock_object(
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES)) db_api, 'share_get_all',
mock.Mock(return_value=_FAKE_LIST_OF_ALL_SHARES[1::2]))
expected_filter = {'status': constants.STATUS_ERROR}
shares = self.api.get_all( shares = self.api.get_all(
ctx, {'status': constants.STATUS_ERROR, 'all_tenants': 1}) ctx, {'status': constants.STATUS_ERROR, 'all_tenants': 1})
share_api.policy.check_policy.assert_has_calls([ share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'), mock.call(ctx, 'share', 'get_all'),
]) ])
db_api.share_get_all.assert_called_once_with( db_api.share_get_all.assert_called_once_with(
ctx, sort_dir='desc', sort_key='created_at', filters={}) ctx, sort_dir='desc', sort_key='created_at',
filters=expected_filter)
self.assertEqual(_FAKE_LIST_OF_ALL_SHARES[1::2], shares) self.assertEqual(_FAKE_LIST_OF_ALL_SHARES[1::2], shares)
def test_get_all_non_admin_filter_by_all_tenants(self): def test_get_all_non_admin_filter_by_all_tenants(self):
@ -447,9 +453,12 @@ class ShareAPITestCase(test.TestCase):
share_api.policy.check_policy.assert_has_calls([ share_api.policy.check_policy.assert_has_calls([
mock.call(ctx, 'share', 'get_all'), mock.call(ctx, 'share', 'get_all'),
]) ])
expected_filter_1 = {'status': constants.STATUS_ERROR}
expected_filter_2 = {'status': constants.STATUS_AVAILABLE}
db_api.share_get_all_by_project.assert_called_once_with( db_api.share_get_all_by_project.assert_called_once_with(
ctx, sort_dir='desc', sort_key='created_at', ctx, sort_dir='desc', sort_key='created_at',
project_id='fake_pid_2', filters={}, is_public=False project_id='fake_pid_2', filters=expected_filter_1, is_public=False
) )
# two items expected, one filtered # two items expected, one filtered
@ -464,10 +473,14 @@ class ShareAPITestCase(test.TestCase):
mock.call(ctx, 'share', 'get_all'), mock.call(ctx, 'share', 'get_all'),
]) ])
db_api.share_get_all_by_project.assert_has_calls([ db_api.share_get_all_by_project.assert_has_calls([
mock.call(ctx, sort_dir='desc', sort_key='created_at', mock.call(
project_id='fake_pid_2', filters={}, is_public=False), ctx, sort_dir='desc', sort_key='created_at',
mock.call(ctx, sort_dir='desc', sort_key='created_at', project_id='fake_pid_2', filters=expected_filter_1,
project_id='fake_pid_2', filters={}, is_public=False), is_public=False),
mock.call(
ctx, sort_dir='desc', sort_key='created_at',
project_id='fake_pid_2', filters=expected_filter_2,
is_public=False),
]) ])
@ddt.data('True', 'true', '1', 'yes', 'y', 'on', 't', True) @ddt.data('True', 'true', '1', 'yes', 'y', 'on', 't', True)