Merge "Fix filtering in list API calls"

This commit is contained in:
Zuul 2017-12-06 01:21:07 +00:00 committed by Gerrit Code Review
commit 82d9f44c62
8 changed files with 124 additions and 30 deletions

View File

@ -146,7 +146,7 @@ amphora-role:
amphora-status:
description: |
The status of the amphora. One of: ``BOOTING``, ``ALLOCATED``, ``READY``,
``PENDING_DELETE``, ``DELETED``, ``ERROR``.
``PENDING_CREATE``, ``PENDING_DELETE``, ``DELETED``, ``ERROR``.
in: body
required: true
type: string

View File

@ -23,6 +23,7 @@ from octavia.api.common import types
from octavia.common.config import cfg
from octavia.common import constants
from octavia.common import exceptions
from octavia.db import models
CONF = cfg.CONF
LOG = logging.getLogger(__name__)
@ -33,6 +34,10 @@ class PaginationHelper(object):
Pass this class to `db.repositories` to apply it on query
"""
_auxiliary_arguments = ('limit', 'marker',
'sort', 'sort_key', 'sort_dir',
'fields', 'page_reverse',
)
def __init__(self, params, sort_dir=constants.DEFAULT_SORT_DIR):
"""Pagination Helper takes params and a default sort direction
@ -173,7 +178,7 @@ class PaginationHelper(object):
links = [types.PageType(**link) for link in links]
return links
def apply(self, query, model):
def apply(self, query, model, enforce_valid_params=True):
"""Returns a query with sorting / pagination criteria added.
Pagination works by requiring a unique sort_key specified by sort_keys.
@ -191,6 +196,7 @@ class PaginationHelper(object):
:param query: the query object to which we should add
paging/sorting/filtering
:param model: the ORM model class
:param enforce_valid_params: check for invalid enteries in self.params
:rtype: sqlalchemy.orm.query.Query
:returns: The query with sorting/pagination/filtering added.
@ -198,18 +204,35 @@ class PaginationHelper(object):
# Add filtering
if CONF.api_settings.allow_filtering:
filter_attrs = [attr for attr in dir(
model.__v2_wsme__
) if not callable(
getattr(model.__v2_wsme__, attr)
) and not attr.startswith("_")]
self.filters = {k: v for (k, v) in self.params.items()
if k in filter_attrs}
# Exclude (valid) arguments that are not used for data filtering
filter_params = {k: v for k, v in self.params.items(
) if k not in self._auxiliary_arguments}
secondary_query_filter = filter_params.pop(
"project_id", None) if (model == models.Amphora) else None
# Tranlate arguments from API standard to data model's field name
filter_params = (
model.__v2_wsme__.translate_dict_keys_to_data_model(
filter_params)
)
if 'loadbalancer_id' in filter_params:
filter_params['load_balancer_id'] = filter_params.pop(
'loadbalancer_id')
# Drop invalid arguments
self.filters = {k: v for (k, v) in filter_params.items()
if k in vars(model.__data_model__())}
if enforce_valid_params and (
len(self.filters) < len(filter_params)
):
raise exceptions.InvalidFilterArgument()
query = model.apply_filter(query, model, self.filters)
if model.__name__ == "Amphora" and 'project_id' in self.params:
if secondary_query_filter is not None:
query = query.filter(model.load_balancer.has(
project_id=self.params['project_id']))
project_id=secondary_query_filter))
# Add sorting
if CONF.api_settings.allow_sorting:

View File

@ -106,6 +106,24 @@ class BaseType(wtypes.Base):
del new_dict[key]
return cls(**new_dict)
@classmethod
def translate_dict_keys_to_data_model(cls, wsme_dict):
"""Translate the keys from wsme class type, to data_model."""
if not hasattr(cls, '_type_to_model_map'):
return wsme_dict
res = {}
for (k, v) in wsme_dict.items():
if k in cls._type_to_model_map:
k = cls._type_to_model_map[k]
if '.' in k:
parent, child = k.split('.')
if parent not in res:
res[parent] = {}
res[parent][child] = v
continue
res[k] = v
return res
def to_dict(self, render_unsets=False):
"""Converts Octavia WSME type to dictionary.
@ -119,7 +137,7 @@ class BaseType(wtypes.Base):
if (isinstance(self.project_id, wtypes.UnsetType) and
not isinstance(self.tenant_id, wtypes.UnsetType)):
self.project_id = self.tenant_id
ret_dict = {}
wsme_dict = {}
for attr in dir(self):
if attr.startswith('_'):
continue
@ -143,19 +161,8 @@ class BaseType(wtypes.Base):
value = None
else:
continue
attr_name = attr
if (hasattr(self, '_type_to_model_map') and
attr in self._type_to_model_map):
renamed = self._type_to_model_map[attr]
if '.' in renamed:
parent, child = renamed.split('.')
if parent not in ret_dict:
ret_dict[parent] = {}
ret_dict[parent][child] = value
continue
attr_name = renamed
ret_dict[attr_name] = value
return ret_dict
wsme_dict[attr] = value
return self.translate_dict_keys_to_data_model(wsme_dict)
class IdOnlyType(BaseType):

View File

@ -99,7 +99,8 @@ MUTABLE_STATUSES = (ACTIVE,)
DELETABLE_STATUSES = (ACTIVE, ERROR)
SUPPORTED_AMPHORA_STATUSES = (AMPHORA_ALLOCATED, AMPHORA_BOOTING, ERROR,
AMPHORA_READY, DELETED, PENDING_DELETE)
AMPHORA_READY, DELETED,
PENDING_CREATE, PENDING_DELETE)
ONLINE = 'ONLINE'
OFFLINE = 'OFFLINE'

View File

@ -86,6 +86,11 @@ class InvalidOption(APIException):
code = 400
class InvalidFilterArgument(APIException):
msg = "One or more arguments are either duplicate or invalid"
code = 400
class DisabledOption(APIException):
msg = _("The selected %(option)s is not allowed in this deployment: "
"%(value)s")

View File

@ -209,13 +209,16 @@ class TestAmphora(base.BaseAPITest):
self.AMPHORAE_PATH,
params={'project_id': self.project_id}
).json.get(self.root_tag_list)
self.assertEqual(1, len(amps))
false_project_id = uuidutils.generate_uuid()
amps = self.get(
self.AMPHORAE_PATH,
params={'project_id': uuidutils.generate_uuid()}
params={'project_id': false_project_id}
).json.get(self.root_tag_list)
self.assertEqual(0, len(amps))
self.assertEqual(int(false_project_id == self.project_id),
len(amps))
def test_get_all_sorted(self):
self._create_additional_amp()

View File

@ -95,6 +95,17 @@ class TestTypeDataModelRenames(base.TestCase):
self.assertEqual(new_type.child_one, child_dict.get('one'))
self.assertEqual(new_type.child_two, child_dict.get('two'))
def test_translate_dict_keys_to_data_model(self):
new_type = TestTypeRename.from_data_model(self.model)
new_type_vars = {
k: getattr(new_type, k) for k in dir(new_type) if not (
callable(getattr(new_type, k)) or k.startswith('_'))
}
self.assertEqual(
set(vars(self.model)),
set(new_type.translate_dict_keys_to_data_model(new_type_vars)),
)
def test_type_to_dict_with_tenant_id(self):
type_dict = TestTypeTenantProject(tenant_id='1234').to_dict()
self.assertEqual('1234', type_dict['project_id'])

View File

@ -105,13 +105,57 @@ class TestPaginationHelper(base.TestCase):
@mock.patch('octavia.api.common.pagination.request')
def test_filter_mismatched_params(self, request_mock):
params = {'id': 'fake_id', 'fields': 'id'}
params = {
'id': 'fake_id',
'fields': 'field',
'limit': '10',
'sort': None,
}
filters = {'id': 'fake_id'}
helper = pagination.PaginationHelper(params)
query_mock = mock.MagicMock()
helper.apply(query_mock, models.LoadBalancer)
self.assertEqual(filters, helper.filters)
helper.apply(query_mock, models.LoadBalancer,
enforce_valid_params=True)
self.assertEqual(filters, helper.filters)
@mock.patch('octavia.api.common.pagination.request')
def test_filter_with_invalid_params(self, request_mock):
params = {'id': 'fake_id', 'no_such_param': 'id'}
filters = {'id': 'fake_id'}
helper = pagination.PaginationHelper(params)
query_mock = mock.MagicMock()
helper.apply(query_mock, models.LoadBalancer,
# silently ignore invalid parameter
enforce_valid_params=False)
self.assertEqual(filters, helper.filters)
self.assertRaises(
exceptions.InvalidFilterArgument,
pagination.PaginationHelper.apply,
helper,
query_mock,
models.Amphora,
)
@mock.patch('octavia.api.common.pagination.request')
def test_duplicate_argument(self, request_mock):
params = {'loadbalacer_id': 'id1', 'load_balacer_id': 'id2'}
query_mock = mock.MagicMock()
helper = pagination.PaginationHelper(params)
self.assertRaises(
exceptions.InvalidFilterArgument,
pagination.PaginationHelper.apply,
helper,
query_mock,
models.Amphora,
)
@mock.patch('octavia.api.common.pagination.request')
def test_fields_not_passed(self, request_mock):