Merge "Fix filtering in list API calls"
This commit is contained in:
commit
82d9f44c62
@ -146,7 +146,7 @@ amphora-role:
|
|||||||
amphora-status:
|
amphora-status:
|
||||||
description: |
|
description: |
|
||||||
The status of the amphora. One of: ``BOOTING``, ``ALLOCATED``, ``READY``,
|
The status of the amphora. One of: ``BOOTING``, ``ALLOCATED``, ``READY``,
|
||||||
``PENDING_DELETE``, ``DELETED``, ``ERROR``.
|
``PENDING_CREATE``, ``PENDING_DELETE``, ``DELETED``, ``ERROR``.
|
||||||
in: body
|
in: body
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
|
@ -23,6 +23,7 @@ from octavia.api.common import types
|
|||||||
from octavia.common.config import cfg
|
from octavia.common.config import cfg
|
||||||
from octavia.common import constants
|
from octavia.common import constants
|
||||||
from octavia.common import exceptions
|
from octavia.common import exceptions
|
||||||
|
from octavia.db import models
|
||||||
|
|
||||||
CONF = cfg.CONF
|
CONF = cfg.CONF
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@ -33,6 +34,10 @@ class PaginationHelper(object):
|
|||||||
|
|
||||||
Pass this class to `db.repositories` to apply it on query
|
Pass this class to `db.repositories` to apply it on query
|
||||||
"""
|
"""
|
||||||
|
_auxiliary_arguments = ('limit', 'marker',
|
||||||
|
'sort', 'sort_key', 'sort_dir',
|
||||||
|
'fields', 'page_reverse',
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, params, sort_dir=constants.DEFAULT_SORT_DIR):
|
def __init__(self, params, sort_dir=constants.DEFAULT_SORT_DIR):
|
||||||
"""Pagination Helper takes params and a default sort direction
|
"""Pagination Helper takes params and a default sort direction
|
||||||
@ -173,7 +178,7 @@ class PaginationHelper(object):
|
|||||||
links = [types.PageType(**link) for link in links]
|
links = [types.PageType(**link) for link in links]
|
||||||
return links
|
return links
|
||||||
|
|
||||||
def apply(self, query, model):
|
def apply(self, query, model, enforce_valid_params=True):
|
||||||
"""Returns a query with sorting / pagination criteria added.
|
"""Returns a query with sorting / pagination criteria added.
|
||||||
|
|
||||||
Pagination works by requiring a unique sort_key specified by sort_keys.
|
Pagination works by requiring a unique sort_key specified by sort_keys.
|
||||||
@ -191,6 +196,7 @@ class PaginationHelper(object):
|
|||||||
:param query: the query object to which we should add
|
:param query: the query object to which we should add
|
||||||
paging/sorting/filtering
|
paging/sorting/filtering
|
||||||
:param model: the ORM model class
|
:param model: the ORM model class
|
||||||
|
:param enforce_valid_params: check for invalid enteries in self.params
|
||||||
|
|
||||||
:rtype: sqlalchemy.orm.query.Query
|
:rtype: sqlalchemy.orm.query.Query
|
||||||
:returns: The query with sorting/pagination/filtering added.
|
:returns: The query with sorting/pagination/filtering added.
|
||||||
@ -198,18 +204,35 @@ class PaginationHelper(object):
|
|||||||
|
|
||||||
# Add filtering
|
# Add filtering
|
||||||
if CONF.api_settings.allow_filtering:
|
if CONF.api_settings.allow_filtering:
|
||||||
filter_attrs = [attr for attr in dir(
|
# Exclude (valid) arguments that are not used for data filtering
|
||||||
model.__v2_wsme__
|
filter_params = {k: v for k, v in self.params.items(
|
||||||
) if not callable(
|
) if k not in self._auxiliary_arguments}
|
||||||
getattr(model.__v2_wsme__, attr)
|
|
||||||
) and not attr.startswith("_")]
|
secondary_query_filter = filter_params.pop(
|
||||||
self.filters = {k: v for (k, v) in self.params.items()
|
"project_id", None) if (model == models.Amphora) else None
|
||||||
if k in filter_attrs}
|
|
||||||
|
# Tranlate arguments from API standard to data model's field name
|
||||||
|
filter_params = (
|
||||||
|
model.__v2_wsme__.translate_dict_keys_to_data_model(
|
||||||
|
filter_params)
|
||||||
|
)
|
||||||
|
if 'loadbalancer_id' in filter_params:
|
||||||
|
filter_params['load_balancer_id'] = filter_params.pop(
|
||||||
|
'loadbalancer_id')
|
||||||
|
|
||||||
|
# Drop invalid arguments
|
||||||
|
self.filters = {k: v for (k, v) in filter_params.items()
|
||||||
|
if k in vars(model.__data_model__())}
|
||||||
|
|
||||||
|
if enforce_valid_params and (
|
||||||
|
len(self.filters) < len(filter_params)
|
||||||
|
):
|
||||||
|
raise exceptions.InvalidFilterArgument()
|
||||||
|
|
||||||
query = model.apply_filter(query, model, self.filters)
|
query = model.apply_filter(query, model, self.filters)
|
||||||
if model.__name__ == "Amphora" and 'project_id' in self.params:
|
if secondary_query_filter is not None:
|
||||||
query = query.filter(model.load_balancer.has(
|
query = query.filter(model.load_balancer.has(
|
||||||
project_id=self.params['project_id']))
|
project_id=secondary_query_filter))
|
||||||
|
|
||||||
# Add sorting
|
# Add sorting
|
||||||
if CONF.api_settings.allow_sorting:
|
if CONF.api_settings.allow_sorting:
|
||||||
|
@ -106,6 +106,24 @@ class BaseType(wtypes.Base):
|
|||||||
del new_dict[key]
|
del new_dict[key]
|
||||||
return cls(**new_dict)
|
return cls(**new_dict)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def translate_dict_keys_to_data_model(cls, wsme_dict):
|
||||||
|
"""Translate the keys from wsme class type, to data_model."""
|
||||||
|
if not hasattr(cls, '_type_to_model_map'):
|
||||||
|
return wsme_dict
|
||||||
|
res = {}
|
||||||
|
for (k, v) in wsme_dict.items():
|
||||||
|
if k in cls._type_to_model_map:
|
||||||
|
k = cls._type_to_model_map[k]
|
||||||
|
if '.' in k:
|
||||||
|
parent, child = k.split('.')
|
||||||
|
if parent not in res:
|
||||||
|
res[parent] = {}
|
||||||
|
res[parent][child] = v
|
||||||
|
continue
|
||||||
|
res[k] = v
|
||||||
|
return res
|
||||||
|
|
||||||
def to_dict(self, render_unsets=False):
|
def to_dict(self, render_unsets=False):
|
||||||
"""Converts Octavia WSME type to dictionary.
|
"""Converts Octavia WSME type to dictionary.
|
||||||
|
|
||||||
@ -119,7 +137,7 @@ class BaseType(wtypes.Base):
|
|||||||
if (isinstance(self.project_id, wtypes.UnsetType) and
|
if (isinstance(self.project_id, wtypes.UnsetType) and
|
||||||
not isinstance(self.tenant_id, wtypes.UnsetType)):
|
not isinstance(self.tenant_id, wtypes.UnsetType)):
|
||||||
self.project_id = self.tenant_id
|
self.project_id = self.tenant_id
|
||||||
ret_dict = {}
|
wsme_dict = {}
|
||||||
for attr in dir(self):
|
for attr in dir(self):
|
||||||
if attr.startswith('_'):
|
if attr.startswith('_'):
|
||||||
continue
|
continue
|
||||||
@ -143,19 +161,8 @@ class BaseType(wtypes.Base):
|
|||||||
value = None
|
value = None
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
attr_name = attr
|
wsme_dict[attr] = value
|
||||||
if (hasattr(self, '_type_to_model_map') and
|
return self.translate_dict_keys_to_data_model(wsme_dict)
|
||||||
attr in self._type_to_model_map):
|
|
||||||
renamed = self._type_to_model_map[attr]
|
|
||||||
if '.' in renamed:
|
|
||||||
parent, child = renamed.split('.')
|
|
||||||
if parent not in ret_dict:
|
|
||||||
ret_dict[parent] = {}
|
|
||||||
ret_dict[parent][child] = value
|
|
||||||
continue
|
|
||||||
attr_name = renamed
|
|
||||||
ret_dict[attr_name] = value
|
|
||||||
return ret_dict
|
|
||||||
|
|
||||||
|
|
||||||
class IdOnlyType(BaseType):
|
class IdOnlyType(BaseType):
|
||||||
|
@ -99,7 +99,8 @@ MUTABLE_STATUSES = (ACTIVE,)
|
|||||||
DELETABLE_STATUSES = (ACTIVE, ERROR)
|
DELETABLE_STATUSES = (ACTIVE, ERROR)
|
||||||
|
|
||||||
SUPPORTED_AMPHORA_STATUSES = (AMPHORA_ALLOCATED, AMPHORA_BOOTING, ERROR,
|
SUPPORTED_AMPHORA_STATUSES = (AMPHORA_ALLOCATED, AMPHORA_BOOTING, ERROR,
|
||||||
AMPHORA_READY, DELETED, PENDING_DELETE)
|
AMPHORA_READY, DELETED,
|
||||||
|
PENDING_CREATE, PENDING_DELETE)
|
||||||
|
|
||||||
ONLINE = 'ONLINE'
|
ONLINE = 'ONLINE'
|
||||||
OFFLINE = 'OFFLINE'
|
OFFLINE = 'OFFLINE'
|
||||||
|
@ -86,6 +86,11 @@ class InvalidOption(APIException):
|
|||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidFilterArgument(APIException):
|
||||||
|
msg = "One or more arguments are either duplicate or invalid"
|
||||||
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class DisabledOption(APIException):
|
class DisabledOption(APIException):
|
||||||
msg = _("The selected %(option)s is not allowed in this deployment: "
|
msg = _("The selected %(option)s is not allowed in this deployment: "
|
||||||
"%(value)s")
|
"%(value)s")
|
||||||
|
@ -209,13 +209,16 @@ class TestAmphora(base.BaseAPITest):
|
|||||||
self.AMPHORAE_PATH,
|
self.AMPHORAE_PATH,
|
||||||
params={'project_id': self.project_id}
|
params={'project_id': self.project_id}
|
||||||
).json.get(self.root_tag_list)
|
).json.get(self.root_tag_list)
|
||||||
|
|
||||||
self.assertEqual(1, len(amps))
|
self.assertEqual(1, len(amps))
|
||||||
|
|
||||||
|
false_project_id = uuidutils.generate_uuid()
|
||||||
amps = self.get(
|
amps = self.get(
|
||||||
self.AMPHORAE_PATH,
|
self.AMPHORAE_PATH,
|
||||||
params={'project_id': uuidutils.generate_uuid()}
|
params={'project_id': false_project_id}
|
||||||
).json.get(self.root_tag_list)
|
).json.get(self.root_tag_list)
|
||||||
self.assertEqual(0, len(amps))
|
|
||||||
|
self.assertEqual(int(false_project_id == self.project_id),
|
||||||
|
len(amps))
|
||||||
|
|
||||||
def test_get_all_sorted(self):
|
def test_get_all_sorted(self):
|
||||||
self._create_additional_amp()
|
self._create_additional_amp()
|
||||||
|
@ -95,6 +95,17 @@ class TestTypeDataModelRenames(base.TestCase):
|
|||||||
self.assertEqual(new_type.child_one, child_dict.get('one'))
|
self.assertEqual(new_type.child_one, child_dict.get('one'))
|
||||||
self.assertEqual(new_type.child_two, child_dict.get('two'))
|
self.assertEqual(new_type.child_two, child_dict.get('two'))
|
||||||
|
|
||||||
|
def test_translate_dict_keys_to_data_model(self):
|
||||||
|
new_type = TestTypeRename.from_data_model(self.model)
|
||||||
|
new_type_vars = {
|
||||||
|
k: getattr(new_type, k) for k in dir(new_type) if not (
|
||||||
|
callable(getattr(new_type, k)) or k.startswith('_'))
|
||||||
|
}
|
||||||
|
self.assertEqual(
|
||||||
|
set(vars(self.model)),
|
||||||
|
set(new_type.translate_dict_keys_to_data_model(new_type_vars)),
|
||||||
|
)
|
||||||
|
|
||||||
def test_type_to_dict_with_tenant_id(self):
|
def test_type_to_dict_with_tenant_id(self):
|
||||||
type_dict = TestTypeTenantProject(tenant_id='1234').to_dict()
|
type_dict = TestTypeTenantProject(tenant_id='1234').to_dict()
|
||||||
self.assertEqual('1234', type_dict['project_id'])
|
self.assertEqual('1234', type_dict['project_id'])
|
||||||
|
@ -105,13 +105,57 @@ class TestPaginationHelper(base.TestCase):
|
|||||||
|
|
||||||
@mock.patch('octavia.api.common.pagination.request')
|
@mock.patch('octavia.api.common.pagination.request')
|
||||||
def test_filter_mismatched_params(self, request_mock):
|
def test_filter_mismatched_params(self, request_mock):
|
||||||
params = {'id': 'fake_id', 'fields': 'id'}
|
params = {
|
||||||
|
'id': 'fake_id',
|
||||||
|
'fields': 'field',
|
||||||
|
'limit': '10',
|
||||||
|
'sort': None,
|
||||||
|
}
|
||||||
|
|
||||||
filters = {'id': 'fake_id'}
|
filters = {'id': 'fake_id'}
|
||||||
|
|
||||||
helper = pagination.PaginationHelper(params)
|
helper = pagination.PaginationHelper(params)
|
||||||
query_mock = mock.MagicMock()
|
query_mock = mock.MagicMock()
|
||||||
|
|
||||||
helper.apply(query_mock, models.LoadBalancer)
|
helper.apply(query_mock, models.LoadBalancer)
|
||||||
self.assertEqual(filters, helper.filters)
|
self.assertEqual(filters, helper.filters)
|
||||||
|
helper.apply(query_mock, models.LoadBalancer,
|
||||||
|
enforce_valid_params=True)
|
||||||
|
self.assertEqual(filters, helper.filters)
|
||||||
|
|
||||||
|
@mock.patch('octavia.api.common.pagination.request')
|
||||||
|
def test_filter_with_invalid_params(self, request_mock):
|
||||||
|
params = {'id': 'fake_id', 'no_such_param': 'id'}
|
||||||
|
filters = {'id': 'fake_id'}
|
||||||
|
helper = pagination.PaginationHelper(params)
|
||||||
|
query_mock = mock.MagicMock()
|
||||||
|
|
||||||
|
helper.apply(query_mock, models.LoadBalancer,
|
||||||
|
# silently ignore invalid parameter
|
||||||
|
enforce_valid_params=False)
|
||||||
|
self.assertEqual(filters, helper.filters)
|
||||||
|
|
||||||
|
self.assertRaises(
|
||||||
|
exceptions.InvalidFilterArgument,
|
||||||
|
pagination.PaginationHelper.apply,
|
||||||
|
helper,
|
||||||
|
query_mock,
|
||||||
|
models.Amphora,
|
||||||
|
)
|
||||||
|
|
||||||
|
@mock.patch('octavia.api.common.pagination.request')
|
||||||
|
def test_duplicate_argument(self, request_mock):
|
||||||
|
params = {'loadbalacer_id': 'id1', 'load_balacer_id': 'id2'}
|
||||||
|
query_mock = mock.MagicMock()
|
||||||
|
helper = pagination.PaginationHelper(params)
|
||||||
|
|
||||||
|
self.assertRaises(
|
||||||
|
exceptions.InvalidFilterArgument,
|
||||||
|
pagination.PaginationHelper.apply,
|
||||||
|
helper,
|
||||||
|
query_mock,
|
||||||
|
models.Amphora,
|
||||||
|
)
|
||||||
|
|
||||||
@mock.patch('octavia.api.common.pagination.request')
|
@mock.patch('octavia.api.common.pagination.request')
|
||||||
def test_fields_not_passed(self, request_mock):
|
def test_fields_not_passed(self, request_mock):
|
||||||
|
Loading…
Reference in New Issue
Block a user