Add filtering and field selection to API

This patch implements API filtering based off of
query parameters passed to the Octavia API. Additonally
this patch implements field selection for the Octavia
API.

Change-Id: I9fe26abe37f464d9c028b8c476485007143d3b5c
This commit is contained in:
Jude Cross 2017-05-30 15:09:19 -07:00
parent e47c2cb584
commit 487750a877
28 changed files with 503 additions and 19 deletions

View File

@ -49,6 +49,7 @@ class PaginationHelper(object):
self.limit = self._parse_limit(params)
self.sort_keys = self._parse_sort_keys(params)
self.params = params
self.filters = None
@staticmethod
def _parse_limit(params):
@ -184,13 +185,26 @@ class PaginationHelper(object):
Typically, the id of the last row is used as the client-facing
pagination marker, then the actual marker object must be fetched from
the db and passed in to us as marker.
:param query: the query object to which we should add paging/sorting
:param query: the query object to which we should add
paging/sorting/filtering
:param model: the ORM model class
:rtype: sqlalchemy.orm.query.Query
:returns: The query with sorting/pagination added.
:returns: The query with sorting/pagination/filtering added.
"""
# Add filtering
if CONF.allow_filtering:
filter_attrs = [attr for attr in dir(
model.__v2_wsme__
) if not callable(
getattr(model.__v2_wsme__, attr)
) and not attr.startswith("_")]
self.filters = {k: v for (k, v) in self.params.items()
if k in filter_attrs}
query = model.apply_filter(query, model, self.filters)
# Add sorting
if CONF.allow_sorting:
# Add default sort keys (if they are OK for the model)
@ -263,7 +277,9 @@ class PaginationHelper(object):
query = query.limit(self.limit)
model_list = query.all()
links = None
if CONF.allow_pagination:
links = self._make_links(model_list)
return model_list, links

View File

@ -44,7 +44,7 @@ class BaseController(rest.RestController):
"""Converts a data model into an Octavia WSME type
:param db_entity: data model to convert
:param to_type: converts db_entity to this time
:param to_type: converts db_entity to this type
"""
if isinstance(to_type, list):
to_type = to_type[0]
@ -174,3 +174,18 @@ class BaseController(rest.RestController):
rbac_obj=self.RBAC_TYPE, action=action)
target = {'project_id': project_id}
context.policy.authorize(action, target)
def _filter_fields(self, object_list, fields):
if CONF.allow_field_selection:
for index, obj in enumerate(object_list):
members = self._get_attrs(obj)
for member in members:
if member not in fields:
delattr(object_list[index], member)
return object_list
@staticmethod
def _get_attrs(obj):
attrs = [attr for attr in dir(obj) if not callable(
getattr(obj, attr)) and not attr.startswith("_")]
return attrs

View File

@ -68,8 +68,8 @@ class HealthMonitorController(base.BaseController):
return hm_types.HealthMonitorRootResponse(healthmonitor=result)
@wsme_pecan.wsexpose(hm_types.HealthMonitorsRootResponse, wtypes.text,
ignore_extra_args=True)
def get_all(self, project_id=None):
[wtypes.text], ignore_extra_args=True)
def get_all(self, project_id=None, fields=None):
"""Gets all health monitors."""
pcontext = pecan.request.context
context = pcontext.get('octavia_context')
@ -82,6 +82,8 @@ class HealthMonitorController(base.BaseController):
**query_filter)
result = self._convert_db_to_type(
db_hm, [hm_types.HealthMonitorResponse])
if fields is not None:
result = self._filter_fields(result, fields)
return hm_types.HealthMonitorsRootResponse(
healthmonitors=result, healthmonitors_links=links)

View File

@ -57,8 +57,8 @@ class L7PolicyController(base.BaseController):
return l7policy_types.L7PolicyRootResponse(l7policy=result)
@wsme_pecan.wsexpose(l7policy_types.L7PoliciesRootResponse, wtypes.text,
ignore_extra_args=True)
def get_all(self, project_id=None):
[wtypes.text], ignore_extra_args=True)
def get_all(self, project_id=None, fields=None):
"""Lists all l7policies of a listener."""
pcontext = pecan.request.context
context = pcontext.get('octavia_context')
@ -71,6 +71,8 @@ class L7PolicyController(base.BaseController):
**query_filter)
result = self._convert_db_to_type(
db_l7policies, [l7policy_types.L7PolicyResponse])
if fields is not None:
result = self._filter_fields(result, fields)
return l7policy_types.L7PoliciesRootResponse(
l7policies=result, l7policies_links=links)

View File

@ -55,8 +55,8 @@ class L7RuleController(base.BaseController):
return l7rule_types.L7RuleRootResponse(rule=result)
@wsme_pecan.wsexpose(l7rule_types.L7RulesRootResponse, wtypes.text,
ignore_extra_args=True)
def get_all(self):
[wtypes.text], ignore_extra_args=True)
def get_all(self, fields=None):
"""Lists all l7rules of a l7policy."""
pcontext = pecan.request.context
context = pcontext.get('octavia_context')
@ -74,6 +74,8 @@ class L7RuleController(base.BaseController):
pagination_helper=pcontext.get(constants.PAGINATION_HELPER))
result = self._convert_db_to_type(
db_l7rules, [l7rule_types.L7RuleResponse])
if fields is not None:
result = self._filter_fields(result, fields)
return l7rule_types.L7RulesRootResponse(
rules=result, rules_links=links)

View File

@ -70,8 +70,8 @@ class ListenersController(base.BaseController):
return listener_types.ListenerRootResponse(listener=result)
@wsme_pecan.wsexpose(listener_types.ListenersRootResponse, wtypes.text,
ignore_extra_args=True)
def get_all(self, project_id=None):
[wtypes.text], ignore_extra_args=True)
def get_all(self, project_id=None, fields=None):
"""Lists all listeners."""
pcontext = pecan.request.context
context = pcontext.get('octavia_context')
@ -84,6 +84,8 @@ class ListenersController(base.BaseController):
**query_filter)
result = self._convert_db_to_type(
db_listeners, [listener_types.ListenerResponse])
if fields is not None:
result = self._filter_fields(result, fields)
return listener_types.ListenersRootResponse(
listeners=result, listeners_links=links)

View File

@ -61,8 +61,8 @@ class LoadBalancersController(base.BaseController):
return lb_types.LoadBalancerRootResponse(loadbalancer=result)
@wsme_pecan.wsexpose(lb_types.LoadBalancersRootResponse, wtypes.text,
ignore_extra_args=True)
def get_all(self, project_id=None):
[wtypes.text], ignore_extra_args=True)
def get_all(self, project_id=None, fields=None):
"""Lists all load balancers."""
pcontext = pecan.request.context
context = pcontext.get('octavia_context')
@ -75,6 +75,8 @@ class LoadBalancersController(base.BaseController):
**query_filter)
result = self._convert_db_to_type(
load_balancers, [lb_types.LoadBalancerResponse])
if fields is not None:
result = self._filter_fields(result, fields)
return lb_types.LoadBalancersRootResponse(
loadbalancers=result, loadbalancers_links=links)

View File

@ -56,8 +56,8 @@ class MembersController(base.BaseController):
return member_types.MemberRootResponse(member=result)
@wsme_pecan.wsexpose(member_types.MembersRootResponse, wtypes.text,
ignore_extra_args=True)
def get_all(self):
[wtypes.text], ignore_extra_args=True)
def get_all(self, fields=None):
"""Lists all pool members of a pool."""
pcontext = pecan.request.context
context = pcontext.get('octavia_context')
@ -76,6 +76,8 @@ class MembersController(base.BaseController):
pagination_helper=pcontext.get(constants.PAGINATION_HELPER))
result = self._convert_db_to_type(
db_members, [member_types.MemberResponse])
if fields is not None:
result = self._filter_fields(result, fields)
return member_types.MembersRootResponse(
members=result, members_links=links)

View File

@ -58,8 +58,8 @@ class PoolsController(base.BaseController):
return pool_types.PoolRootResponse(pool=result)
@wsme_pecan.wsexpose(pool_types.PoolsRootResponse, wtypes.text,
ignore_extra_args=True)
def get_all(self, project_id=None):
[wtypes.text], ignore_extra_args=True)
def get_all(self, project_id=None, fields=None):
"""Lists all pools."""
pcontext = pecan.request.context
context = pcontext.get('octavia_context')
@ -71,6 +71,8 @@ class PoolsController(base.BaseController):
pagination_helper=pcontext.get(constants.PAGINATION_HELPER),
**query_filter)
result = self._convert_db_to_type(db_pools, [pool_types.PoolResponse])
if fields is not None:
result = self._filter_fields(result, fields)
return pool_types.PoolsRootResponse(pools=result, pools_links=links)
def _get_affected_listener_ids(self, pool):

View File

@ -22,6 +22,7 @@ class BaseHealthMonitorType(types.BaseType):
_type_to_model_map = {'admin_state_up': 'enabled',
'max_retries': 'rise_threshold',
'max_retries_down': 'fall_threshold'}
_child_map = {}
class HealthMonitorResponse(BaseHealthMonitorType):

View File

@ -22,6 +22,7 @@ from octavia.common import constants
class BaseL7PolicyType(types.BaseType):
_type_to_model_map = {'admin_state_up': 'enabled'}
_child_map = {}
class L7PolicyResponse(BaseL7PolicyType):

View File

@ -20,6 +20,7 @@ from octavia.common import constants
class BaseL7Type(types.BaseType):
_type_to_model_map = {'admin_state_up': 'enabled'}
_child_map = {}
class L7RuleResponse(BaseL7Type):

View File

@ -23,6 +23,7 @@ from octavia.common import constants
class BaseListenerType(types.BaseType):
_type_to_model_map = {'admin_state_up': 'enabled',
'default_tls_container_ref': 'tls_certificate_id'}
_child_map = {}
class ListenerResponse(BaseListenerType):

View File

@ -25,6 +25,11 @@ class BaseLoadBalancerType(types.BaseType):
'vip_port_id': 'vip.port_id',
'vip_network_id': 'vip.network_id',
'admin_state_up': 'enabled'}
_child_map = {'vip': {
'ip_address': 'vip_address',
'subnet_id': 'vip_subnet_id',
'port_id': 'vip_port_id',
'network_id': 'vip_network_id'}}
class LoadBalancerResponse(BaseLoadBalancerType):

View File

@ -21,6 +21,7 @@ from octavia.common import constants
class BaseMemberType(types.BaseType):
_type_to_model_map = {'admin_state_up': 'enabled',
'address': 'ip_address'}
_child_map = {}
class MemberResponse(BaseMemberType):

View File

@ -42,6 +42,7 @@ class SessionPersistencePUT(types.BaseType):
class BasePoolType(types.BaseType):
_type_to_model_map = {'admin_state_up': 'enabled',
'healthmonitor': 'health_monitor'}
_child_map = {}
class PoolResponse(BasePoolType):

View File

@ -53,6 +53,9 @@ class QuotaAllBase(base.BaseType):
pool = wtypes.wsattr(wtypes.IntegerType())
health_monitor = wtypes.wsattr(wtypes.IntegerType())
_type_to_model_map = {}
_child_map = {}
@classmethod
def from_data_model(cls, data_model, children=False):
quotas = super(QuotaAllBase, cls).from_data_model(

View File

@ -45,6 +45,10 @@ core_opts = [
help=_("Allow the usage of the pagination")),
cfg.BoolOpt('allow_sorting', default=True,
help=_("Allow the usage of the sorting")),
cfg.BoolOpt('allow_filtering', default=True,
help=_("Allow the usage of filtering")),
cfg.BoolOpt('allow_field_selection', default=True,
help=_("Allow the usage of field selection")),
cfg.StrOpt('pagination_max_limit',
default=str(constants.DEFAULT_PAGE_SIZE),
help=_("The maximum number of items returned in a single "

View File

@ -13,6 +13,7 @@
# under the License.
from oslo_db.sqlalchemy import models
from oslo_utils import strutils
from oslo_utils import uuidutils
import sqlalchemy as sa
from sqlalchemy.ext import declarative
@ -102,6 +103,28 @@ class OctaviaBase(models.ModelBase):
listref.append(item)
return dm_self
@staticmethod
def apply_filter(query, model, filters):
translated_filters = {}
child_map = {}
# Convert admin_state_up to proper type
if 'admin_state_up' in filters:
filters['admin_state_up'] = strutils.bool_from_string(
filters['admin_state_up'])
for attr, name_map in model.__v2_wsme__._child_map.items():
for k, v in name_map.items():
if v in filters:
child_map[attr] = {k: filters.pop(v)}
for k, v in model.__v2_wsme__._type_to_model_map.items():
if k in filters:
translated_filters[v] = filters.pop(k)
translated_filters.update(filters)
if translated_filters:
query = query.filter_by(**translated_filters)
for k, v in child_map.items():
query = query.join(getattr(model, k)).filter_by(**v)
return query
class LookupTableMixin(object):
"""Mixin to add to classes that are lookup tables."""

View File

@ -20,6 +20,14 @@ from sqlalchemy import orm
from sqlalchemy.orm import validates
from sqlalchemy.sql import func
from octavia.api.v2.types import health_monitor
from octavia.api.v2.types import l7policy
from octavia.api.v2.types import l7rule
from octavia.api.v2.types import listener
from octavia.api.v2.types import load_balancer
from octavia.api.v2.types import member
from octavia.api.v2.types import pool
from octavia.api.v2.types import quotas
from octavia.common import data_models
from octavia.db import base_models
@ -164,6 +172,9 @@ class Member(base_models.BASE, base_models.IdMixin, base_models.ProjectMixin,
__data_model__ = data_models.Member
__tablename__ = "member"
__v2_wsme__ = member.MemberResponse
__table_args__ = (
sa.UniqueConstraint('pool_id', 'ip_address', 'protocol_port',
name='uq_member_pool_id_address_protocol_port'),
@ -203,6 +214,8 @@ class HealthMonitor(base_models.BASE, base_models.IdMixin,
__tablename__ = "health_monitor"
__v2_wsme__ = health_monitor.HealthMonitorResponse
type = sa.Column(
sa.String(36),
sa.ForeignKey("health_monitor_type.name",
@ -243,6 +256,8 @@ class Pool(base_models.BASE, base_models.IdMixin, base_models.ProjectMixin,
__tablename__ = "pool"
__v2_wsme__ = pool.PoolResponse
description = sa.Column(sa.String(255), nullable=True)
protocol = sa.Column(
sa.String(16),
@ -298,6 +313,8 @@ class LoadBalancer(base_models.BASE, base_models.IdMixin,
__tablename__ = "load_balancer"
__v2_wsme__ = load_balancer.LoadBalancerResponse
description = sa.Column(sa.String(255), nullable=True)
provisioning_status = sa.Column(
sa.String(16),
@ -371,6 +388,9 @@ class Listener(base_models.BASE, base_models.IdMixin,
__data_model__ = data_models.Listener
__tablename__ = "listener"
__v2_wsme__ = listener.ListenerResponse
__table_args__ = (
sa.UniqueConstraint('load_balancer_id', 'protocol_port',
name='uq_listener_load_balancer_id_protocol_port'),
@ -508,6 +528,8 @@ class L7Rule(base_models.BASE, base_models.IdMixin, base_models.ProjectMixin,
__tablename__ = "l7rule"
__v2_wsme__ = l7rule.L7RuleResponse
l7policy_id = sa.Column(
sa.String(36),
sa.ForeignKey("l7policy.id", name="fk_l7rule_l7policy_id"),
@ -551,6 +573,8 @@ class L7Policy(base_models.BASE, base_models.IdMixin, base_models.ProjectMixin,
__tablename__ = "l7policy"
__v2_wsme__ = l7policy.L7PolicyResponse
description = sa.Column(sa.String(255), nullable=True)
listener_id = sa.Column(
sa.String(36),
@ -601,6 +625,8 @@ class Quotas(base_models.BASE):
__tablename__ = "quotas"
__v2_wsme__ = quotas.QuotaAllBase
project_id = sa.Column(sa.String(36), primary_key=True)
health_monitor = sa.Column(sa.Integer(), nullable=True)
listener = sa.Column(sa.Integer(), nullable=True)

View File

@ -475,6 +475,83 @@ class TestHealthMonitor(base.BaseAPITest):
self.assertEqual(2, len(links))
self.assertItemsEqual(['previous', 'next'], [l['rel'] for l in links])
def test_get_all_fields_filter(self):
pool1 = self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool1').get('pool')
self.set_lb_status(self.lb_id)
pool2 = self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool2').get('pool')
self.set_lb_status(self.lb_id)
pool3 = self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool3').get('pool')
self.set_lb_status(self.lb_id)
self.create_health_monitor(
pool1.get('id'), constants.HEALTH_MONITOR_HTTP,
1, 1, 1, 1, name='hm1').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_health_monitor(
pool2.get('id'), constants.HEALTH_MONITOR_PING,
1, 1, 1, 1, name='hm2').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_health_monitor(
pool3.get('id'), constants.HEALTH_MONITOR_TCP,
1, 1, 1, 1, name='hm3').get(self.root_tag)
self.set_lb_status(self.lb_id)
hms = self.get(self.HMS_PATH, params={
'fields': ['id', 'project_id']}).json
for hm in hms['healthmonitors']:
self.assertIn(u'id', hm.keys())
self.assertIn(u'project_id', hm.keys())
self.assertNotIn(u'description', hm.keys())
def test_get_all_filter(self):
pool1 = self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool1').get('pool')
self.set_lb_status(self.lb_id)
pool2 = self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool2').get('pool')
self.set_lb_status(self.lb_id)
pool3 = self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool3').get('pool')
self.set_lb_status(self.lb_id)
hm1 = self.create_health_monitor(
pool1.get('id'), constants.HEALTH_MONITOR_HTTP,
1, 1, 1, 1, name='hm1').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_health_monitor(
pool2.get('id'), constants.HEALTH_MONITOR_PING,
1, 1, 1, 1, name='hm2').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_health_monitor(
pool3.get('id'), constants.HEALTH_MONITOR_TCP,
1, 1, 1, 1, name='hm3').get(self.root_tag)
self.set_lb_status(self.lb_id)
hms = self.get(self.HMS_PATH, params={
'id': hm1['id']}).json
self.assertEqual(1, len(hms['healthmonitors']))
self.assertEqual(hm1['id'],
hms['healthmonitors'][0]['id'])
def test_empty_get_all(self):
response = self.get(self.HMS_PATH).json.get(self.root_tag_list)
self.assertIsInstance(response, list)

View File

@ -427,6 +427,51 @@ class TestL7Policy(base.BaseAPITest):
self.assertEqual(2, len(links))
self.assertItemsEqual(['previous', 'next'], [l['rel'] for l in links])
def test_get_all_fields_filter(self):
self.create_l7policy(
self.listener_id, constants.L7POLICY_ACTION_REJECT,
name='policy1').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_l7policy(
self.listener_id, constants.L7POLICY_ACTION_REDIRECT_TO_POOL,
position=2, redirect_pool_id=self.pool_id,
name='policy2').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_l7policy(
self.listener_id, constants.L7POLICY_ACTION_REDIRECT_TO_URL,
redirect_url='http://localhost/',
name='policy3').get(self.root_tag)
self.set_lb_status(self.lb_id)
l7pos = self.get(self.L7POLICIES_PATH, params={
'fields': ['id', 'project_id']}).json
for l7po in l7pos['l7policies']:
self.assertIn(u'id', l7po.keys())
self.assertIn(u'project_id', l7po.keys())
self.assertNotIn(u'description', l7po.keys())
def test_get_all_filter(self):
policy1 = self.create_l7policy(
self.listener_id, constants.L7POLICY_ACTION_REJECT,
name='policy1').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_l7policy(
self.listener_id, constants.L7POLICY_ACTION_REDIRECT_TO_POOL,
position=2, redirect_pool_id=self.pool_id,
name='policy2').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_l7policy(
self.listener_id, constants.L7POLICY_ACTION_REDIRECT_TO_URL,
redirect_url='http://localhost/',
name='policy3').get(self.root_tag)
self.set_lb_status(self.lb_id)
l7pos = self.get(self.L7POLICIES_PATH, params={
'id': policy1['id']}).json
self.assertEqual(1, len(l7pos['l7policies']))
self.assertEqual(policy1['id'],
l7pos['l7policies'][0]['id'])
def test_empty_get_all(self):
response = self.get(self.L7POLICIES_PATH).json.get(self.root_tag_list)
self.assertIsInstance(response, list)

View File

@ -291,6 +291,54 @@ class TestL7Rule(base.BaseAPITest):
self.assertEqual(2, len(links))
self.assertItemsEqual(['previous', 'next'], [l['rel'] for l in links])
def test_get_all_fields_filter(self):
self.create_l7rule(
self.l7policy_id, constants.L7RULE_TYPE_PATH,
constants.L7RULE_COMPARE_TYPE_STARTS_WITH,
'/api').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_l7rule(
self.l7policy_id, constants.L7RULE_TYPE_COOKIE,
constants.L7RULE_COMPARE_TYPE_CONTAINS, 'some-value',
key='some-cookie').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_l7rule(
self.l7policy_id, constants.L7RULE_TYPE_HOST_NAME,
constants.L7RULE_COMPARE_TYPE_EQUAL_TO,
'www.example.com').get(self.root_tag)
self.set_lb_status(self.lb_id)
l7rus = self.get(self.l7rules_path, params={
'fields': ['id', 'project_id']}).json
for l7ru in l7rus['rules']:
self.assertIn(u'id', l7ru.keys())
self.assertIn(u'project_id', l7ru.keys())
self.assertNotIn(u'description', l7ru.keys())
def test_get_all_filter(self):
ru1 = self.create_l7rule(
self.l7policy_id, constants.L7RULE_TYPE_PATH,
constants.L7RULE_COMPARE_TYPE_STARTS_WITH,
'/api').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_l7rule(
self.l7policy_id, constants.L7RULE_TYPE_COOKIE,
constants.L7RULE_COMPARE_TYPE_CONTAINS, 'some-value',
key='some-cookie').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_l7rule(
self.l7policy_id, constants.L7RULE_TYPE_HOST_NAME,
constants.L7RULE_COMPARE_TYPE_EQUAL_TO,
'www.example.com').get(self.root_tag)
self.set_lb_status(self.lb_id)
l7rus = self.get(self.l7rules_path, params={
'id': ru1['id']}).json
self.assertEqual(1, len(l7rus['rules']))
self.assertEqual(ru1['id'],
l7rus['rules'][0]['id'])
def test_empty_get_all(self):
response = self.get(self.l7rules_path).json.get(self.root_tag_list)
self.assertIsInstance(response, list)

View File

@ -303,6 +303,50 @@ class TestListener(base.BaseAPITest):
self.assertEqual(2, len(links))
self.assertItemsEqual(['previous', 'next'], [l['rel'] for l in links])
def test_get_all_fields_filter(self):
self.create_listener(constants.PROTOCOL_HTTP, 80,
self.lb_id,
name='listener1')
self.set_lb_status(self.lb_id)
self.create_listener(constants.PROTOCOL_HTTP, 81,
self.lb_id,
name='listener2')
self.set_lb_status(self.lb_id)
self.create_listener(constants.PROTOCOL_HTTP, 82,
self.lb_id,
name='listener3')
self.set_lb_status(self.lb_id)
lis = self.get(self.LISTENERS_PATH, params={
'fields': ['id', 'project_id']}).json
for li in lis['listeners']:
self.assertIn(u'id', li.keys())
self.assertIn(u'project_id', li.keys())
self.assertNotIn(u'description', li.keys())
def test_get_all_filter(self):
li1 = self.create_listener(constants.PROTOCOL_HTTP,
80,
self.lb_id,
name='listener1').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_listener(constants.PROTOCOL_HTTP,
81,
self.lb_id,
name='listener2').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_listener(constants.PROTOCOL_HTTP,
82,
self.lb_id,
name='listener3').get(self.root_tag)
self.set_lb_status(self.lb_id)
lis = self.get(self.LISTENERS_PATH, params={
'id': li1['id']}).json
self.assertEqual(1, len(lis['listeners']))
self.assertEqual(li1['id'],
lis['listeners'][0]['id'])
def test_get(self):
listener = self.create_listener(
constants.PROTOCOL_HTTP, 80, self.lb_id).get(self.root_tag)

View File

@ -608,6 +608,43 @@ class TestLoadBalancer(base.BaseAPITest):
self.assertEqual(2, len(links))
self.assertItemsEqual(['previous', 'next'], [l['rel'] for l in links])
def test_get_all_fields_filter(self):
self.create_load_balancer(uuidutils.generate_uuid(),
name='lb1',
project_id=self.project_id)
self.create_load_balancer(uuidutils.generate_uuid(),
name='lb2',
project_id=self.project_id)
self.create_load_balancer(uuidutils.generate_uuid(),
name='lb3',
project_id=self.project_id)
lbs = self.get(self.LBS_PATH, params={
'fields': ['id', 'project_id']}).json
for lb in lbs['loadbalancers']:
self.assertIn(u'id', lb.keys())
self.assertIn(u'project_id', lb.keys())
self.assertNotIn(u'description', lb.keys())
def test_get_all_filter(self):
lb1 = self.create_load_balancer(
uuidutils.generate_uuid(),
name='lb1',
project_id=self.project_id).get(self.root_tag)
self.create_load_balancer(
uuidutils.generate_uuid(),
name='lb2',
project_id=self.project_id).get(self.root_tag)
self.create_load_balancer(
uuidutils.generate_uuid(),
name='lb3',
project_id=self.project_id).get(self.root_tag)
lbs = self.get(self.LBS_PATH, params={
'id': lb1['id']}).json
self.assertEqual(1, len(lbs['loadbalancers']))
self.assertEqual(lb1['id'],
lbs['loadbalancers'][0]['id'])
def test_get(self):
project_id = uuidutils.generate_uuid()
subnet = network_models.Subnet(id=uuidutils.generate_uuid())

View File

@ -280,6 +280,44 @@ class TestMember(base.BaseAPITest):
self.assertEqual(2, len(links))
self.assertItemsEqual(['previous', 'next'], [l['rel'] for l in links])
def test_get_all_fields_filter(self):
self.create_member(self.pool_id, '10.0.0.1', 80, name='member1')
self.set_lb_status(self.lb_id)
self.create_member(self.pool_id, '10.0.0.2', 80, name='member2')
self.set_lb_status(self.lb_id)
self.create_member(self.pool_id, '10.0.0.3', 80, name='member3')
self.set_lb_status(self.lb_id)
members = self.get(self.members_path, params={
'fields': ['id', 'project_id']}).json
for member in members['members']:
self.assertIn(u'id', member.keys())
self.assertIn(u'project_id', member.keys())
self.assertNotIn(u'description', member.keys())
def test_get_all_filter(self):
mem1 = self.create_member(self.pool_id,
'10.0.0.1',
80,
name='member1').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_member(self.pool_id,
'10.0.0.2',
80,
name='member2').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_member(self.pool_id,
'10.0.0.3',
80,
name='member3').get(self.root_tag)
self.set_lb_status(self.lb_id)
members = self.get(self.members_path, params={
'id': mem1['id']}).json
self.assertEqual(1, len(members['members']))
self.assertEqual(mem1['id'],
members['members'][0]['id'])
def test_empty_get_all(self):
response = self.get(self.members_path).json.get(self.root_tag_list)
self.assertIsInstance(response, list)

View File

@ -450,6 +450,59 @@ class TestPool(base.BaseAPITest):
self.assertEqual(2, len(links))
self.assertItemsEqual(['previous', 'next'], [l['rel'] for l in links])
def test_get_all_fields_filter(self):
self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool1')
self.set_lb_status(lb_id=self.lb_id)
self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool2')
self.set_lb_status(lb_id=self.lb_id)
self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool3')
self.set_lb_status(lb_id=self.lb_id)
pools = self.get(self.POOLS_PATH, params={
'fields': ['id', 'project_id']}).json
for pool in pools['pools']:
self.assertIn(u'id', pool.keys())
self.assertIn(u'project_id', pool.keys())
self.assertNotIn(u'description', pool.keys())
def test_get_all_filter(self):
po1 = self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool1').get(self.root_tag)
self.set_lb_status(lb_id=self.lb_id)
self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool2').get(self.root_tag)
self.set_lb_status(lb_id=self.lb_id)
self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool3').get(self.root_tag)
self.set_lb_status(lb_id=self.lb_id)
pools = self.get(self.POOLS_PATH, params={
'id': po1['id']}).json
self.assertEqual(1, len(pools['pools']))
self.assertEqual(po1['id'],
pools['pools'][0]['id'])
def test_empty_get_all(self):
response = self.get(self.POOLS_PATH).json.get(self.root_tag_list)
self.assertIsInstance(response, list)

View File

@ -38,7 +38,8 @@ class TestPaginationHelper(base.TestCase):
self.assertEqual(DEFAULT_SORTS, helper.sort_keys)
self.assertIsNone(helper.marker)
self.assertEqual(1000, helper.limit)
query_mock.order_by().order_by().limit.assert_called_with(1000)
query_mock.order_by().order_by().limit.assert_called_with(
1000)
def test_sort_empty(self):
sort_params = ""
@ -90,7 +91,36 @@ class TestPaginationHelper(base.TestCase):
query_mock = mock.MagicMock()
helper.apply(query_mock, models.LoadBalancer)
query_mock.order_by().order_by().limit.assert_called_with(limit)
query_mock.order_by().order_by().limit.assert_called_with(
limit)
@mock.patch('octavia.api.common.pagination.request')
def test_filter_correct_params(self, request_mock):
params = {'id': 'fake_id'}
helper = pagination.PaginationHelper(params)
query_mock = mock.MagicMock()
helper.apply(query_mock, models.LoadBalancer)
self.assertEqual(params, helper.filters)
@mock.patch('octavia.api.common.pagination.request')
def test_filter_mismatched_params(self, request_mock):
params = {'id': 'fake_id', 'fields': 'id'}
filters = {'id': 'fake_id'}
helper = pagination.PaginationHelper(params)
query_mock = mock.MagicMock()
helper.apply(query_mock, models.LoadBalancer)
self.assertEqual(filters, helper.filters)
@mock.patch('octavia.api.common.pagination.request')
def test_fields_not_passed(self, request_mock):
params = {'fields': 'id'}
helper = pagination.PaginationHelper(params)
query_mock = mock.MagicMock()
helper.apply(query_mock, models.LoadBalancer)
self.assertEqual({}, helper.filters)
@mock.patch('octavia.api.common.pagination.request')
def test_make_links_next(self, request_mock):