Merge "Fix filtering in list API calls"
This commit is contained in:
commit
82d9f44c62
@ -146,7 +146,7 @@ amphora-role:
|
||||
amphora-status:
|
||||
description: |
|
||||
The status of the amphora. One of: ``BOOTING``, ``ALLOCATED``, ``READY``,
|
||||
``PENDING_DELETE``, ``DELETED``, ``ERROR``.
|
||||
``PENDING_CREATE``, ``PENDING_DELETE``, ``DELETED``, ``ERROR``.
|
||||
in: body
|
||||
required: true
|
||||
type: string
|
||||
|
@ -23,6 +23,7 @@ from octavia.api.common import types
|
||||
from octavia.common.config import cfg
|
||||
from octavia.common import constants
|
||||
from octavia.common import exceptions
|
||||
from octavia.db import models
|
||||
|
||||
CONF = cfg.CONF
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -33,6 +34,10 @@ class PaginationHelper(object):
|
||||
|
||||
Pass this class to `db.repositories` to apply it on query
|
||||
"""
|
||||
_auxiliary_arguments = ('limit', 'marker',
|
||||
'sort', 'sort_key', 'sort_dir',
|
||||
'fields', 'page_reverse',
|
||||
)
|
||||
|
||||
def __init__(self, params, sort_dir=constants.DEFAULT_SORT_DIR):
|
||||
"""Pagination Helper takes params and a default sort direction
|
||||
@ -173,7 +178,7 @@ class PaginationHelper(object):
|
||||
links = [types.PageType(**link) for link in links]
|
||||
return links
|
||||
|
||||
def apply(self, query, model):
|
||||
def apply(self, query, model, enforce_valid_params=True):
|
||||
"""Returns a query with sorting / pagination criteria added.
|
||||
|
||||
Pagination works by requiring a unique sort_key specified by sort_keys.
|
||||
@ -191,6 +196,7 @@ class PaginationHelper(object):
|
||||
:param query: the query object to which we should add
|
||||
paging/sorting/filtering
|
||||
:param model: the ORM model class
|
||||
:param enforce_valid_params: check for invalid enteries in self.params
|
||||
|
||||
:rtype: sqlalchemy.orm.query.Query
|
||||
:returns: The query with sorting/pagination/filtering added.
|
||||
@ -198,18 +204,35 @@ class PaginationHelper(object):
|
||||
|
||||
# Add filtering
|
||||
if CONF.api_settings.allow_filtering:
|
||||
filter_attrs = [attr for attr in dir(
|
||||
model.__v2_wsme__
|
||||
) if not callable(
|
||||
getattr(model.__v2_wsme__, attr)
|
||||
) and not attr.startswith("_")]
|
||||
self.filters = {k: v for (k, v) in self.params.items()
|
||||
if k in filter_attrs}
|
||||
# Exclude (valid) arguments that are not used for data filtering
|
||||
filter_params = {k: v for k, v in self.params.items(
|
||||
) if k not in self._auxiliary_arguments}
|
||||
|
||||
secondary_query_filter = filter_params.pop(
|
||||
"project_id", None) if (model == models.Amphora) else None
|
||||
|
||||
# Tranlate arguments from API standard to data model's field name
|
||||
filter_params = (
|
||||
model.__v2_wsme__.translate_dict_keys_to_data_model(
|
||||
filter_params)
|
||||
)
|
||||
if 'loadbalancer_id' in filter_params:
|
||||
filter_params['load_balancer_id'] = filter_params.pop(
|
||||
'loadbalancer_id')
|
||||
|
||||
# Drop invalid arguments
|
||||
self.filters = {k: v for (k, v) in filter_params.items()
|
||||
if k in vars(model.__data_model__())}
|
||||
|
||||
if enforce_valid_params and (
|
||||
len(self.filters) < len(filter_params)
|
||||
):
|
||||
raise exceptions.InvalidFilterArgument()
|
||||
|
||||
query = model.apply_filter(query, model, self.filters)
|
||||
if model.__name__ == "Amphora" and 'project_id' in self.params:
|
||||
if secondary_query_filter is not None:
|
||||
query = query.filter(model.load_balancer.has(
|
||||
project_id=self.params['project_id']))
|
||||
project_id=secondary_query_filter))
|
||||
|
||||
# Add sorting
|
||||
if CONF.api_settings.allow_sorting:
|
||||
|
@ -106,6 +106,24 @@ class BaseType(wtypes.Base):
|
||||
del new_dict[key]
|
||||
return cls(**new_dict)
|
||||
|
||||
@classmethod
|
||||
def translate_dict_keys_to_data_model(cls, wsme_dict):
|
||||
"""Translate the keys from wsme class type, to data_model."""
|
||||
if not hasattr(cls, '_type_to_model_map'):
|
||||
return wsme_dict
|
||||
res = {}
|
||||
for (k, v) in wsme_dict.items():
|
||||
if k in cls._type_to_model_map:
|
||||
k = cls._type_to_model_map[k]
|
||||
if '.' in k:
|
||||
parent, child = k.split('.')
|
||||
if parent not in res:
|
||||
res[parent] = {}
|
||||
res[parent][child] = v
|
||||
continue
|
||||
res[k] = v
|
||||
return res
|
||||
|
||||
def to_dict(self, render_unsets=False):
|
||||
"""Converts Octavia WSME type to dictionary.
|
||||
|
||||
@ -119,7 +137,7 @@ class BaseType(wtypes.Base):
|
||||
if (isinstance(self.project_id, wtypes.UnsetType) and
|
||||
not isinstance(self.tenant_id, wtypes.UnsetType)):
|
||||
self.project_id = self.tenant_id
|
||||
ret_dict = {}
|
||||
wsme_dict = {}
|
||||
for attr in dir(self):
|
||||
if attr.startswith('_'):
|
||||
continue
|
||||
@ -143,19 +161,8 @@ class BaseType(wtypes.Base):
|
||||
value = None
|
||||
else:
|
||||
continue
|
||||
attr_name = attr
|
||||
if (hasattr(self, '_type_to_model_map') and
|
||||
attr in self._type_to_model_map):
|
||||
renamed = self._type_to_model_map[attr]
|
||||
if '.' in renamed:
|
||||
parent, child = renamed.split('.')
|
||||
if parent not in ret_dict:
|
||||
ret_dict[parent] = {}
|
||||
ret_dict[parent][child] = value
|
||||
continue
|
||||
attr_name = renamed
|
||||
ret_dict[attr_name] = value
|
||||
return ret_dict
|
||||
wsme_dict[attr] = value
|
||||
return self.translate_dict_keys_to_data_model(wsme_dict)
|
||||
|
||||
|
||||
class IdOnlyType(BaseType):
|
||||
|
@ -99,7 +99,8 @@ MUTABLE_STATUSES = (ACTIVE,)
|
||||
DELETABLE_STATUSES = (ACTIVE, ERROR)
|
||||
|
||||
SUPPORTED_AMPHORA_STATUSES = (AMPHORA_ALLOCATED, AMPHORA_BOOTING, ERROR,
|
||||
AMPHORA_READY, DELETED, PENDING_DELETE)
|
||||
AMPHORA_READY, DELETED,
|
||||
PENDING_CREATE, PENDING_DELETE)
|
||||
|
||||
ONLINE = 'ONLINE'
|
||||
OFFLINE = 'OFFLINE'
|
||||
|
@ -86,6 +86,11 @@ class InvalidOption(APIException):
|
||||
code = 400
|
||||
|
||||
|
||||
class InvalidFilterArgument(APIException):
|
||||
msg = "One or more arguments are either duplicate or invalid"
|
||||
code = 400
|
||||
|
||||
|
||||
class DisabledOption(APIException):
|
||||
msg = _("The selected %(option)s is not allowed in this deployment: "
|
||||
"%(value)s")
|
||||
|
@ -209,13 +209,16 @@ class TestAmphora(base.BaseAPITest):
|
||||
self.AMPHORAE_PATH,
|
||||
params={'project_id': self.project_id}
|
||||
).json.get(self.root_tag_list)
|
||||
|
||||
self.assertEqual(1, len(amps))
|
||||
|
||||
false_project_id = uuidutils.generate_uuid()
|
||||
amps = self.get(
|
||||
self.AMPHORAE_PATH,
|
||||
params={'project_id': uuidutils.generate_uuid()}
|
||||
params={'project_id': false_project_id}
|
||||
).json.get(self.root_tag_list)
|
||||
self.assertEqual(0, len(amps))
|
||||
|
||||
self.assertEqual(int(false_project_id == self.project_id),
|
||||
len(amps))
|
||||
|
||||
def test_get_all_sorted(self):
|
||||
self._create_additional_amp()
|
||||
|
@ -95,6 +95,17 @@ class TestTypeDataModelRenames(base.TestCase):
|
||||
self.assertEqual(new_type.child_one, child_dict.get('one'))
|
||||
self.assertEqual(new_type.child_two, child_dict.get('two'))
|
||||
|
||||
def test_translate_dict_keys_to_data_model(self):
|
||||
new_type = TestTypeRename.from_data_model(self.model)
|
||||
new_type_vars = {
|
||||
k: getattr(new_type, k) for k in dir(new_type) if not (
|
||||
callable(getattr(new_type, k)) or k.startswith('_'))
|
||||
}
|
||||
self.assertEqual(
|
||||
set(vars(self.model)),
|
||||
set(new_type.translate_dict_keys_to_data_model(new_type_vars)),
|
||||
)
|
||||
|
||||
def test_type_to_dict_with_tenant_id(self):
|
||||
type_dict = TestTypeTenantProject(tenant_id='1234').to_dict()
|
||||
self.assertEqual('1234', type_dict['project_id'])
|
||||
|
@ -105,13 +105,57 @@ class TestPaginationHelper(base.TestCase):
|
||||
|
||||
@mock.patch('octavia.api.common.pagination.request')
|
||||
def test_filter_mismatched_params(self, request_mock):
|
||||
params = {'id': 'fake_id', 'fields': 'id'}
|
||||
params = {
|
||||
'id': 'fake_id',
|
||||
'fields': 'field',
|
||||
'limit': '10',
|
||||
'sort': None,
|
||||
}
|
||||
|
||||
filters = {'id': 'fake_id'}
|
||||
|
||||
helper = pagination.PaginationHelper(params)
|
||||
query_mock = mock.MagicMock()
|
||||
|
||||
helper.apply(query_mock, models.LoadBalancer)
|
||||
self.assertEqual(filters, helper.filters)
|
||||
helper.apply(query_mock, models.LoadBalancer,
|
||||
enforce_valid_params=True)
|
||||
self.assertEqual(filters, helper.filters)
|
||||
|
||||
@mock.patch('octavia.api.common.pagination.request')
|
||||
def test_filter_with_invalid_params(self, request_mock):
|
||||
params = {'id': 'fake_id', 'no_such_param': 'id'}
|
||||
filters = {'id': 'fake_id'}
|
||||
helper = pagination.PaginationHelper(params)
|
||||
query_mock = mock.MagicMock()
|
||||
|
||||
helper.apply(query_mock, models.LoadBalancer,
|
||||
# silently ignore invalid parameter
|
||||
enforce_valid_params=False)
|
||||
self.assertEqual(filters, helper.filters)
|
||||
|
||||
self.assertRaises(
|
||||
exceptions.InvalidFilterArgument,
|
||||
pagination.PaginationHelper.apply,
|
||||
helper,
|
||||
query_mock,
|
||||
models.Amphora,
|
||||
)
|
||||
|
||||
@mock.patch('octavia.api.common.pagination.request')
|
||||
def test_duplicate_argument(self, request_mock):
|
||||
params = {'loadbalacer_id': 'id1', 'load_balacer_id': 'id2'}
|
||||
query_mock = mock.MagicMock()
|
||||
helper = pagination.PaginationHelper(params)
|
||||
|
||||
self.assertRaises(
|
||||
exceptions.InvalidFilterArgument,
|
||||
pagination.PaginationHelper.apply,
|
||||
helper,
|
||||
query_mock,
|
||||
models.Amphora,
|
||||
)
|
||||
|
||||
@mock.patch('octavia.api.common.pagination.request')
|
||||
def test_fields_not_passed(self, request_mock):
|
||||
|
Loading…
Reference in New Issue
Block a user