From 3ada33d4c7821554982d5d20f1e8753c17e402e9 Mon Sep 17 00:00:00 2001 From: Bar RH Date: Fri, 24 Nov 2017 04:36:39 +0200 Subject: [PATCH] Fix filtering in list API calls The API's filtering arguments were not handled properly, and therefore, some were consistently ignored. This patch resolves this by translating the argument names to the ORM data model's fields, and then validating them. Additionally, enforcing of arguments validity is now the default behavior. Should unrecognized filtering arguments be entered, the API call will fail with code 400. Task: 5844 Story: 2001224 Change-Id: I8f61880d6c11037d32b96e9827fb4e810dc219c2 --- api-ref/source/parameters.yaml | 2 +- octavia/api/common/pagination.py | 43 +++++++++++++---- octavia/api/common/types.py | 35 ++++++++------ octavia/common/constants.py | 3 +- octavia/common/exceptions.py | 5 ++ .../tests/functional/api/v2/test_amphora.py | 9 ++-- octavia/tests/unit/api/common/test_types.py | 11 +++++ .../unit/api/hooks/test_query_parameters.py | 46 ++++++++++++++++++- 8 files changed, 124 insertions(+), 30 deletions(-) diff --git a/api-ref/source/parameters.yaml b/api-ref/source/parameters.yaml index 0af41a7998..2adfcce803 100644 --- a/api-ref/source/parameters.yaml +++ b/api-ref/source/parameters.yaml @@ -146,7 +146,7 @@ amphora-role: amphora-status: description: | The status of the amphora. One of: ``BOOTING``, ``ALLOCATED``, ``READY``, - ``PENDING_DELETE``, ``DELETED``, ``ERROR``. + ``PENDING_CREATE``, ``PENDING_DELETE``, ``DELETED``, ``ERROR``. in: body required: true type: string diff --git a/octavia/api/common/pagination.py b/octavia/api/common/pagination.py index 00c660c49f..bf18341232 100644 --- a/octavia/api/common/pagination.py +++ b/octavia/api/common/pagination.py @@ -23,6 +23,7 @@ from octavia.api.common import types from octavia.common.config import cfg from octavia.common import constants from octavia.common import exceptions +from octavia.db import models CONF = cfg.CONF LOG = logging.getLogger(__name__) @@ -33,6 +34,10 @@ class PaginationHelper(object): Pass this class to `db.repositories` to apply it on query """ + _auxiliary_arguments = ('limit', 'marker', + 'sort', 'sort_key', 'sort_dir', + 'fields', 'page_reverse', + ) def __init__(self, params, sort_dir=constants.DEFAULT_SORT_DIR): """Pagination Helper takes params and a default sort direction @@ -173,7 +178,7 @@ class PaginationHelper(object): links = [types.PageType(**link) for link in links] return links - def apply(self, query, model): + def apply(self, query, model, enforce_valid_params=True): """Returns a query with sorting / pagination criteria added. Pagination works by requiring a unique sort_key specified by sort_keys. @@ -191,6 +196,7 @@ class PaginationHelper(object): :param query: the query object to which we should add paging/sorting/filtering :param model: the ORM model class + :param enforce_valid_params: check for invalid enteries in self.params :rtype: sqlalchemy.orm.query.Query :returns: The query with sorting/pagination/filtering added. @@ -198,18 +204,35 @@ class PaginationHelper(object): # Add filtering if CONF.api_settings.allow_filtering: - filter_attrs = [attr for attr in dir( - model.__v2_wsme__ - ) if not callable( - getattr(model.__v2_wsme__, attr) - ) and not attr.startswith("_")] - self.filters = {k: v for (k, v) in self.params.items() - if k in filter_attrs} + # Exclude (valid) arguments that are not used for data filtering + filter_params = {k: v for k, v in self.params.items( + ) if k not in self._auxiliary_arguments} + + secondary_query_filter = filter_params.pop( + "project_id", None) if (model == models.Amphora) else None + + # Tranlate arguments from API standard to data model's field name + filter_params = ( + model.__v2_wsme__.translate_dict_keys_to_data_model( + filter_params) + ) + if 'loadbalancer_id' in filter_params: + filter_params['load_balancer_id'] = filter_params.pop( + 'loadbalancer_id') + + # Drop invalid arguments + self.filters = {k: v for (k, v) in filter_params.items() + if k in vars(model.__data_model__())} + + if enforce_valid_params and ( + len(self.filters) < len(filter_params) + ): + raise exceptions.InvalidFilterArgument() query = model.apply_filter(query, model, self.filters) - if model.__name__ == "Amphora" and 'project_id' in self.params: + if secondary_query_filter is not None: query = query.filter(model.load_balancer.has( - project_id=self.params['project_id'])) + project_id=secondary_query_filter)) # Add sorting if CONF.api_settings.allow_sorting: diff --git a/octavia/api/common/types.py b/octavia/api/common/types.py index 77a9c5069c..28997462ff 100644 --- a/octavia/api/common/types.py +++ b/octavia/api/common/types.py @@ -106,6 +106,24 @@ class BaseType(wtypes.Base): del new_dict[key] return cls(**new_dict) + @classmethod + def translate_dict_keys_to_data_model(cls, wsme_dict): + """Translate the keys from wsme class type, to data_model.""" + if not hasattr(cls, '_type_to_model_map'): + return wsme_dict + res = {} + for (k, v) in wsme_dict.items(): + if k in cls._type_to_model_map: + k = cls._type_to_model_map[k] + if '.' in k: + parent, child = k.split('.') + if parent not in res: + res[parent] = {} + res[parent][child] = v + continue + res[k] = v + return res + def to_dict(self, render_unsets=False): """Converts Octavia WSME type to dictionary. @@ -119,7 +137,7 @@ class BaseType(wtypes.Base): if (isinstance(self.project_id, wtypes.UnsetType) and not isinstance(self.tenant_id, wtypes.UnsetType)): self.project_id = self.tenant_id - ret_dict = {} + wsme_dict = {} for attr in dir(self): if attr.startswith('_'): continue @@ -143,19 +161,8 @@ class BaseType(wtypes.Base): value = None else: continue - attr_name = attr - if (hasattr(self, '_type_to_model_map') and - attr in self._type_to_model_map): - renamed = self._type_to_model_map[attr] - if '.' in renamed: - parent, child = renamed.split('.') - if parent not in ret_dict: - ret_dict[parent] = {} - ret_dict[parent][child] = value - continue - attr_name = renamed - ret_dict[attr_name] = value - return ret_dict + wsme_dict[attr] = value + return self.translate_dict_keys_to_data_model(wsme_dict) class IdOnlyType(BaseType): diff --git a/octavia/common/constants.py b/octavia/common/constants.py index 54622727af..558413510c 100644 --- a/octavia/common/constants.py +++ b/octavia/common/constants.py @@ -99,7 +99,8 @@ MUTABLE_STATUSES = (ACTIVE,) DELETABLE_STATUSES = (ACTIVE, ERROR) SUPPORTED_AMPHORA_STATUSES = (AMPHORA_ALLOCATED, AMPHORA_BOOTING, ERROR, - AMPHORA_READY, DELETED, PENDING_DELETE) + AMPHORA_READY, DELETED, + PENDING_CREATE, PENDING_DELETE) ONLINE = 'ONLINE' OFFLINE = 'OFFLINE' diff --git a/octavia/common/exceptions.py b/octavia/common/exceptions.py index 20f0976f26..65458e13ba 100644 --- a/octavia/common/exceptions.py +++ b/octavia/common/exceptions.py @@ -86,6 +86,11 @@ class InvalidOption(APIException): code = 400 +class InvalidFilterArgument(APIException): + msg = "One or more arguments are either duplicate or invalid" + code = 400 + + class DisabledOption(APIException): msg = _("The selected %(option)s is not allowed in this deployment: " "%(value)s") diff --git a/octavia/tests/functional/api/v2/test_amphora.py b/octavia/tests/functional/api/v2/test_amphora.py index 7c78aeba1d..d16f08e68f 100644 --- a/octavia/tests/functional/api/v2/test_amphora.py +++ b/octavia/tests/functional/api/v2/test_amphora.py @@ -209,13 +209,16 @@ class TestAmphora(base.BaseAPITest): self.AMPHORAE_PATH, params={'project_id': self.project_id} ).json.get(self.root_tag_list) - self.assertEqual(1, len(amps)) + + false_project_id = uuidutils.generate_uuid() amps = self.get( self.AMPHORAE_PATH, - params={'project_id': uuidutils.generate_uuid()} + params={'project_id': false_project_id} ).json.get(self.root_tag_list) - self.assertEqual(0, len(amps)) + + self.assertEqual(int(false_project_id == self.project_id), + len(amps)) def test_get_all_sorted(self): self._create_additional_amp() diff --git a/octavia/tests/unit/api/common/test_types.py b/octavia/tests/unit/api/common/test_types.py index d55a3c8756..d0b0724941 100644 --- a/octavia/tests/unit/api/common/test_types.py +++ b/octavia/tests/unit/api/common/test_types.py @@ -95,6 +95,17 @@ class TestTypeDataModelRenames(base.TestCase): self.assertEqual(new_type.child_one, child_dict.get('one')) self.assertEqual(new_type.child_two, child_dict.get('two')) + def test_translate_dict_keys_to_data_model(self): + new_type = TestTypeRename.from_data_model(self.model) + new_type_vars = { + k: getattr(new_type, k) for k in dir(new_type) if not ( + callable(getattr(new_type, k)) or k.startswith('_')) + } + self.assertEqual( + set(vars(self.model)), + set(new_type.translate_dict_keys_to_data_model(new_type_vars)), + ) + def test_type_to_dict_with_tenant_id(self): type_dict = TestTypeTenantProject(tenant_id='1234').to_dict() self.assertEqual('1234', type_dict['project_id']) diff --git a/octavia/tests/unit/api/hooks/test_query_parameters.py b/octavia/tests/unit/api/hooks/test_query_parameters.py index 286e618333..78c7871c2d 100644 --- a/octavia/tests/unit/api/hooks/test_query_parameters.py +++ b/octavia/tests/unit/api/hooks/test_query_parameters.py @@ -105,13 +105,57 @@ class TestPaginationHelper(base.TestCase): @mock.patch('octavia.api.common.pagination.request') def test_filter_mismatched_params(self, request_mock): - params = {'id': 'fake_id', 'fields': 'id'} + params = { + 'id': 'fake_id', + 'fields': 'field', + 'limit': '10', + 'sort': None, + } + filters = {'id': 'fake_id'} + helper = pagination.PaginationHelper(params) query_mock = mock.MagicMock() helper.apply(query_mock, models.LoadBalancer) self.assertEqual(filters, helper.filters) + helper.apply(query_mock, models.LoadBalancer, + enforce_valid_params=True) + self.assertEqual(filters, helper.filters) + + @mock.patch('octavia.api.common.pagination.request') + def test_filter_with_invalid_params(self, request_mock): + params = {'id': 'fake_id', 'no_such_param': 'id'} + filters = {'id': 'fake_id'} + helper = pagination.PaginationHelper(params) + query_mock = mock.MagicMock() + + helper.apply(query_mock, models.LoadBalancer, + # silently ignore invalid parameter + enforce_valid_params=False) + self.assertEqual(filters, helper.filters) + + self.assertRaises( + exceptions.InvalidFilterArgument, + pagination.PaginationHelper.apply, + helper, + query_mock, + models.Amphora, + ) + + @mock.patch('octavia.api.common.pagination.request') + def test_duplicate_argument(self, request_mock): + params = {'loadbalacer_id': 'id1', 'load_balacer_id': 'id2'} + query_mock = mock.MagicMock() + helper = pagination.PaginationHelper(params) + + self.assertRaises( + exceptions.InvalidFilterArgument, + pagination.PaginationHelper.apply, + helper, + query_mock, + models.Amphora, + ) @mock.patch('octavia.api.common.pagination.request') def test_fields_not_passed(self, request_mock):