diff --git a/api-ref/source/parameters.yaml b/api-ref/source/parameters.yaml index 0af41a7998..2adfcce803 100644 --- a/api-ref/source/parameters.yaml +++ b/api-ref/source/parameters.yaml @@ -146,7 +146,7 @@ amphora-role: amphora-status: description: | The status of the amphora. One of: ``BOOTING``, ``ALLOCATED``, ``READY``, - ``PENDING_DELETE``, ``DELETED``, ``ERROR``. + ``PENDING_CREATE``, ``PENDING_DELETE``, ``DELETED``, ``ERROR``. in: body required: true type: string diff --git a/octavia/api/common/pagination.py b/octavia/api/common/pagination.py index 00c660c49f..bf18341232 100644 --- a/octavia/api/common/pagination.py +++ b/octavia/api/common/pagination.py @@ -23,6 +23,7 @@ from octavia.api.common import types from octavia.common.config import cfg from octavia.common import constants from octavia.common import exceptions +from octavia.db import models CONF = cfg.CONF LOG = logging.getLogger(__name__) @@ -33,6 +34,10 @@ class PaginationHelper(object): Pass this class to `db.repositories` to apply it on query """ + _auxiliary_arguments = ('limit', 'marker', + 'sort', 'sort_key', 'sort_dir', + 'fields', 'page_reverse', + ) def __init__(self, params, sort_dir=constants.DEFAULT_SORT_DIR): """Pagination Helper takes params and a default sort direction @@ -173,7 +178,7 @@ class PaginationHelper(object): links = [types.PageType(**link) for link in links] return links - def apply(self, query, model): + def apply(self, query, model, enforce_valid_params=True): """Returns a query with sorting / pagination criteria added. Pagination works by requiring a unique sort_key specified by sort_keys. @@ -191,6 +196,7 @@ class PaginationHelper(object): :param query: the query object to which we should add paging/sorting/filtering :param model: the ORM model class + :param enforce_valid_params: check for invalid enteries in self.params :rtype: sqlalchemy.orm.query.Query :returns: The query with sorting/pagination/filtering added. @@ -198,18 +204,35 @@ class PaginationHelper(object): # Add filtering if CONF.api_settings.allow_filtering: - filter_attrs = [attr for attr in dir( - model.__v2_wsme__ - ) if not callable( - getattr(model.__v2_wsme__, attr) - ) and not attr.startswith("_")] - self.filters = {k: v for (k, v) in self.params.items() - if k in filter_attrs} + # Exclude (valid) arguments that are not used for data filtering + filter_params = {k: v for k, v in self.params.items( + ) if k not in self._auxiliary_arguments} + + secondary_query_filter = filter_params.pop( + "project_id", None) if (model == models.Amphora) else None + + # Tranlate arguments from API standard to data model's field name + filter_params = ( + model.__v2_wsme__.translate_dict_keys_to_data_model( + filter_params) + ) + if 'loadbalancer_id' in filter_params: + filter_params['load_balancer_id'] = filter_params.pop( + 'loadbalancer_id') + + # Drop invalid arguments + self.filters = {k: v for (k, v) in filter_params.items() + if k in vars(model.__data_model__())} + + if enforce_valid_params and ( + len(self.filters) < len(filter_params) + ): + raise exceptions.InvalidFilterArgument() query = model.apply_filter(query, model, self.filters) - if model.__name__ == "Amphora" and 'project_id' in self.params: + if secondary_query_filter is not None: query = query.filter(model.load_balancer.has( - project_id=self.params['project_id'])) + project_id=secondary_query_filter)) # Add sorting if CONF.api_settings.allow_sorting: diff --git a/octavia/api/common/types.py b/octavia/api/common/types.py index 77a9c5069c..28997462ff 100644 --- a/octavia/api/common/types.py +++ b/octavia/api/common/types.py @@ -106,6 +106,24 @@ class BaseType(wtypes.Base): del new_dict[key] return cls(**new_dict) + @classmethod + def translate_dict_keys_to_data_model(cls, wsme_dict): + """Translate the keys from wsme class type, to data_model.""" + if not hasattr(cls, '_type_to_model_map'): + return wsme_dict + res = {} + for (k, v) in wsme_dict.items(): + if k in cls._type_to_model_map: + k = cls._type_to_model_map[k] + if '.' in k: + parent, child = k.split('.') + if parent not in res: + res[parent] = {} + res[parent][child] = v + continue + res[k] = v + return res + def to_dict(self, render_unsets=False): """Converts Octavia WSME type to dictionary. @@ -119,7 +137,7 @@ class BaseType(wtypes.Base): if (isinstance(self.project_id, wtypes.UnsetType) and not isinstance(self.tenant_id, wtypes.UnsetType)): self.project_id = self.tenant_id - ret_dict = {} + wsme_dict = {} for attr in dir(self): if attr.startswith('_'): continue @@ -143,19 +161,8 @@ class BaseType(wtypes.Base): value = None else: continue - attr_name = attr - if (hasattr(self, '_type_to_model_map') and - attr in self._type_to_model_map): - renamed = self._type_to_model_map[attr] - if '.' in renamed: - parent, child = renamed.split('.') - if parent not in ret_dict: - ret_dict[parent] = {} - ret_dict[parent][child] = value - continue - attr_name = renamed - ret_dict[attr_name] = value - return ret_dict + wsme_dict[attr] = value + return self.translate_dict_keys_to_data_model(wsme_dict) class IdOnlyType(BaseType): diff --git a/octavia/common/constants.py b/octavia/common/constants.py index 54622727af..558413510c 100644 --- a/octavia/common/constants.py +++ b/octavia/common/constants.py @@ -99,7 +99,8 @@ MUTABLE_STATUSES = (ACTIVE,) DELETABLE_STATUSES = (ACTIVE, ERROR) SUPPORTED_AMPHORA_STATUSES = (AMPHORA_ALLOCATED, AMPHORA_BOOTING, ERROR, - AMPHORA_READY, DELETED, PENDING_DELETE) + AMPHORA_READY, DELETED, + PENDING_CREATE, PENDING_DELETE) ONLINE = 'ONLINE' OFFLINE = 'OFFLINE' diff --git a/octavia/common/exceptions.py b/octavia/common/exceptions.py index 20f0976f26..65458e13ba 100644 --- a/octavia/common/exceptions.py +++ b/octavia/common/exceptions.py @@ -86,6 +86,11 @@ class InvalidOption(APIException): code = 400 +class InvalidFilterArgument(APIException): + msg = "One or more arguments are either duplicate or invalid" + code = 400 + + class DisabledOption(APIException): msg = _("The selected %(option)s is not allowed in this deployment: " "%(value)s") diff --git a/octavia/tests/functional/api/v2/test_amphora.py b/octavia/tests/functional/api/v2/test_amphora.py index 7c78aeba1d..d16f08e68f 100644 --- a/octavia/tests/functional/api/v2/test_amphora.py +++ b/octavia/tests/functional/api/v2/test_amphora.py @@ -209,13 +209,16 @@ class TestAmphora(base.BaseAPITest): self.AMPHORAE_PATH, params={'project_id': self.project_id} ).json.get(self.root_tag_list) - self.assertEqual(1, len(amps)) + + false_project_id = uuidutils.generate_uuid() amps = self.get( self.AMPHORAE_PATH, - params={'project_id': uuidutils.generate_uuid()} + params={'project_id': false_project_id} ).json.get(self.root_tag_list) - self.assertEqual(0, len(amps)) + + self.assertEqual(int(false_project_id == self.project_id), + len(amps)) def test_get_all_sorted(self): self._create_additional_amp() diff --git a/octavia/tests/unit/api/common/test_types.py b/octavia/tests/unit/api/common/test_types.py index d55a3c8756..d0b0724941 100644 --- a/octavia/tests/unit/api/common/test_types.py +++ b/octavia/tests/unit/api/common/test_types.py @@ -95,6 +95,17 @@ class TestTypeDataModelRenames(base.TestCase): self.assertEqual(new_type.child_one, child_dict.get('one')) self.assertEqual(new_type.child_two, child_dict.get('two')) + def test_translate_dict_keys_to_data_model(self): + new_type = TestTypeRename.from_data_model(self.model) + new_type_vars = { + k: getattr(new_type, k) for k in dir(new_type) if not ( + callable(getattr(new_type, k)) or k.startswith('_')) + } + self.assertEqual( + set(vars(self.model)), + set(new_type.translate_dict_keys_to_data_model(new_type_vars)), + ) + def test_type_to_dict_with_tenant_id(self): type_dict = TestTypeTenantProject(tenant_id='1234').to_dict() self.assertEqual('1234', type_dict['project_id']) diff --git a/octavia/tests/unit/api/hooks/test_query_parameters.py b/octavia/tests/unit/api/hooks/test_query_parameters.py index 286e618333..78c7871c2d 100644 --- a/octavia/tests/unit/api/hooks/test_query_parameters.py +++ b/octavia/tests/unit/api/hooks/test_query_parameters.py @@ -105,13 +105,57 @@ class TestPaginationHelper(base.TestCase): @mock.patch('octavia.api.common.pagination.request') def test_filter_mismatched_params(self, request_mock): - params = {'id': 'fake_id', 'fields': 'id'} + params = { + 'id': 'fake_id', + 'fields': 'field', + 'limit': '10', + 'sort': None, + } + filters = {'id': 'fake_id'} + helper = pagination.PaginationHelper(params) query_mock = mock.MagicMock() helper.apply(query_mock, models.LoadBalancer) self.assertEqual(filters, helper.filters) + helper.apply(query_mock, models.LoadBalancer, + enforce_valid_params=True) + self.assertEqual(filters, helper.filters) + + @mock.patch('octavia.api.common.pagination.request') + def test_filter_with_invalid_params(self, request_mock): + params = {'id': 'fake_id', 'no_such_param': 'id'} + filters = {'id': 'fake_id'} + helper = pagination.PaginationHelper(params) + query_mock = mock.MagicMock() + + helper.apply(query_mock, models.LoadBalancer, + # silently ignore invalid parameter + enforce_valid_params=False) + self.assertEqual(filters, helper.filters) + + self.assertRaises( + exceptions.InvalidFilterArgument, + pagination.PaginationHelper.apply, + helper, + query_mock, + models.Amphora, + ) + + @mock.patch('octavia.api.common.pagination.request') + def test_duplicate_argument(self, request_mock): + params = {'loadbalacer_id': 'id1', 'load_balacer_id': 'id2'} + query_mock = mock.MagicMock() + helper = pagination.PaginationHelper(params) + + self.assertRaises( + exceptions.InvalidFilterArgument, + pagination.PaginationHelper.apply, + helper, + query_mock, + models.Amphora, + ) @mock.patch('octavia.api.common.pagination.request') def test_fields_not_passed(self, request_mock):