Merge "Fix filtering in list API calls"

This commit is contained in:
Zuul 2017-12-06 01:21:07 +00:00 committed by Gerrit Code Review
commit 82d9f44c62
8 changed files with 124 additions and 30 deletions

View File

@ -146,7 +146,7 @@ amphora-role:
amphora-status: amphora-status:
description: | description: |
The status of the amphora. One of: ``BOOTING``, ``ALLOCATED``, ``READY``, The status of the amphora. One of: ``BOOTING``, ``ALLOCATED``, ``READY``,
``PENDING_DELETE``, ``DELETED``, ``ERROR``. ``PENDING_CREATE``, ``PENDING_DELETE``, ``DELETED``, ``ERROR``.
in: body in: body
required: true required: true
type: string type: string

View File

@ -23,6 +23,7 @@ from octavia.api.common import types
from octavia.common.config import cfg from octavia.common.config import cfg
from octavia.common import constants from octavia.common import constants
from octavia.common import exceptions from octavia.common import exceptions
from octavia.db import models
CONF = cfg.CONF CONF = cfg.CONF
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -33,6 +34,10 @@ class PaginationHelper(object):
Pass this class to `db.repositories` to apply it on query Pass this class to `db.repositories` to apply it on query
""" """
_auxiliary_arguments = ('limit', 'marker',
'sort', 'sort_key', 'sort_dir',
'fields', 'page_reverse',
)
def __init__(self, params, sort_dir=constants.DEFAULT_SORT_DIR): def __init__(self, params, sort_dir=constants.DEFAULT_SORT_DIR):
"""Pagination Helper takes params and a default sort direction """Pagination Helper takes params and a default sort direction
@ -173,7 +178,7 @@ class PaginationHelper(object):
links = [types.PageType(**link) for link in links] links = [types.PageType(**link) for link in links]
return links return links
def apply(self, query, model): def apply(self, query, model, enforce_valid_params=True):
"""Returns a query with sorting / pagination criteria added. """Returns a query with sorting / pagination criteria added.
Pagination works by requiring a unique sort_key specified by sort_keys. Pagination works by requiring a unique sort_key specified by sort_keys.
@ -191,6 +196,7 @@ class PaginationHelper(object):
:param query: the query object to which we should add :param query: the query object to which we should add
paging/sorting/filtering paging/sorting/filtering
:param model: the ORM model class :param model: the ORM model class
:param enforce_valid_params: check for invalid enteries in self.params
:rtype: sqlalchemy.orm.query.Query :rtype: sqlalchemy.orm.query.Query
:returns: The query with sorting/pagination/filtering added. :returns: The query with sorting/pagination/filtering added.
@ -198,18 +204,35 @@ class PaginationHelper(object):
# Add filtering # Add filtering
if CONF.api_settings.allow_filtering: if CONF.api_settings.allow_filtering:
filter_attrs = [attr for attr in dir( # Exclude (valid) arguments that are not used for data filtering
model.__v2_wsme__ filter_params = {k: v for k, v in self.params.items(
) if not callable( ) if k not in self._auxiliary_arguments}
getattr(model.__v2_wsme__, attr)
) and not attr.startswith("_")] secondary_query_filter = filter_params.pop(
self.filters = {k: v for (k, v) in self.params.items() "project_id", None) if (model == models.Amphora) else None
if k in filter_attrs}
# Tranlate arguments from API standard to data model's field name
filter_params = (
model.__v2_wsme__.translate_dict_keys_to_data_model(
filter_params)
)
if 'loadbalancer_id' in filter_params:
filter_params['load_balancer_id'] = filter_params.pop(
'loadbalancer_id')
# Drop invalid arguments
self.filters = {k: v for (k, v) in filter_params.items()
if k in vars(model.__data_model__())}
if enforce_valid_params and (
len(self.filters) < len(filter_params)
):
raise exceptions.InvalidFilterArgument()
query = model.apply_filter(query, model, self.filters) query = model.apply_filter(query, model, self.filters)
if model.__name__ == "Amphora" and 'project_id' in self.params: if secondary_query_filter is not None:
query = query.filter(model.load_balancer.has( query = query.filter(model.load_balancer.has(
project_id=self.params['project_id'])) project_id=secondary_query_filter))
# Add sorting # Add sorting
if CONF.api_settings.allow_sorting: if CONF.api_settings.allow_sorting:

View File

@ -106,6 +106,24 @@ class BaseType(wtypes.Base):
del new_dict[key] del new_dict[key]
return cls(**new_dict) return cls(**new_dict)
@classmethod
def translate_dict_keys_to_data_model(cls, wsme_dict):
"""Translate the keys from wsme class type, to data_model."""
if not hasattr(cls, '_type_to_model_map'):
return wsme_dict
res = {}
for (k, v) in wsme_dict.items():
if k in cls._type_to_model_map:
k = cls._type_to_model_map[k]
if '.' in k:
parent, child = k.split('.')
if parent not in res:
res[parent] = {}
res[parent][child] = v
continue
res[k] = v
return res
def to_dict(self, render_unsets=False): def to_dict(self, render_unsets=False):
"""Converts Octavia WSME type to dictionary. """Converts Octavia WSME type to dictionary.
@ -119,7 +137,7 @@ class BaseType(wtypes.Base):
if (isinstance(self.project_id, wtypes.UnsetType) and if (isinstance(self.project_id, wtypes.UnsetType) and
not isinstance(self.tenant_id, wtypes.UnsetType)): not isinstance(self.tenant_id, wtypes.UnsetType)):
self.project_id = self.tenant_id self.project_id = self.tenant_id
ret_dict = {} wsme_dict = {}
for attr in dir(self): for attr in dir(self):
if attr.startswith('_'): if attr.startswith('_'):
continue continue
@ -143,19 +161,8 @@ class BaseType(wtypes.Base):
value = None value = None
else: else:
continue continue
attr_name = attr wsme_dict[attr] = value
if (hasattr(self, '_type_to_model_map') and return self.translate_dict_keys_to_data_model(wsme_dict)
attr in self._type_to_model_map):
renamed = self._type_to_model_map[attr]
if '.' in renamed:
parent, child = renamed.split('.')
if parent not in ret_dict:
ret_dict[parent] = {}
ret_dict[parent][child] = value
continue
attr_name = renamed
ret_dict[attr_name] = value
return ret_dict
class IdOnlyType(BaseType): class IdOnlyType(BaseType):

View File

@ -99,7 +99,8 @@ MUTABLE_STATUSES = (ACTIVE,)
DELETABLE_STATUSES = (ACTIVE, ERROR) DELETABLE_STATUSES = (ACTIVE, ERROR)
SUPPORTED_AMPHORA_STATUSES = (AMPHORA_ALLOCATED, AMPHORA_BOOTING, ERROR, SUPPORTED_AMPHORA_STATUSES = (AMPHORA_ALLOCATED, AMPHORA_BOOTING, ERROR,
AMPHORA_READY, DELETED, PENDING_DELETE) AMPHORA_READY, DELETED,
PENDING_CREATE, PENDING_DELETE)
ONLINE = 'ONLINE' ONLINE = 'ONLINE'
OFFLINE = 'OFFLINE' OFFLINE = 'OFFLINE'

View File

@ -86,6 +86,11 @@ class InvalidOption(APIException):
code = 400 code = 400
class InvalidFilterArgument(APIException):
msg = "One or more arguments are either duplicate or invalid"
code = 400
class DisabledOption(APIException): class DisabledOption(APIException):
msg = _("The selected %(option)s is not allowed in this deployment: " msg = _("The selected %(option)s is not allowed in this deployment: "
"%(value)s") "%(value)s")

View File

@ -209,13 +209,16 @@ class TestAmphora(base.BaseAPITest):
self.AMPHORAE_PATH, self.AMPHORAE_PATH,
params={'project_id': self.project_id} params={'project_id': self.project_id}
).json.get(self.root_tag_list) ).json.get(self.root_tag_list)
self.assertEqual(1, len(amps)) self.assertEqual(1, len(amps))
false_project_id = uuidutils.generate_uuid()
amps = self.get( amps = self.get(
self.AMPHORAE_PATH, self.AMPHORAE_PATH,
params={'project_id': uuidutils.generate_uuid()} params={'project_id': false_project_id}
).json.get(self.root_tag_list) ).json.get(self.root_tag_list)
self.assertEqual(0, len(amps))
self.assertEqual(int(false_project_id == self.project_id),
len(amps))
def test_get_all_sorted(self): def test_get_all_sorted(self):
self._create_additional_amp() self._create_additional_amp()

View File

@ -95,6 +95,17 @@ class TestTypeDataModelRenames(base.TestCase):
self.assertEqual(new_type.child_one, child_dict.get('one')) self.assertEqual(new_type.child_one, child_dict.get('one'))
self.assertEqual(new_type.child_two, child_dict.get('two')) self.assertEqual(new_type.child_two, child_dict.get('two'))
def test_translate_dict_keys_to_data_model(self):
new_type = TestTypeRename.from_data_model(self.model)
new_type_vars = {
k: getattr(new_type, k) for k in dir(new_type) if not (
callable(getattr(new_type, k)) or k.startswith('_'))
}
self.assertEqual(
set(vars(self.model)),
set(new_type.translate_dict_keys_to_data_model(new_type_vars)),
)
def test_type_to_dict_with_tenant_id(self): def test_type_to_dict_with_tenant_id(self):
type_dict = TestTypeTenantProject(tenant_id='1234').to_dict() type_dict = TestTypeTenantProject(tenant_id='1234').to_dict()
self.assertEqual('1234', type_dict['project_id']) self.assertEqual('1234', type_dict['project_id'])

View File

@ -105,13 +105,57 @@ class TestPaginationHelper(base.TestCase):
@mock.patch('octavia.api.common.pagination.request') @mock.patch('octavia.api.common.pagination.request')
def test_filter_mismatched_params(self, request_mock): def test_filter_mismatched_params(self, request_mock):
params = {'id': 'fake_id', 'fields': 'id'} params = {
'id': 'fake_id',
'fields': 'field',
'limit': '10',
'sort': None,
}
filters = {'id': 'fake_id'} filters = {'id': 'fake_id'}
helper = pagination.PaginationHelper(params) helper = pagination.PaginationHelper(params)
query_mock = mock.MagicMock() query_mock = mock.MagicMock()
helper.apply(query_mock, models.LoadBalancer) helper.apply(query_mock, models.LoadBalancer)
self.assertEqual(filters, helper.filters) self.assertEqual(filters, helper.filters)
helper.apply(query_mock, models.LoadBalancer,
enforce_valid_params=True)
self.assertEqual(filters, helper.filters)
@mock.patch('octavia.api.common.pagination.request')
def test_filter_with_invalid_params(self, request_mock):
params = {'id': 'fake_id', 'no_such_param': 'id'}
filters = {'id': 'fake_id'}
helper = pagination.PaginationHelper(params)
query_mock = mock.MagicMock()
helper.apply(query_mock, models.LoadBalancer,
# silently ignore invalid parameter
enforce_valid_params=False)
self.assertEqual(filters, helper.filters)
self.assertRaises(
exceptions.InvalidFilterArgument,
pagination.PaginationHelper.apply,
helper,
query_mock,
models.Amphora,
)
@mock.patch('octavia.api.common.pagination.request')
def test_duplicate_argument(self, request_mock):
params = {'loadbalacer_id': 'id1', 'load_balacer_id': 'id2'}
query_mock = mock.MagicMock()
helper = pagination.PaginationHelper(params)
self.assertRaises(
exceptions.InvalidFilterArgument,
pagination.PaginationHelper.apply,
helper,
query_mock,
models.Amphora,
)
@mock.patch('octavia.api.common.pagination.request') @mock.patch('octavia.api.common.pagination.request')
def test_fields_not_passed(self, request_mock): def test_fields_not_passed(self, request_mock):